diff options
author | Yangqing Jia <jiayq84@gmail.com> | 2016-01-19 12:49:36 -0800 |
---|---|---|
committer | Yangqing Jia <jiayq84@gmail.com> | 2016-01-19 12:49:36 -0800 |
commit | 5f2d7ba963ee65b31026ca6b62d9629aea49ba3f (patch) | |
tree | bf663240583ae5eae9a94cc61b37efae02c464f2 /caffe2/cuda_rtc/common_rtc.h | |
parent | d244ca90527990d202bae228f9e258b79d2df995 (diff) | |
download | pytorch-5f2d7ba963ee65b31026ca6b62d9629aea49ba3f.tar.gz pytorch-5f2d7ba963ee65b31026ca6b62d9629aea49ba3f.tar.bz2 pytorch-5f2d7ba963ee65b31026ca6b62d9629aea49ba3f.zip |
misc: experimental cuda elementwise rtc, softmax fp16
Diffstat (limited to 'caffe2/cuda_rtc/common_rtc.h')
-rw-r--r-- | caffe2/cuda_rtc/common_rtc.h | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/caffe2/cuda_rtc/common_rtc.h b/caffe2/cuda_rtc/common_rtc.h index 536a1c9e18..e3b377efda 100644 --- a/caffe2/cuda_rtc/common_rtc.h +++ b/caffe2/cuda_rtc/common_rtc.h @@ -81,6 +81,17 @@ class CudaRTCFunction { args_voidp, 0)); } + void LaunchEx(unsigned int gx, unsigned int gy, unsigned int gz, + unsigned int bx, unsigned int by, unsigned int bz, + unsigned int shared_mem, cudaStream_t stream, + void** extra) { + CAFFE_CHECK(module_loaded_) + << "Cannot call Launch before a module is loaded."; + CUDA_DRIVERAPI_CHECK(cuLaunchKernel( + kernel_, gx, gy, gz, bx, by, bz, shared_mem, stream, + nullptr, extra)); + } + private: bool module_loaded_; CUmodule module_; |