diff options
author | Gregory Chanan <gchanan@fb.com> | 2017-05-30 11:10:16 -0700 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-06-11 05:37:59 -0400 |
commit | 65b23f146e7450105842654da432857a1b341815 (patch) | |
tree | 6d87e20e9d8f437ac3d5e1693e50b7fbb6a1e20d /torch/csrc/cuda | |
parent | c54e53295420defd8c38e4f5d12bf2dc977c91ea (diff) | |
download | pytorch-65b23f146e7450105842654da432857a1b341815.tar.gz pytorch-65b23f146e7450105842654da432857a1b341815.tar.bz2 pytorch-65b23f146e7450105842654da432857a1b341815.zip |
Add broadcasting support for copy_, simplify code generation by moving a lot of currently generated code to expand_utils.
Diffstat (limited to 'torch/csrc/cuda')
-rw-r--r-- | torch/csrc/cuda/expand_utils.cpp | 188 | ||||
-rw-r--r-- | torch/csrc/cuda/override_macros.h | 1 | ||||
-rw-r--r-- | torch/csrc/cuda/undef_macros.h | 1 |
3 files changed, 190 insertions, 0 deletions
diff --git a/torch/csrc/cuda/expand_utils.cpp b/torch/csrc/cuda/expand_utils.cpp new file mode 100644 index 0000000000..c0fd25f666 --- /dev/null +++ b/torch/csrc/cuda/expand_utils.cpp @@ -0,0 +1,188 @@ +#include "torch/csrc/cuda/THCP.h" +#include "torch/csrc/expand_utils.h" + +#include "torch/csrc/expand_utils-inl.h" + +template <> +THCudaTensor *newForExpand(THCState *s) { + return THCudaTensor_new(s); +} + +template <> +THCudaDoubleTensor *newForExpand(THCState *s) { + return THCudaDoubleTensor_new(s); +} + +#ifdef CUDA_HALF_TENSOR +template <> +THCudaHalfTensor *newForExpand(THCState *s) { + return THCudaHalfTensor_new(s); +} +#endif // CUDA_HALF_TENSOR + +template <> +THCudaByteTensor *newForExpand(THCState *s) { + return THCudaByteTensor_new(s); +} + +template <> +THCudaCharTensor *newForExpand(THCState *s) { + return THCudaCharTensor_new(s); +} + +template <> +THCudaShortTensor *newForExpand(THCState *s) { + return THCudaShortTensor_new(s); +} + +template <> +THCudaIntTensor *newForExpand(THCState *s) { + return THCudaIntTensor_new(s); +} + +template <> +THCudaLongTensor *newForExpand(THCState *s) { + return THCudaLongTensor_new(s); +} + +template<> +int expand(THCState *s, THCudaTensor *r, THCudaTensor *tensor, THLongStorage *sizes, int raiseErrors) { + return THCudaTensor_expand(s, r, tensor, sizes, raiseErrors); +} + +template<> +int expand(THCState *s, THCudaDoubleTensor *r, THCudaDoubleTensor *tensor, THLongStorage *sizes, int raiseErrors) { + return THCudaDoubleTensor_expand(s, r, tensor, sizes, raiseErrors); +} + +#ifdef CUDA_HALF_TENSOR +template<> +int expand(THCState *s, THCudaHalfTensor *r, THCudaHalfTensor *tensor, THLongStorage *sizes, int raiseErrors) { + return THCudaHalfTensor_expand(s, r, tensor, sizes, raiseErrors); +} +#endif // CUDA_HALF_TENSOR + +template<> +int expand(THCState *s, THCudaByteTensor *r, THCudaByteTensor *tensor, THLongStorage *sizes, int raiseErrors) { + return THCudaByteTensor_expand(s, r, tensor, sizes, raiseErrors); +} + +template<> +int expand(THCState *s, THCudaCharTensor *r, THCudaCharTensor *tensor, THLongStorage *sizes, int raiseErrors) { + return THCudaCharTensor_expand(s, r, tensor, sizes, raiseErrors); +} + +template<> +int expand(THCState *s, THCudaShortTensor *r, THCudaShortTensor *tensor, THLongStorage *sizes, int raiseErrors) { + return THCudaShortTensor_expand(s, r, tensor, sizes, raiseErrors); +} + +template<> +int expand(THCState *s, THCudaIntTensor *r, THCudaIntTensor *tensor, THLongStorage *sizes, int raiseErrors) { + return THCudaIntTensor_expand(s, r, tensor, sizes, raiseErrors); +} + +template<> +int expand(THCState *s, THCudaLongTensor *r, THCudaLongTensor *tensor, THLongStorage *sizes, int raiseErrors) { + return THCudaLongTensor_expand(s, r, tensor, sizes, raiseErrors); +} + +template <> +int expand2(THCState *s, THCudaTensor *r1, THCudaTensor *r2, + THCudaTensor *e1, THCudaTensor *e2, int raiseErrors) { + return THCudaTensor_expand2(s, r1, r2, e1, e2, raiseErrors); +} + +template <> +int expand2(THCState *s, THCudaDoubleTensor *r1, THCudaDoubleTensor *r2, + THCudaDoubleTensor *e1, THCudaDoubleTensor *e2, int raiseErrors) { + return THCudaDoubleTensor_expand2(s, r1, r2, e1, e2, raiseErrors); +} + +#ifdef CUDA_HALF_TENSOR +template <> +int expand2(THCState *s, THCudaHalfTensor *r1, THCudaHalfTensor *r2, + THCudaHalfTensor *e1, THCudaHalfTensor *e2, int raiseErrors) { + return THCudaHalfTensor_expand2(s, r1, r2, e1, e2, raiseErrors); +} +#endif // CUDA_HALF_TENSOR + +template <> +int expand2(THCState *s, THCudaByteTensor *r1, THCudaByteTensor *r2, + THCudaByteTensor *e1, THCudaByteTensor *e2, int raiseErrors) { + return THCudaByteTensor_expand2(s, r1, r2, e1, e2, raiseErrors); +} + +template <> +int expand2(THCState *s, THCudaCharTensor *r1, THCudaCharTensor *r2, + THCudaCharTensor *e1, THCudaCharTensor *e2, int raiseErrors) { + return THCudaCharTensor_expand2(s, r1, r2, e1, e2, raiseErrors); +} + +template <> +int expand2(THCState *s, THCudaShortTensor *r1, THCudaShortTensor *r2, + THCudaShortTensor *e1, THCudaShortTensor *e2, int raiseErrors) { + return THCudaShortTensor_expand2(s, r1, r2, e1, e2, raiseErrors); +} + +template <> +int expand2(THCState *s, THCudaIntTensor *r1, THCudaIntTensor *r2, + THCudaIntTensor *e1, THCudaIntTensor *e2, int raiseErrors) { + return THCudaIntTensor_expand2(s, r1, r2, e1, e2, raiseErrors); +} + +template <> +int expand2(THCState *s, THCudaLongTensor *r1, THCudaLongTensor *r2, + THCudaLongTensor *e1, THCudaLongTensor *e2, int raiseErrors) { + return THCudaLongTensor_expand2(s, r1, r2, e1, e2, raiseErrors); +} + +template <> +int expand3(THCState *s, THCudaTensor *r1, THCudaTensor *r2, THCudaTensor *r3, + THCudaTensor *e1, THCudaTensor *e2, THCudaTensor *e3, int raiseErrors) { + return THCudaTensor_expand3(s, r1, r2, r3, e1, e2, e3, raiseErrors); +} + +template <> +int expand3(THCState *s, THCudaDoubleTensor *r1, THCudaDoubleTensor *r2, THCudaDoubleTensor *r3, + THCudaDoubleTensor *e1, THCudaDoubleTensor *e2, THCudaDoubleTensor *e3, int raiseErrors) { + return THCudaDoubleTensor_expand3(s, r1, r2, r3, e1, e2, e3, raiseErrors); +} + +#ifdef CUDA_HALF_TENSOR +template <> +int expand3(THCState *s, THCudaHalfTensor *r1, THCudaHalfTensor *r2, THCudaHalfTensor *r3, + THCudaHalfTensor *e1, THCudaHalfTensor *e2, THCudaHalfTensor *e3, int raiseErrors) { + return THCudaHalfTensor_expand3(s, r1, r2, r3, e1, e2, e3, raiseErrors); +} +#endif // CUDA_HALF_TENSOR + +template <> +int expand3(THCState *s, THCudaByteTensor *r1, THCudaByteTensor *r2, THCudaByteTensor *r3, + THCudaByteTensor *e1, THCudaByteTensor *e2, THCudaByteTensor *e3, int raiseErrors) { + return THCudaByteTensor_expand3(s, r1, r2, r3, e1, e2, e3, raiseErrors); +} + +template <> +int expand3(THCState *s, THCudaCharTensor *r1, THCudaCharTensor *r2, THCudaCharTensor *r3, + THCudaCharTensor *e1, THCudaCharTensor *e2, THCudaCharTensor *e3, int raiseErrors) { + return THCudaCharTensor_expand3(s, r1, r2, r3, e1, e2, e3, raiseErrors); +} + +template <> +int expand3(THCState *s, THCudaShortTensor *r1, THCudaShortTensor *r2, THCudaShortTensor *r3, + THCudaShortTensor *e1, THCudaShortTensor *e2, THCudaShortTensor *e3, int raiseErrors) { + return THCudaShortTensor_expand3(s, r1, r2, r3, e1, e2, e3, raiseErrors); +} + +template <> +int expand3(THCState *s, THCudaIntTensor *r1, THCudaIntTensor *r2, THCudaIntTensor *r3, + THCudaIntTensor *e1, THCudaIntTensor *e2, THCudaIntTensor *e3, int raiseErrors) { + return THCudaIntTensor_expand3(s, r1, r2, r3, e1, e2, e3, raiseErrors); +} + +template <> +int expand3(THCState *s, THCudaLongTensor *r1, THCudaLongTensor *r2, THCudaLongTensor *r3, + THCudaLongTensor *e1, THCudaLongTensor *e2, THCudaLongTensor *e3, int raiseErrors) { + return THCudaLongTensor_expand3(s, r1, r2, r3, e1, e2, e3, raiseErrors); +} diff --git a/torch/csrc/cuda/override_macros.h b/torch/csrc/cuda/override_macros.h index cef3b29184..a124903bf5 100644 --- a/torch/csrc/cuda/override_macros.h +++ b/torch/csrc/cuda/override_macros.h @@ -49,6 +49,7 @@ #define LIBRARY_STATE_NOARGS state #define LIBRARY_STATE state, #define LIBRARY_STATE_TYPE THCState*, +#define LIBRARY_STATE_TYPE_NOARGS THCState* #define TH_GENERIC_FILE THC_GENERIC_FILE #define THHostTensor TH_CONCAT_3(TH,Real,Tensor) diff --git a/torch/csrc/cuda/undef_macros.h b/torch/csrc/cuda/undef_macros.h index d1a8c10d48..bfc6600959 100644 --- a/torch/csrc/cuda/undef_macros.h +++ b/torch/csrc/cuda/undef_macros.h @@ -2,6 +2,7 @@ #undef LIBRARY_STATE #undef LIBRARY_STATE_NOARGS #undef LIBRARY_STATE_TYPE +#undef LIBRARY_STATE_TYPE_NOARGS #undef THPTensor_ #undef THPTensor_stateless_ |