summaryrefslogtreecommitdiff
path: root/caffe2/cuda_rtc/common_rtc.h
diff options
context:
space:
mode:
authorYangqing Jia <jiayq84@gmail.com>2016-01-19 12:49:36 -0800
committerYangqing Jia <jiayq84@gmail.com>2016-01-19 12:49:36 -0800
commit5f2d7ba963ee65b31026ca6b62d9629aea49ba3f (patch)
treebf663240583ae5eae9a94cc61b37efae02c464f2 /caffe2/cuda_rtc/common_rtc.h
parentd244ca90527990d202bae228f9e258b79d2df995 (diff)
downloadpytorch-5f2d7ba963ee65b31026ca6b62d9629aea49ba3f.tar.gz
pytorch-5f2d7ba963ee65b31026ca6b62d9629aea49ba3f.tar.bz2
pytorch-5f2d7ba963ee65b31026ca6b62d9629aea49ba3f.zip
misc: experimental cuda elementwise rtc, softmax fp16
Diffstat (limited to 'caffe2/cuda_rtc/common_rtc.h')
-rw-r--r--caffe2/cuda_rtc/common_rtc.h11
1 files changed, 11 insertions, 0 deletions
diff --git a/caffe2/cuda_rtc/common_rtc.h b/caffe2/cuda_rtc/common_rtc.h
index 536a1c9e18..e3b377efda 100644
--- a/caffe2/cuda_rtc/common_rtc.h
+++ b/caffe2/cuda_rtc/common_rtc.h
@@ -81,6 +81,17 @@ class CudaRTCFunction {
args_voidp, 0));
}
+ void LaunchEx(unsigned int gx, unsigned int gy, unsigned int gz,
+ unsigned int bx, unsigned int by, unsigned int bz,
+ unsigned int shared_mem, cudaStream_t stream,
+ void** extra) {
+ CAFFE_CHECK(module_loaded_)
+ << "Cannot call Launch before a module is loaded.";
+ CUDA_DRIVERAPI_CHECK(cuLaunchKernel(
+ kernel_, gx, gy, gz, bx, by, bz, shared_mem, stream,
+ nullptr, extra));
+ }
+
private:
bool module_loaded_;
CUmodule module_;