diff options
author | jiej <jiej@nvidia.com> | 2019-01-16 22:12:13 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-01-16 22:15:25 -0800 |
commit | 7c56db73d5a9e1432dabc0231acad63575c3089e (patch) | |
tree | fd26ce91d87d61ab336165285ef04c28a5dd3f8d | |
parent | 55511004d17bd3e0e36e88efa6abdc9a5a03dec1 (diff) | |
download | pytorch-7c56db73d5a9e1432dabc0231acad63575c3089e.tar.gz pytorch-7c56db73d5a9e1432dabc0231acad63575c3089e.tar.bz2 pytorch-7c56db73d5a9e1432dabc0231acad63575c3089e.zip |
Moving torch.norm to ATen using TensorIterator (#15414)
Summary:
Adding supports for torch.nomr:
i. multi dimensions for dim
ii. dtype that specifies math/output tensor type
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15414
Differential Revision: D13702022
Pulled By: ezyang
fbshipit-source-id: da2676f2b6aff988889b1539d0de8ecd4946823a
-rw-r--r-- | aten/src/ATen/Declarations.cwrap | 40 | ||||
-rw-r--r-- | aten/src/ATen/core/Tensor.h | 4 | ||||
-rw-r--r-- | aten/src/ATen/core/TensorMethods.h | 8 | ||||
-rw-r--r-- | aten/src/ATen/core/Type.h | 4 | ||||
-rw-r--r-- | aten/src/ATen/core/aten_interned_strings.h | 1 | ||||
-rw-r--r-- | aten/src/ATen/native/LinearAlgebra.cpp | 4 | ||||
-rw-r--r-- | aten/src/ATen/native/ReduceOps.cpp | 70 | ||||
-rw-r--r-- | aten/src/ATen/native/ReduceOps.h | 3 | ||||
-rw-r--r-- | aten/src/ATen/native/SharedReduceOps.h | 129 | ||||
-rw-r--r-- | aten/src/ATen/native/cpu/ReduceOpsKernel.cpp | 57 | ||||
-rw-r--r-- | aten/src/ATen/native/cuda/ReduceOpsKernel.cu | 48 | ||||
-rw-r--r-- | aten/src/ATen/native/native_functions.yaml | 12 | ||||
-rw-r--r-- | test/test_cuda.py | 5 | ||||
-rw-r--r-- | tools/autograd/derivatives.yaml | 9 | ||||
-rw-r--r-- | tools/autograd/templates/Functions.cpp | 17 | ||||
-rw-r--r-- | torch/csrc/jit/passes/shape_analysis.cpp | 2 | ||||
-rw-r--r-- | torch/functional.py | 22 | ||||
-rw-r--r-- | torch/tensor.py | 4 |
18 files changed, 346 insertions, 93 deletions
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap index fc04feb98e..bc06d116f2 100644 --- a/aten/src/ATen/Declarations.cwrap +++ b/aten/src/ATen/Declarations.cwrap @@ -1520,46 +1520,6 @@ default: "false" ]] [[ - name: _th_norm - cname: norm - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - function - options: - - cname: normall - return: accreal - arguments: - - THTensor* self - - arg: real p - default: AS_REAL(2) -]] -[[ - name: _th_norm - types: - - floating_point - backends: - - CPU - - CUDA - variants: function - options: - - cname: norm - return: argument 0 - scalar_check: self_->dim() == 0 || (keepdim == false && self_->dim() == 1) - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - real p - - arg: long dim - wrap_dim: self - - arg: bool keepdim - default: "false" -]] -[[ name: _th_renorm cname: renorm types: diff --git a/aten/src/ATen/core/Tensor.h b/aten/src/ATen/core/Tensor.h index 2f02f92d70..6abf5fcbe9 100644 --- a/aten/src/ATen/core/Tensor.h +++ b/aten/src/ATen/core/Tensor.h @@ -496,8 +496,10 @@ public: Tensor var(IntList dim, bool unbiased=true, bool keepdim=false) const; Tensor view_as(const Tensor & other) const; Tensor where(const Tensor & condition, const Tensor & other) const; + Tensor norm(c10::optional<Scalar> p, ScalarType dtype) const; Tensor norm(Scalar p=2) const; - Tensor norm(c10::optional<Scalar> p, int64_t dim, bool keepdim=false) const; + Tensor norm(c10::optional<Scalar> p, IntList dim, bool keepdim, ScalarType dtype) const; + Tensor norm(c10::optional<Scalar> p, IntList dim, bool keepdim=false) const; Tensor clone() const; Tensor & resize_as_(const Tensor & the_template); Tensor pow(Scalar exponent) const; diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h index 33f3dc2a0b..e76a0eb358 100644 --- a/aten/src/ATen/core/TensorMethods.h +++ b/aten/src/ATen/core/TensorMethods.h @@ -673,10 +673,16 @@ inline Tensor Tensor::view_as(const Tensor & other) const { inline Tensor Tensor::where(const Tensor & condition, const Tensor & other) const { return type().where(condition, *this, other); } +inline Tensor Tensor::norm(c10::optional<Scalar> p, ScalarType dtype) const { + return type().norm(*this, p, dtype); +} inline Tensor Tensor::norm(Scalar p) const { return type().norm(*this, p); } -inline Tensor Tensor::norm(c10::optional<Scalar> p, int64_t dim, bool keepdim) const { +inline Tensor Tensor::norm(c10::optional<Scalar> p, IntList dim, bool keepdim, ScalarType dtype) const { + return type().norm(*this, p, dim, keepdim, dtype); +} +inline Tensor Tensor::norm(c10::optional<Scalar> p, IntList dim, bool keepdim) const { return type().norm(*this, p, dim, keepdim); } inline Tensor Tensor::clone() const { diff --git a/aten/src/ATen/core/Type.h b/aten/src/ATen/core/Type.h index 7de537ff9d..1676c51e29 100644 --- a/aten/src/ATen/core/Type.h +++ b/aten/src/ATen/core/Type.h @@ -391,8 +391,10 @@ struct CAFFE2_API Type { virtual Tensor var(const Tensor & self, IntList dim, bool unbiased, bool keepdim) const = 0; virtual Tensor view_as(const Tensor & self, const Tensor & other) const = 0; virtual Tensor where(const Tensor & condition, const Tensor & self, const Tensor & other) const = 0; + virtual Tensor norm(const Tensor & self, c10::optional<Scalar> p, ScalarType dtype) const = 0; virtual Tensor norm(const Tensor & self, Scalar p) const = 0; - virtual Tensor norm(const Tensor & self, c10::optional<Scalar> p, int64_t dim, bool keepdim) const = 0; + virtual Tensor norm(const Tensor & self, c10::optional<Scalar> p, IntList dim, bool keepdim, ScalarType dtype) const = 0; + virtual Tensor norm(const Tensor & self, c10::optional<Scalar> p, IntList dim, bool keepdim) const = 0; virtual Tensor clone(const Tensor & self) const = 0; virtual Tensor & resize_as_(Tensor & self, const Tensor & the_template) const = 0; virtual Tensor pow(const Tensor & self, Scalar exponent) const = 0; diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index 8e2fbd7a16..28f08e9a3e 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -151,7 +151,6 @@ _(aten, _th_max) \ _(aten, _th_median) \ _(aten, _th_min) \ _(aten, _th_mode) \ -_(aten, _th_norm) \ _(aten, _th_prod) \ _(aten, _th_sigmoid) \ _(aten, _th_std) \ diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 4283e5d170..64d40fcb3e 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -532,7 +532,7 @@ Tensor frobenius_norm(const Tensor& self, IntList dim, bool keepdim) { dim.size(), " dimensions instead."); if (dim.size() == 1) { - return at::norm(self, 2, dim[0], keepdim); + return at::norm(self, 2, dim, keepdim, self.type().scalarType()); } return at::sqrt(at::sum(self * self, dim, keepdim)); } @@ -548,7 +548,7 @@ Tensor &frobenius_norm_out( dim.size(), " dimensions instead."); if (dim.size() == 1) { - return at::norm_out(result, self, 2, dim[0], keepdim); + return at::norm_out(result, self, 2, dim, keepdim, self.type().scalarType()); } return at::sqrt_out(result, at::sum(self * self, dim, keepdim)); } diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index a2e2acfc68..7c8c119749 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -23,6 +23,7 @@ namespace native { DEFINE_DISPATCH(sum_stub); DEFINE_DISPATCH(std_var_stub); DEFINE_DISPATCH(prod_stub); +DEFINE_DISPATCH(norm_stub); DEFINE_DISPATCH(mean_stub); DEFINE_DISPATCH(and_stub); @@ -385,46 +386,69 @@ Tensor logsumexp(const Tensor &self, int64_t dim_, bool keepdim) { return at::native::logsumexp_out(result, self, dim, keepdim); } -Tensor& _norm_out_cpu(Tensor& result, const Tensor& self, Scalar p, int64_t dim_, bool keepdim) { - int64_t dim = maybe_wrap_dim(dim_, self.dim()); - if (_dimreduce_return_trivial(result, self, 0, dim, keepdim)) - return result; - return at::legacy::th::_th_norm_out(result, self, p, dim, keepdim); -} - -Tensor& norm_out(Tensor &result, const Tensor &self, optional<Scalar> pOpt, int64_t dim, bool keepdim) { +static Tensor& norm_out(Tensor &result, const Tensor &self, optional<Scalar> opt_p, + IntList dim, bool keepdim, optional<ScalarType> opt_dtype) { + auto p = opt_p.value_or(2.0); AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA, "norm only supports CPU AND CUDA backend, got: ", toString(self.type().backend())); - AT_CHECK(at::isFloatingType(self.type().scalarType()), "norm only supports floating-point dtypes"); - auto p = pOpt.value_or(2.0); - dim = maybe_wrap_dim(dim, self.dim()); - if (_dimreduce_return_trivial(result, self, 0, dim, keepdim)) { - return result; + + ScalarType scalarType = opt_dtype.has_value() ? opt_dtype.value() : self.type().scalarType(); + AT_CHECK( + at::isFloatingType(scalarType), + "Can only calculate the mean of floating types. Got ", + toString(scalarType), + " instead."); + + ScalarType dtype = get_dtype(result, self, opt_dtype, true); + auto iter = make_reduction("norm", result, self, dim, keepdim, dtype); + if (iter->numel() == 0) { + result.zero_(); } else { - if (self.is_cuda()) { - return at::legacy::th::_th_norm_out(result, self, p, dim, keepdim); - } else { - return _norm_out_cpu(result, self, p, dim, keepdim); - } + norm_stub(iter->device_type(), *iter, p); } + return result; } -Tensor _norm(const Tensor &self, Scalar p) { +static inline Tensor _norm(const Tensor &self, Scalar p) { if (self.is_sparse()) { return at::native_norm(self, p); } else { AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA, "norm only supports CPU AND CUDA backend, got: ", toString(self.type().backend())); AT_CHECK(at::isFloatingType(self.type().scalarType()), "norm only supports floating-point dtypes"); - return at::legacy::th::_th_norm(self, p); + + Tensor result; + return at::native::norm_out(result, self, p, {}, false, c10::nullopt); } } -Tensor norm(const Tensor& self, optional<Scalar> p, int64_t dim, bool keepdim) { - Tensor result = at::empty({0}, self.options()); - return at::native::norm_out(result, self, p, dim, keepdim); +Tensor &norm_out(Tensor& result, const Tensor& self, optional<Scalar> p, IntList dim, bool keepdim, ScalarType dtype) { + return at::native::norm_out(result, self, p, dim, keepdim, optional<ScalarType>(dtype)); +} + +Tensor &norm_out(Tensor& result, const Tensor& self, optional<Scalar> p, IntList dim, bool keepdim) { + return at::native::norm_out(result, self, p, dim, keepdim, c10::nullopt); +} + +static Tensor norm(const Tensor& self, optional<Scalar> p, IntList dim, bool keepdim, + optional<ScalarType> opt_dtype) { + Tensor result; + return at::native::norm_out(result, self, p, dim, keepdim, opt_dtype); +} + +Tensor norm(const Tensor& self, optional<Scalar> p, IntList dim, bool keepdim, ScalarType dtype) { + return at::native::norm(self, p, dim, keepdim, optional<ScalarType>(dtype)); +} + +Tensor norm(const Tensor& self, optional<Scalar> p, ScalarType dtype) { + return at::native::norm(self, p, {}, false, optional<ScalarType>(dtype)); +} + +Tensor norm(const Tensor& self, optional<Scalar> p, IntList dim, bool keepdim) { + return at::native::norm(self, p, dim, keepdim, c10::nullopt); } +// leave it so we support sparse tensors Tensor norm(const Tensor& self, Scalar p) { return at::native::_norm(self, p); } diff --git a/aten/src/ATen/native/ReduceOps.h b/aten/src/ATen/native/ReduceOps.h index 5fb6ac6896..f0c0231377 100644 --- a/aten/src/ATen/native/ReduceOps.h +++ b/aten/src/ATen/native/ReduceOps.h @@ -25,4 +25,7 @@ using reduce_norm_fn = void (*)(Tensor&, const Tensor&, Scalar, c10::optional<int64_t>); DECLARE_DISPATCH(reduce_norm_fn, norm_kernel); +using reduce_fn_flag = void(*)(TensorIterator &, Scalar); +DECLARE_DISPATCH(reduce_fn_flag, norm_stub); + }} // namespace at::native diff --git a/aten/src/ATen/native/SharedReduceOps.h b/aten/src/ATen/native/SharedReduceOps.h index 30b9a07131..2cd6afd533 100644 --- a/aten/src/ATen/native/SharedReduceOps.h +++ b/aten/src/ATen/native/SharedReduceOps.h @@ -13,6 +13,21 @@ #include <cmath> #define device_sqrt std::sqrt #endif +#if defined(__CUDACC__) || defined(__HIPCC__) +#define MAX(X, Y) ::max(X,Y) +#define MIN(X, Y) ::min(X,Y) +#else +#define MAX(X, Y) std::max(X,Y) +#define MIN(X, Y) std::min(X,Y) +#endif + +// ROCM hcc doesn't work well with using std:: in kernel functions +#if defined(__CUDA_ARCH__) || defined(__HIP_PLATFORM_HCC__) +#include <c10/cuda/CUDAMathCompat.h> +#define compat_pow c10::cuda::compat::pow +#else +#define compat_pow std::pow +#endif namespace at { namespace native { @@ -105,5 +120,119 @@ struct MeanOps { } }; +template <typename acc_t> +struct AbsMinOps { + + inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data) const { + return MIN(acc, std::abs(data)); + } + + inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { + return MIN(a, b); + } + + inline C10_DEVICE acc_t project(acc_t a) const { + return a; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const { + return WARP_SHFL_DOWN(data, offset); + } +#endif +}; + +template <typename acc_t> +struct AbsMaxOps { + + inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data) const { + return MAX(acc, std::abs(data)); + } + + inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { + return MAX(a, b); + } + + inline C10_DEVICE acc_t project(acc_t a) const { + return a; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const { + return WARP_SHFL_DOWN(data, offset); + } +#endif +}; + +template <typename acc_t> +struct NormOps { + acc_t norm; + + inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data) const { + return acc + compat_pow(std::abs(data), norm); + } + + inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { + return a + b; + } + + inline C10_DEVICE acc_t project(acc_t a) const { + return compat_pow(a, acc_t(1.0)/norm); + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const { + return WARP_SHFL_DOWN(data, offset); + } +#endif + + NormOps(acc_t norm): norm(norm) { + } +}; + +template <typename acc_t> +struct NormZeroOps { + inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data) const { + return acc + (data==acc_t(0) ? acc_t(0) : acc_t(1)); + } + + inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { + return a + b; + } + + inline C10_DEVICE acc_t project(acc_t a) const { + return a; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const { + return WARP_SHFL_DOWN(data, offset); + } +#endif +}; + +template <typename acc_t> +struct NormOneOps { + inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data) const { + return acc + std::abs(data); + } + + inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { + return a + b; + } + + inline C10_DEVICE acc_t project(acc_t a) const { + return a; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const { + return WARP_SHFL_DOWN(data, offset); + } +#endif +}; }} // namespace at::native + +#undef MAX +#undef MIN diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index 4fca2ce7cf..3841dec5d9 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -54,6 +54,62 @@ static void prod_kernel_impl(TensorIterator& iter) { }); } +static void norm_kernel_tensor_iterator_impl( + TensorIterator& iter, + Scalar p) { + float val; + if (p.isIntegral()) { + val = p.to<int64_t>(); + } else if (p.isFloatingPoint()) { + val = p.to<float>(); + } else { + AT_ERROR("norm_kernel_tensor_iterator_impl expects norm to be integer or float"); + } + + + if (val == 0) { + AT_DISPATCH_FLOATING_TYPES(iter.type(), "norm", [&] { + binary_kernel_reduce( + iter, + NormZeroOps<scalar_t>(), + scalar_t(0) + ); + }); + } else if (val == 1) { + AT_DISPATCH_FLOATING_TYPES(iter.type(), "norm", [&] { + binary_kernel_reduce( + iter, + NormOneOps<scalar_t>(), + scalar_t(0) + ); + }); + } else if (val == INFINITY) { + AT_DISPATCH_FLOATING_TYPES(iter.type(), "norm", [&] { + binary_kernel_reduce( + iter, + AbsMaxOps<scalar_t>(), + std::numeric_limits<scalar_t>::min() + ); + }); + } else if (val == -INFINITY) { + AT_DISPATCH_FLOATING_TYPES(iter.type(), "norm", [&] { + binary_kernel_reduce( + iter, + AbsMinOps<scalar_t>(), + std::numeric_limits<scalar_t>::max() + ); + }); + } else { + AT_DISPATCH_FLOATING_TYPES(iter.type(), "norm", [&] { + binary_kernel_reduce( + iter, + NormOps<scalar_t> { scalar_t(val) }, + scalar_t(0) + ); + }); + } +} + static void and_kernel_impl(TensorIterator& iter) { binary_kernel_reduce_vec( iter, @@ -84,6 +140,7 @@ REGISTER_DISPATCH(sum_stub, &sum_kernel_impl); REGISTER_DISPATCH(std_var_stub, &std_var_kernel_impl); REGISTER_DISPATCH(prod_stub, &prod_kernel_impl); REGISTER_DISPATCH(mean_stub, &mean_kernel_impl); +REGISTER_DISPATCH(norm_stub, &norm_kernel_tensor_iterator_impl); REGISTER_DISPATCH(and_stub, &and_kernel_impl); }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/ReduceOpsKernel.cu b/aten/src/ATen/native/cuda/ReduceOpsKernel.cu index 5856d3712c..1ce0fecd90 100644 --- a/aten/src/ATen/native/cuda/ReduceOpsKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceOpsKernel.cu @@ -14,17 +14,6 @@ namespace at { namespace native { -namespace { - -template <typename scalar_t> -struct SimpleCopy { - __device__ __forceinline__ scalar_t operator() (const scalar_t a) const { - return a; - } -}; - -} // namespace - template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t> void sum_kernel_impl(TensorIterator& iter) { gpu_reduce_kernel<scalar_t, out_t>(iter, func_wrapper<out_t> ([]GPU_LAMBDA(acc_t a, acc_t b) -> acc_t { @@ -86,6 +75,30 @@ void mean_kernel_impl<int16_t, int16_t, int16_t>(TensorIterator& iter) { } #endif // __HIPCC__ +template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t> +void norm_kernel_cuda_impl(TensorIterator& iter, Scalar val) { + float p; + if (val.isIntegral()) { + p = val.to<int64_t>(); + } else if (val.isFloatingPoint()) { + p = val.to<acc_t>(); + } else { + AT_ERROR("norm_kernel_cuda_impl expects norm to be integer or float"); + } + + if (p == static_cast<float>(0)) { + gpu_reduce_kernel<scalar_t, out_t>(iter, NormZeroOps<acc_t>(), 0); + } else if (p == static_cast<float>(1)) { + gpu_reduce_kernel<scalar_t, out_t>(iter, NormOneOps<acc_t>(), 0); + } else if (p == static_cast<float>(INFINITY)) { + gpu_reduce_kernel<scalar_t, out_t>(iter, AbsMaxOps<acc_t>(), std::numeric_limits<acc_t>::min()); + } else if (p == static_cast<float>(-INFINITY)) { + gpu_reduce_kernel<scalar_t, out_t>(iter, AbsMinOps<acc_t>(), std::numeric_limits<acc_t>::max()); + } else { + gpu_reduce_kernel<scalar_t, out_t>(iter, NormOps<acc_t>{ acc_t(p) }, 0); + } +} + static void sum_kernel_cuda(TensorIterator& iter) { if (iter.type().scalarType() == kHalf) { return sum_kernel_impl<at::Half, float>(iter); @@ -119,6 +132,18 @@ static void mean_kernel_cuda(TensorIterator& iter) { }); } +static void norm_kernel_cuda(TensorIterator& iter, Scalar p) { + if (iter.type().scalarType() == kHalf) { + return norm_kernel_cuda_impl<at::Half, float>(iter, p); + } else if (iter.type(1).scalarType() == kHalf && iter.type().scalarType() == kFloat) { + // type promotion that does cast and reduction in a single kernel + return norm_kernel_cuda_impl<at::Half, float, float>(iter, p); + } + AT_DISPATCH_FLOATING_TYPES(iter.type(), "norm", [&]() { + norm_kernel_cuda_impl<scalar_t>(iter, p); + }); +} + void and_kernel_cuda(TensorIterator& iter) { gpu_reduce_kernel<uint8_t, uint8_t>( iter, func_wrapper<uint8_t> ([]GPU_LAMBDA(uint8_t a, uint8_t b) -> uint8_t { @@ -130,6 +155,7 @@ REGISTER_DISPATCH(std_var_stub, &std_var_kernel_cuda); REGISTER_DISPATCH(sum_stub, &sum_kernel_cuda); REGISTER_DISPATCH(prod_stub, &prod_kernel_cuda); REGISTER_DISPATCH(mean_stub, &mean_kernel_cuda); +REGISTER_DISPATCH(norm_stub, &norm_kernel_cuda); REGISTER_DISPATCH(and_stub, &and_kernel_cuda); }} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 4f257eed17..69a7428d85 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1897,13 +1897,21 @@ SparseCPU: _sparse_sum_backward_cpu SparseCUDA: _sparse_sum_backward_cuda +- func: norm(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor + variants: function, method + - func: norm(Tensor self, Scalar p=2) -> Tensor variants: function, method -- func: norm(Tensor self, Scalar? p, int64_t dim, bool keepdim=false) -> Tensor +- func: norm(Tensor self, Scalar? p, IntList[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor variants: function, method -- func: norm_out(Tensor result, Tensor self, Scalar? p, int64_t dim, bool keepdim=false) -> Tensor +- func: norm(Tensor self, Scalar? p, IntList[1] dim, bool keepdim=false) -> Tensor + variants: function, method + +- func: norm_out(Tensor result, Tensor self, Scalar? p, IntList[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor + +- func: norm_out(Tensor result, Tensor self, Scalar? p, IntList[1] dim, bool keepdim=false) -> Tensor - func: frobenius_norm(Tensor self) -> Tensor variants: function diff --git a/test/test_cuda.py b/test/test_cuda.py index be5359571c..c22ddb0b40 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -2212,6 +2212,11 @@ class TestCuda(TestCase): self.assertGreater(b.norm().item(), 0) @skipIfRocm + def test_norm_type_conversion(self): + a = torch.ones(65536).cuda().half() + self.assertEqual(a.norm(p=0, dtype=torch.float32), 65536) + + @skipIfRocm # Test that wrap_with_cuda_memory_check successfully detects leak def test_cuda_memory_leak_detection(self): l = [] diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 5eef168ee2..b391a6b5e3 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -570,9 +570,16 @@ - name: norm(Tensor self, Scalar p) self: norm_backward(grad, self, p, result) -- name: norm(Tensor self, Scalar? p, int64_t dim, bool keepdim) +- name: norm(Tensor self, Scalar? p, IntList dim, bool keepdim) self: norm_backward(grad, self, p, result, dim, keepdim) +- name: norm(Tensor self, Scalar? p, ScalarType dtype) + self: norm_backward(grad, self.to(grad.type().scalarType()), p, result).to(self.type().scalarType()) + + +- name: norm(Tensor self, Scalar? p, IntList dim, bool keepdim, ScalarType dtype) + self: norm_backward(grad, self.to(grad.type().scalarType()), p, result, dim, keepdim).to(self.type().scalarType()) + - name: _pdist_forward(Tensor self, double p) self: _pdist_backward(grad, self, p, result) diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp index 4f1a158844..dd024ac625 100644 --- a/tools/autograd/templates/Functions.cpp +++ b/tools/autograd/templates/Functions.cpp @@ -114,10 +114,21 @@ Tensor norm_backward(const Tensor & grad, const Tensor & self, const optional<Sc return self_scaled * scale_v; } -Tensor norm_backward(Tensor grad, const Tensor & self, const optional<Scalar> & p_, Tensor norm, int64_t dim, bool keepdim) { +Tensor norm_backward(Tensor grad, const Tensor & self, const optional<Scalar> & p_, Tensor norm, IntList dim, bool keepdim) { + IntList sizes = self.sizes(); if (!keepdim && self.dim() != 0) { - grad = grad.unsqueeze(dim); - norm = norm.unsqueeze(dim); + if (dim.size()==1) { + grad = grad.unsqueeze(dim[0]); + norm = norm.unsqueeze(dim[0]); + } else { + auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, sizes.size()); + for (size_t i = 0; i < sizes.size(); i++){ + if (dims_to_unsqueeze[i]) { + grad = grad.unsqueeze(i); + norm = norm.unsqueeze(i); + } + } + } } return norm_backward(grad, self, p_, norm); } diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 12654657a9..e8b8425f6a 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -902,7 +902,6 @@ class ShapePropagator { "aten::argmin(Tensor self, int dim, bool keepdim) -> Tensor", "aten::max_values(Tensor self, int dim, bool keepdim) -> Tensor", "aten::min_values(Tensor self, int dim, bool keepdim) -> Tensor", - "aten::norm(Tensor self, Scalar? p, int dim, bool keepdim) -> Tensor", "aten::logsumexp(Tensor self, int dim, bool keepdim) -> Tensor", "aten::all(Tensor self, int dim, bool keepdim) -> Tensor", "aten::any(Tensor self, int dim, bool keepdim) -> Tensor", @@ -955,6 +954,7 @@ class ShapePropagator { static const register_formula_for multidim_reduce_ops{ { "aten::mean(Tensor self, int[] dim, bool keepdim) -> Tensor", + "aten::norm(Tensor self, Scalar? p, int[] dim, bool keepdim) -> Tensor", "aten::std(Tensor self, int[] dim, bool unbiased, bool keepdim) -> Tensor", "aten::var(Tensor self, int[] dim, bool unbiased, bool keepdim) -> Tensor", }, diff --git a/torch/functional.py b/torch/functional.py index 1ff28b161b..9847f34150 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -633,7 +633,7 @@ def cartesian_prod(*tensors): return torch._C._VariableFunctions.cartesian_prod(tensors) -def norm(input, p="fro", dim=None, keepdim=False, out=None): +def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): r"""Returns the matrix norm or vector norm of a given tensor. Args: @@ -662,6 +662,10 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None): :attr:`out` = ``None``. Default: ``False`` out (Tensor, optional): the output tensor. Ignored if :attr:`dim` = ``None`` and :attr:`out` = ``None``. + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. If specified, the input tensor is casted to + :attr:'dtype' while performing the operation. Default: None. + Example:: @@ -692,26 +696,36 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None): ndim = input.dim() # catch default case - if dim is None and out is None: + if dim is None and out is None and dtype is None: if p == "fro": return torch._C._VariableFunctions.frobenius_norm(input) elif p != "nuc": return torch._C._VariableFunctions.norm(input, p) if p == "fro": + if dtype is not None: + raise ValueError("dtype argument is not supported in frobenius norm") if dim is None: dim = tuple(range(ndim)) if out is None: return torch._C._VariableFunctions.frobenius_norm(input, dim, keepdim=keepdim) return torch._C._VariableFunctions.frobenius_norm(input, dim, keepdim=keepdim, out=out) elif p == "nuc": + if dtype is not None: + raise ValueError("dtype argument is not supported in nuclear norm") if out is None: torch._C._VariableFunctions.nuclear_norm(input, keepdim=keepdim) return torch._C._VariableFunctions.nuclear_norm(input, keepdim=keepdim, out=out) else: - if out is None: + if dim is None: + dim = tuple(range(ndim)) + if out is None and dtype is None: return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim) - return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, out=out) + elif out is None: + return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, dtype=dtype) + elif dtype is None: + return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, out=out) + return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, dtype=dtype, out=out) def chain_matmul(*matrices): diff --git a/torch/tensor.py b/torch/tensor.py index eedd6413cb..be936bf38e 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -255,9 +255,9 @@ class Tensor(torch._C._TensorBase): r"""See :func: `torch.argsort`""" return torch.argsort(self, dim, descending) - def norm(self, p="fro", dim=None, keepdim=False): + def norm(self, p="fro", dim=None, keepdim=False, dtype=None): r"""See :func: `torch.norm`""" - return torch.norm(self, p, dim, keepdim) + return torch.norm(self, p, dim, keepdim, dtype=dtype) def potrf(self, upper=True): r"""See :func:`torch.cholesky`""" |