summaryrefslogtreecommitdiff
path: root/torch/csrc/cuda
diff options
context:
space:
mode:
authorGregory Chanan <gchanan@fb.com>2017-05-30 11:10:16 -0700
committerSoumith Chintala <soumith@gmail.com>2017-06-11 05:37:59 -0400
commit65b23f146e7450105842654da432857a1b341815 (patch)
tree6d87e20e9d8f437ac3d5e1693e50b7fbb6a1e20d /torch/csrc/cuda
parentc54e53295420defd8c38e4f5d12bf2dc977c91ea (diff)
downloadpytorch-65b23f146e7450105842654da432857a1b341815.tar.gz
pytorch-65b23f146e7450105842654da432857a1b341815.tar.bz2
pytorch-65b23f146e7450105842654da432857a1b341815.zip
Add broadcasting support for copy_, simplify code generation by moving a lot of currently generated code to expand_utils.
Diffstat (limited to 'torch/csrc/cuda')
-rw-r--r--torch/csrc/cuda/expand_utils.cpp188
-rw-r--r--torch/csrc/cuda/override_macros.h1
-rw-r--r--torch/csrc/cuda/undef_macros.h1
3 files changed, 190 insertions, 0 deletions
diff --git a/torch/csrc/cuda/expand_utils.cpp b/torch/csrc/cuda/expand_utils.cpp
new file mode 100644
index 0000000000..c0fd25f666
--- /dev/null
+++ b/torch/csrc/cuda/expand_utils.cpp
@@ -0,0 +1,188 @@
+#include "torch/csrc/cuda/THCP.h"
+#include "torch/csrc/expand_utils.h"
+
+#include "torch/csrc/expand_utils-inl.h"
+
+template <>
+THCudaTensor *newForExpand(THCState *s) {
+ return THCudaTensor_new(s);
+}
+
+template <>
+THCudaDoubleTensor *newForExpand(THCState *s) {
+ return THCudaDoubleTensor_new(s);
+}
+
+#ifdef CUDA_HALF_TENSOR
+template <>
+THCudaHalfTensor *newForExpand(THCState *s) {
+ return THCudaHalfTensor_new(s);
+}
+#endif // CUDA_HALF_TENSOR
+
+template <>
+THCudaByteTensor *newForExpand(THCState *s) {
+ return THCudaByteTensor_new(s);
+}
+
+template <>
+THCudaCharTensor *newForExpand(THCState *s) {
+ return THCudaCharTensor_new(s);
+}
+
+template <>
+THCudaShortTensor *newForExpand(THCState *s) {
+ return THCudaShortTensor_new(s);
+}
+
+template <>
+THCudaIntTensor *newForExpand(THCState *s) {
+ return THCudaIntTensor_new(s);
+}
+
+template <>
+THCudaLongTensor *newForExpand(THCState *s) {
+ return THCudaLongTensor_new(s);
+}
+
+template<>
+int expand(THCState *s, THCudaTensor *r, THCudaTensor *tensor, THLongStorage *sizes, int raiseErrors) {
+ return THCudaTensor_expand(s, r, tensor, sizes, raiseErrors);
+}
+
+template<>
+int expand(THCState *s, THCudaDoubleTensor *r, THCudaDoubleTensor *tensor, THLongStorage *sizes, int raiseErrors) {
+ return THCudaDoubleTensor_expand(s, r, tensor, sizes, raiseErrors);
+}
+
+#ifdef CUDA_HALF_TENSOR
+template<>
+int expand(THCState *s, THCudaHalfTensor *r, THCudaHalfTensor *tensor, THLongStorage *sizes, int raiseErrors) {
+ return THCudaHalfTensor_expand(s, r, tensor, sizes, raiseErrors);
+}
+#endif // CUDA_HALF_TENSOR
+
+template<>
+int expand(THCState *s, THCudaByteTensor *r, THCudaByteTensor *tensor, THLongStorage *sizes, int raiseErrors) {
+ return THCudaByteTensor_expand(s, r, tensor, sizes, raiseErrors);
+}
+
+template<>
+int expand(THCState *s, THCudaCharTensor *r, THCudaCharTensor *tensor, THLongStorage *sizes, int raiseErrors) {
+ return THCudaCharTensor_expand(s, r, tensor, sizes, raiseErrors);
+}
+
+template<>
+int expand(THCState *s, THCudaShortTensor *r, THCudaShortTensor *tensor, THLongStorage *sizes, int raiseErrors) {
+ return THCudaShortTensor_expand(s, r, tensor, sizes, raiseErrors);
+}
+
+template<>
+int expand(THCState *s, THCudaIntTensor *r, THCudaIntTensor *tensor, THLongStorage *sizes, int raiseErrors) {
+ return THCudaIntTensor_expand(s, r, tensor, sizes, raiseErrors);
+}
+
+template<>
+int expand(THCState *s, THCudaLongTensor *r, THCudaLongTensor *tensor, THLongStorage *sizes, int raiseErrors) {
+ return THCudaLongTensor_expand(s, r, tensor, sizes, raiseErrors);
+}
+
+template <>
+int expand2(THCState *s, THCudaTensor *r1, THCudaTensor *r2,
+ THCudaTensor *e1, THCudaTensor *e2, int raiseErrors) {
+ return THCudaTensor_expand2(s, r1, r2, e1, e2, raiseErrors);
+}
+
+template <>
+int expand2(THCState *s, THCudaDoubleTensor *r1, THCudaDoubleTensor *r2,
+ THCudaDoubleTensor *e1, THCudaDoubleTensor *e2, int raiseErrors) {
+ return THCudaDoubleTensor_expand2(s, r1, r2, e1, e2, raiseErrors);
+}
+
+#ifdef CUDA_HALF_TENSOR
+template <>
+int expand2(THCState *s, THCudaHalfTensor *r1, THCudaHalfTensor *r2,
+ THCudaHalfTensor *e1, THCudaHalfTensor *e2, int raiseErrors) {
+ return THCudaHalfTensor_expand2(s, r1, r2, e1, e2, raiseErrors);
+}
+#endif // CUDA_HALF_TENSOR
+
+template <>
+int expand2(THCState *s, THCudaByteTensor *r1, THCudaByteTensor *r2,
+ THCudaByteTensor *e1, THCudaByteTensor *e2, int raiseErrors) {
+ return THCudaByteTensor_expand2(s, r1, r2, e1, e2, raiseErrors);
+}
+
+template <>
+int expand2(THCState *s, THCudaCharTensor *r1, THCudaCharTensor *r2,
+ THCudaCharTensor *e1, THCudaCharTensor *e2, int raiseErrors) {
+ return THCudaCharTensor_expand2(s, r1, r2, e1, e2, raiseErrors);
+}
+
+template <>
+int expand2(THCState *s, THCudaShortTensor *r1, THCudaShortTensor *r2,
+ THCudaShortTensor *e1, THCudaShortTensor *e2, int raiseErrors) {
+ return THCudaShortTensor_expand2(s, r1, r2, e1, e2, raiseErrors);
+}
+
+template <>
+int expand2(THCState *s, THCudaIntTensor *r1, THCudaIntTensor *r2,
+ THCudaIntTensor *e1, THCudaIntTensor *e2, int raiseErrors) {
+ return THCudaIntTensor_expand2(s, r1, r2, e1, e2, raiseErrors);
+}
+
+template <>
+int expand2(THCState *s, THCudaLongTensor *r1, THCudaLongTensor *r2,
+ THCudaLongTensor *e1, THCudaLongTensor *e2, int raiseErrors) {
+ return THCudaLongTensor_expand2(s, r1, r2, e1, e2, raiseErrors);
+}
+
+template <>
+int expand3(THCState *s, THCudaTensor *r1, THCudaTensor *r2, THCudaTensor *r3,
+ THCudaTensor *e1, THCudaTensor *e2, THCudaTensor *e3, int raiseErrors) {
+ return THCudaTensor_expand3(s, r1, r2, r3, e1, e2, e3, raiseErrors);
+}
+
+template <>
+int expand3(THCState *s, THCudaDoubleTensor *r1, THCudaDoubleTensor *r2, THCudaDoubleTensor *r3,
+ THCudaDoubleTensor *e1, THCudaDoubleTensor *e2, THCudaDoubleTensor *e3, int raiseErrors) {
+ return THCudaDoubleTensor_expand3(s, r1, r2, r3, e1, e2, e3, raiseErrors);
+}
+
+#ifdef CUDA_HALF_TENSOR
+template <>
+int expand3(THCState *s, THCudaHalfTensor *r1, THCudaHalfTensor *r2, THCudaHalfTensor *r3,
+ THCudaHalfTensor *e1, THCudaHalfTensor *e2, THCudaHalfTensor *e3, int raiseErrors) {
+ return THCudaHalfTensor_expand3(s, r1, r2, r3, e1, e2, e3, raiseErrors);
+}
+#endif // CUDA_HALF_TENSOR
+
+template <>
+int expand3(THCState *s, THCudaByteTensor *r1, THCudaByteTensor *r2, THCudaByteTensor *r3,
+ THCudaByteTensor *e1, THCudaByteTensor *e2, THCudaByteTensor *e3, int raiseErrors) {
+ return THCudaByteTensor_expand3(s, r1, r2, r3, e1, e2, e3, raiseErrors);
+}
+
+template <>
+int expand3(THCState *s, THCudaCharTensor *r1, THCudaCharTensor *r2, THCudaCharTensor *r3,
+ THCudaCharTensor *e1, THCudaCharTensor *e2, THCudaCharTensor *e3, int raiseErrors) {
+ return THCudaCharTensor_expand3(s, r1, r2, r3, e1, e2, e3, raiseErrors);
+}
+
+template <>
+int expand3(THCState *s, THCudaShortTensor *r1, THCudaShortTensor *r2, THCudaShortTensor *r3,
+ THCudaShortTensor *e1, THCudaShortTensor *e2, THCudaShortTensor *e3, int raiseErrors) {
+ return THCudaShortTensor_expand3(s, r1, r2, r3, e1, e2, e3, raiseErrors);
+}
+
+template <>
+int expand3(THCState *s, THCudaIntTensor *r1, THCudaIntTensor *r2, THCudaIntTensor *r3,
+ THCudaIntTensor *e1, THCudaIntTensor *e2, THCudaIntTensor *e3, int raiseErrors) {
+ return THCudaIntTensor_expand3(s, r1, r2, r3, e1, e2, e3, raiseErrors);
+}
+
+template <>
+int expand3(THCState *s, THCudaLongTensor *r1, THCudaLongTensor *r2, THCudaLongTensor *r3,
+ THCudaLongTensor *e1, THCudaLongTensor *e2, THCudaLongTensor *e3, int raiseErrors) {
+ return THCudaLongTensor_expand3(s, r1, r2, r3, e1, e2, e3, raiseErrors);
+}
diff --git a/torch/csrc/cuda/override_macros.h b/torch/csrc/cuda/override_macros.h
index cef3b29184..a124903bf5 100644
--- a/torch/csrc/cuda/override_macros.h
+++ b/torch/csrc/cuda/override_macros.h
@@ -49,6 +49,7 @@
#define LIBRARY_STATE_NOARGS state
#define LIBRARY_STATE state,
#define LIBRARY_STATE_TYPE THCState*,
+#define LIBRARY_STATE_TYPE_NOARGS THCState*
#define TH_GENERIC_FILE THC_GENERIC_FILE
#define THHostTensor TH_CONCAT_3(TH,Real,Tensor)
diff --git a/torch/csrc/cuda/undef_macros.h b/torch/csrc/cuda/undef_macros.h
index d1a8c10d48..bfc6600959 100644
--- a/torch/csrc/cuda/undef_macros.h
+++ b/torch/csrc/cuda/undef_macros.h
@@ -2,6 +2,7 @@
#undef LIBRARY_STATE
#undef LIBRARY_STATE_NOARGS
#undef LIBRARY_STATE_TYPE
+#undef LIBRARY_STATE_TYPE_NOARGS
#undef THPTensor_
#undef THPTensor_stateless_