diff options
author | gchanan <gregchanan@gmail.com> | 2018-04-16 23:52:59 -0400 |
---|---|---|
committer | Edward Z. Yang <ezyang@mit.edu> | 2018-04-16 23:52:59 -0400 |
commit | 5ed3f3347a5684a2b6208546e34c9a13771e77ab (patch) | |
tree | d3db48be704db088d4b33a026d3b98c464127b58 /aten | |
parent | dd91d57c3f5647fb4ac63ac4325a42224f9a3028 (diff) | |
download | pytorch-5ed3f3347a5684a2b6208546e34c9a13771e77ab.tar.gz pytorch-5ed3f3347a5684a2b6208546e34c9a13771e77ab.tar.bz2 pytorch-5ed3f3347a5684a2b6208546e34c9a13771e77ab.zip |
Add dtypes (with reasonable defaults) to sum, prod, cumsum, cumprod. (#6573)
* Add dtypes (with reasonable defaults) to sum, prod, cumsum, cumprod.
This adds optional dtypes to torch.sum, torch.prod, torch.cumsum, torch.cumprod.
By default, the dtype is torch.float64 for integral types, and the dtype of the input for floating point types.
* Don't use optional<ScalarType>, because the jit can't handle it yet.
Instead, we manually build the overloads. This is fairly painful because of default arguments, but should be easy to pull out once the jit can handle optional<ScalarType>.
* Fix keepdim with out parameters.
* Fix _cudnn_rnn_flatten_weight.
* If dtype is provided to an out function, make sure it matches the dtype of the result.
* Fix typo.
Diffstat (limited to 'aten')
-rw-r--r-- | aten/src/ATen/Declarations.cwrap | 10 | ||||
-rw-r--r-- | aten/src/ATen/native/ReduceOps.cpp | 174 | ||||
-rw-r--r-- | aten/src/ATen/native/cudnn/RNN.cpp | 2 | ||||
-rw-r--r-- | aten/src/ATen/native/native_functions.yaml | 62 | ||||
-rw-r--r-- | aten/src/ATen/native_parse.py | 2 |
5 files changed, 237 insertions, 13 deletions
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap index 7891652109..3666b77e42 100644 --- a/aten/src/ATen/Declarations.cwrap +++ b/aten/src/ATen/Declarations.cwrap @@ -2257,7 +2257,7 @@ - THTensor* self ]] [[ - name: _sum + name: _th_sum variants: - method - function @@ -2286,7 +2286,7 @@ - THTensor* self ]] [[ - name: _prod + name: _th_prod variants: - method - function @@ -2304,7 +2304,8 @@ default: "false" ]] [[ - name: cumsum + name: _cumsum + cname: cumsum variants: - method - function @@ -2317,7 +2318,8 @@ wrap_dim: self ]] [[ - name: cumprod + name: _cumprod + cname: cumprod variants: - method - function diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index f06674ca82..ac17173a18 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -15,8 +15,82 @@ namespace at { namespace native { +static inline Tensor integer_upcast(const Tensor& self, optional<ScalarType> dtype) { + ScalarType scalarType = self.type().scalarType(); + ScalarType upcast_scalarType = dtype.value_or(at::isIntegralType(scalarType) ? ScalarType::Long : scalarType); + return self.toType(upcast_scalarType); +} + +static inline Tensor cumsum(const Tensor& self, int64_t dim, optional<ScalarType> dtype) { + return at::_cumsum(integer_upcast(self, dtype), dim); +} + +Tensor cumsum(const Tensor& self, int64_t dim, ScalarType dtype) { + return at::native::cumsum(self, dim, optional<ScalarType>(dtype)); +} + +Tensor cumsum(const Tensor& self, int64_t dim) { + return at::native::cumsum(self, dim, nullopt); +} + +static inline Tensor& cumsum_out(Tensor& result, const Tensor& self, int64_t dim, optional<ScalarType> dtype) { + // result type is favored over dtype; check that they match if provided (NumPy doesn't check) + AT_ASSERT(!dtype.has_value() || (result.type().scalarType() == dtype.value()), + "provided dtype must match dtype of result in cumsum. Got %s and %s.", + at::toString(result.type().scalarType()), at::toString(dtype.value())); + return at::_cumsum_out(result, self.toType(result.type().scalarType()), dim); +} + +Tensor& cumsum_out(Tensor& result, const Tensor& self, int64_t dim, ScalarType dtype) { + return at::native::cumsum_out(result, self, dim, optional<ScalarType>(dtype)); +} + +Tensor& cumsum_out(Tensor& result, const Tensor& self, int64_t dim) { + return at::native::cumsum_out(result, self, dim, nullopt); +} + +static inline Tensor cumprod(const Tensor& self, int64_t dim, optional<ScalarType> dtype) { + return at::_cumprod(integer_upcast(self, dtype), dim); +} + +Tensor cumprod(const Tensor& self, int64_t dim, ScalarType dtype) { + return at::native::cumprod(self, dim, optional<ScalarType>(dtype)); +} + +Tensor cumprod(const Tensor& self, int64_t dim) { + return at::native::cumprod(self, dim, nullopt); +} + +static inline Tensor& cumprod_out(Tensor& result, const Tensor& self, int64_t dim, optional<ScalarType> dtype) { + // result type is favored over dtype; check that they match if provided (NumPy doesn't check) + AT_ASSERT(!dtype.has_value() || (result.type().scalarType() == dtype.value()), + "provided dtype must match dtype of result in cumprod. Got %s and %s.", + at::toString(result.type().scalarType()), at::toString(dtype.value())); + return at::_cumprod_out(result, self.toType(result.type().scalarType()), dim); +} + +Tensor& cumprod_out(Tensor& result, const Tensor& self, int64_t dim, ScalarType dtype) { + return at::native::cumprod_out(result, self, dim, optional<ScalarType>(dtype)); +} + +Tensor& cumprod_out(Tensor& result, const Tensor& self, int64_t dim) { + return at::native::cumprod_out(result, self, dim, nullopt); +} + // ALL REDUCE ################################################################# +static inline Tensor sum(const Tensor &self, optional<ScalarType> dtype) { + return at::_sum(integer_upcast(self, dtype)); +} + +Tensor sum(const Tensor &self, ScalarType dtype) { + return at::native::sum(self, optional<ScalarType>(dtype)); +} + +Tensor sum(const Tensor &self) { + return at::native::sum(self, nullopt); +} + Tensor _sum_cpu(const Tensor& self) { if (self.is_contiguous()) { Tensor result = self.type().tensor({}); @@ -26,6 +100,18 @@ Tensor _sum_cpu(const Tensor& self) { return self._sumall(); } +static inline Tensor prod(const Tensor &self, optional<ScalarType> dtype) { + return at::_prod(integer_upcast(self, dtype)); +} + +Tensor prod(const Tensor &self, ScalarType dtype) { + return at::native::prod(self, optional<ScalarType>(dtype)); +} + +Tensor prod(const Tensor &self) { + return at::native::prod(self, nullopt); +} + Tensor _prod_cpu(const Tensor &self) { if (self.is_contiguous()) { Tensor result = self.type().tensor({}); @@ -69,6 +155,26 @@ static Tensor &_dimreduce_setup(Tensor &result, const Tensor &self, return result; } +static inline Tensor &sum_out(Tensor &result, const Tensor &self, int64_t dim, + bool keepdim, optional<ScalarType> dtype) { + // result type is favored over dtype; check that they match if provided (NumPy doesn't check) + AT_ASSERT(!dtype.has_value() || (result.type().scalarType() == dtype.value()), + "provided dtype must match dtype of result in sum. Got %s and %s.", + at::toString(result.type().scalarType()), at::toString(dtype.value())); + return at::_sum_out(result, self.toType(result.type().scalarType()), dim, keepdim); +} + +Tensor& sum_out(Tensor& result, const Tensor& self, int64_t dim, bool keepdim, ScalarType dtype) { + return at::native::sum_out(result, self, dim, keepdim, at::optional<ScalarType>(dtype)); +} +Tensor& sum_out(Tensor& result, const Tensor& self, int64_t dim, bool keepdim) { + return at::native::sum_out(result, self, dim, keepdim, nullopt); +} + +Tensor& sum_out(Tensor& result, const Tensor& self, int64_t dim, ScalarType dtype) { + return at::native::sum_out(result, self, dim, false, dtype); +} + Tensor &_sum_out_cpu(Tensor &result, const Tensor &self, int64_t dim_, bool keepdim) { int64_t dim = maybe_wrap_dim(dim_, self.dim()); @@ -80,7 +186,27 @@ Tensor &_sum_out_cpu(Tensor &result, const Tensor &self, int64_t dim_, if (!keepdim) result.squeeze_(dim); return result; } - return at::_sum_out(result, self, dim, keepdim); + return at::_th_sum_out(result, self, dim, keepdim); +} + +static inline Tensor &prod_out(Tensor &result, const Tensor &self, int64_t dim, + bool keepdim, optional<ScalarType> dtype) { + // result type is favored over dtype; check that they match if provided (NumPy doesn't check) + AT_ASSERT(!dtype.has_value() || (result.type().scalarType() == dtype.value()), + "provided dtype must match dtype of result in prod. Got %s and %s.", + at::toString(result.type().scalarType()), at::toString(dtype.value())); + return at::_prod_out(result, self.toType(result.type().scalarType()), dim, keepdim); +} + +Tensor& prod_out(Tensor& result, const Tensor& self, int64_t dim, bool keepdim, ScalarType dtype) { + return at::native::prod_out(result, self, dim, keepdim, at::optional<ScalarType>(dtype)); +} +Tensor& prod_out(Tensor& result, const Tensor& self, int64_t dim, bool keepdim) { + return at::native::prod_out(result, self, dim, keepdim, nullopt); +} + +Tensor& prod_out(Tensor& result, const Tensor& self, int64_t dim, ScalarType dtype) { + return at::native::prod_out(result, self, dim, false, dtype); } Tensor &_prod_out_cpu(Tensor &result, const Tensor &self, int64_t dim_, @@ -94,29 +220,61 @@ Tensor &_prod_out_cpu(Tensor &result, const Tensor &self, int64_t dim_, if (!keepdim) result.squeeze_(dim); return result; } - return at::_prod_out(result, self, dim, keepdim); + return at::_th_prod_out(result, self, dim, keepdim); } Tensor &_sum_out_cuda(Tensor &result, const Tensor &self, int64_t dim, bool keepdim) { - return at::_sum_out(result, self, dim, keepdim); + return at::_th_sum_out(result, self, dim, keepdim); } Tensor &_prod_out_cuda(Tensor &result, const Tensor &self, int64_t dim, bool keepdim) { - return at::_prod_out(result, self, dim, keepdim); + return at::_th_prod_out(result, self, dim, keepdim); } -Tensor sum(const Tensor &self, int64_t dim_, bool keepdim) { +static inline Tensor sum(const Tensor &self, int64_t dim_, bool keepdim, optional<ScalarType> dtype) { + return at::_sum(integer_upcast(self, dtype), dim_, keepdim); +} + +Tensor sum(const Tensor& self, int64_t dim, bool keepdim, ScalarType dtype) { + return at::native::sum(self, dim, keepdim, at::optional<ScalarType>(dtype)); +} + +Tensor sum(const Tensor& self, int64_t dim, bool keepdim) { + return at::native::sum(self, dim, keepdim, nullopt); +} + +Tensor sum(const Tensor& self, int64_t dim, ScalarType dtype) { + return at::native::sum(self, dim, false, dtype); +} + +Tensor _sum(const Tensor &self, int64_t dim_, bool keepdim) { int64_t dim = maybe_wrap_dim(dim_, self.dim()); Tensor result = self.type().tensor(); - return at::sum_out(result, self, dim, keepdim); + return at::_sum_out(result, self, dim, keepdim); } -Tensor prod(const Tensor &self, int64_t dim_, bool keepdim) { +static inline Tensor prod(const Tensor &self, int64_t dim_, bool keepdim, optional<ScalarType> dtype) { + return at::_prod(integer_upcast(self, dtype), dim_, keepdim); +} + +Tensor prod(const Tensor& self, int64_t dim, bool keepdim, ScalarType dtype) { + return at::native::prod(self, dim, keepdim, at::optional<ScalarType>(dtype)); +} + +Tensor prod(const Tensor& self, int64_t dim, bool keepdim) { + return at::native::prod(self, dim, keepdim, nullopt); +} + +Tensor prod(const Tensor& self, int64_t dim, ScalarType dtype) { + return at::native::prod(self, dim, false, dtype); +} + +Tensor _prod(const Tensor &self, int64_t dim_, bool keepdim) { int64_t dim = maybe_wrap_dim(dim_, self.dim()); Tensor result = self.type().tensor(); - return at::prod_out(result, self, dim, keepdim); + return at::_prod_out(result, self, dim, keepdim); } // \DIM REDUCE ################################################################ diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp index 6d2c5c3259..3a53353810 100644 --- a/aten/src/ATen/native/cudnn/RNN.cpp +++ b/aten/src/ATen/native/cudnn/RNN.cpp @@ -448,7 +448,7 @@ namespace { // (same for the hh weights, and the ih and hh biases). // Since we're storing all the weights in a single tensor anyway, // might as well merge the CUDNN ones into a single tensor as well - int mat_numel = *filter_dim_a.prod().data<int>(); + int mat_numel = *filter_dim_a.prod(at::ScalarType::Int).data<int>(); if (linear_id == 0 || linear_id == num_linear_layers / 2) { std::initializer_list<int64_t> size = { mat_numel * num_linear_layers / 2, 1}; diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 47214e0074..5ed406cb77 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -250,6 +250,28 @@ name: grad_grid variants: function +# FIXME: These could be combined as optional<ScalarType> but for https://github.com/pytorch/pytorch/issues/6593. +- func: cumsum(Tensor self, int64_t dim, *, ScalarType dtype) -> Tensor + +- func: cumsum(Tensor self, int64_t dim) -> Tensor + +- func: cumsum_out(Tensor result, Tensor self, int64_t dim, *, ScalarType dtype) -> Tensor + variants: function + +- func: cumsum_out(Tensor result, Tensor self, int64_t dim) -> Tensor + variants: function + +# FIXME: These could be combined as optional<ScalarType> but for https://github.com/pytorch/pytorch/issues/6593. +- func: cumprod(Tensor self, int64_t dim, *, ScalarType dtype) -> Tensor + +- func: cumprod(Tensor self, int64_t dim) -> Tensor + +- func: cumprod_out(Tensor result, Tensor self, int64_t dim, *, ScalarType dtype) -> Tensor + variants: function + +- func: cumprod_out(Tensor result, Tensor self, int64_t dim) -> Tensor + variants: function + - func: det(Tensor self) -> Tensor - func: diagflat(Tensor self, int64_t offset=0) -> Tensor @@ -653,15 +675,35 @@ - func: stride(Tensor self, int64_t dim) -> int64_t +# FIXME: These could be combined as optional<ScalarType> but for https://github.com/pytorch/pytorch/issues/6593. +- func: sum(Tensor self, *, ScalarType dtype) -> Tensor + - func: sum(Tensor self) -> Tensor + +- func: _sum(Tensor self) -> Tensor dispatch: CPU: _sum_cpu CUDA: _sum_cuda +- func: sum(Tensor self, int64_t dim, bool keepdim, *, ScalarType dtype) -> Tensor + - func: sum(Tensor self, int64_t dim, bool keepdim=False) -> Tensor +- func: sum(Tensor self, int64_t dim, *, ScalarType dtype) -> Tensor + +- func: _sum(Tensor self, int64_t dim, bool keepdim=False) -> Tensor + +- func: sum_out(Tensor result, Tensor self, int64_t dim, bool keepdim, *, ScalarType dtype) -> Tensor + variants: function + - func: sum_out(Tensor result, Tensor self, int64_t dim, bool keepdim=False) -> Tensor variants: function + +- func: sum_out(Tensor result, Tensor self, int64_t dim, *, ScalarType dtype) -> Tensor + variants: function + +- func: _sum_out(Tensor result, Tensor self, int64_t dim, bool keepdim=False) -> Tensor + variants: function dispatch: CPU: _sum_out_cpu CUDA: _sum_out_cuda @@ -676,15 +718,35 @@ CPU: _sqrt_out_cpu CUDA: _sqrt_out_cuda +# FIXME: These could be combined as optional<ScalarType> but for https://github.com/pytorch/pytorch/issues/6593. +- func: prod(Tensor self, *, ScalarType dtype) -> Tensor + - func: prod(Tensor self) -> Tensor + +- func: _prod(Tensor self) -> Tensor dispatch: CPU: _prod_cpu CUDA: _prod_cuda +- func: prod(Tensor self, int64_t dim, bool keepdim, *, ScalarType dtype) -> Tensor + - func: prod(Tensor self, int64_t dim, bool keepdim=False) -> Tensor +- func: prod(Tensor self, int64_t dim, *, ScalarType dtype) -> Tensor + +- func: _prod(Tensor self, int64_t dim, bool keepdim=False) -> Tensor + +- func: prod_out(Tensor result, Tensor self, int64_t dim, bool keepdim, *, ScalarType dtype) -> Tensor + variants: function + - func: prod_out(Tensor result, Tensor self, int64_t dim, bool keepdim=False) -> Tensor variants: function + +- func: prod_out(Tensor result, Tensor self, int64_t dim, *, ScalarType dtype) -> Tensor + variants: function + +- func: _prod_out(Tensor result, Tensor self, int64_t dim, bool keepdim=False) -> Tensor + variants: function dispatch: CPU: _prod_out_cpu CUDA: _prod_out_cuda diff --git a/aten/src/ATen/native_parse.py b/aten/src/ATen/native_parse.py index bb109f849f..c64b6a0a21 100644 --- a/aten/src/ATen/native_parse.py +++ b/aten/src/ATen/native_parse.py @@ -17,6 +17,8 @@ def parse_default(s): return s elif s == '{}': return '{}' + elif s == 'nullopt': + return s try: return int(s) except Exception: |