summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjiej <jiej@nvidia.com>2019-01-16 22:12:13 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-01-16 22:15:25 -0800
commit7c56db73d5a9e1432dabc0231acad63575c3089e (patch)
treefd26ce91d87d61ab336165285ef04c28a5dd3f8d
parent55511004d17bd3e0e36e88efa6abdc9a5a03dec1 (diff)
downloadpytorch-7c56db73d5a9e1432dabc0231acad63575c3089e.tar.gz
pytorch-7c56db73d5a9e1432dabc0231acad63575c3089e.tar.bz2
pytorch-7c56db73d5a9e1432dabc0231acad63575c3089e.zip
Moving torch.norm to ATen using TensorIterator (#15414)
Summary: Adding supports for torch.nomr: i. multi dimensions for dim ii. dtype that specifies math/output tensor type Pull Request resolved: https://github.com/pytorch/pytorch/pull/15414 Differential Revision: D13702022 Pulled By: ezyang fbshipit-source-id: da2676f2b6aff988889b1539d0de8ecd4946823a
-rw-r--r--aten/src/ATen/Declarations.cwrap40
-rw-r--r--aten/src/ATen/core/Tensor.h4
-rw-r--r--aten/src/ATen/core/TensorMethods.h8
-rw-r--r--aten/src/ATen/core/Type.h4
-rw-r--r--aten/src/ATen/core/aten_interned_strings.h1
-rw-r--r--aten/src/ATen/native/LinearAlgebra.cpp4
-rw-r--r--aten/src/ATen/native/ReduceOps.cpp70
-rw-r--r--aten/src/ATen/native/ReduceOps.h3
-rw-r--r--aten/src/ATen/native/SharedReduceOps.h129
-rw-r--r--aten/src/ATen/native/cpu/ReduceOpsKernel.cpp57
-rw-r--r--aten/src/ATen/native/cuda/ReduceOpsKernel.cu48
-rw-r--r--aten/src/ATen/native/native_functions.yaml12
-rw-r--r--test/test_cuda.py5
-rw-r--r--tools/autograd/derivatives.yaml9
-rw-r--r--tools/autograd/templates/Functions.cpp17
-rw-r--r--torch/csrc/jit/passes/shape_analysis.cpp2
-rw-r--r--torch/functional.py22
-rw-r--r--torch/tensor.py4
18 files changed, 346 insertions, 93 deletions
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap
index fc04feb98e..bc06d116f2 100644
--- a/aten/src/ATen/Declarations.cwrap
+++ b/aten/src/ATen/Declarations.cwrap
@@ -1520,46 +1520,6 @@
default: "false"
]]
[[
- name: _th_norm
- cname: norm
- types:
- - floating_point
- backends:
- - CPU
- - CUDA
- variants:
- - function
- options:
- - cname: normall
- return: accreal
- arguments:
- - THTensor* self
- - arg: real p
- default: AS_REAL(2)
-]]
-[[
- name: _th_norm
- types:
- - floating_point
- backends:
- - CPU
- - CUDA
- variants: function
- options:
- - cname: norm
- return: argument 0
- scalar_check: self_->dim() == 0 || (keepdim == false && self_->dim() == 1)
- arguments:
- - arg: THTensor* result
- output: True
- - THTensor* self
- - real p
- - arg: long dim
- wrap_dim: self
- - arg: bool keepdim
- default: "false"
-]]
-[[
name: _th_renorm
cname: renorm
types:
diff --git a/aten/src/ATen/core/Tensor.h b/aten/src/ATen/core/Tensor.h
index 2f02f92d70..6abf5fcbe9 100644
--- a/aten/src/ATen/core/Tensor.h
+++ b/aten/src/ATen/core/Tensor.h
@@ -496,8 +496,10 @@ public:
Tensor var(IntList dim, bool unbiased=true, bool keepdim=false) const;
Tensor view_as(const Tensor & other) const;
Tensor where(const Tensor & condition, const Tensor & other) const;
+ Tensor norm(c10::optional<Scalar> p, ScalarType dtype) const;
Tensor norm(Scalar p=2) const;
- Tensor norm(c10::optional<Scalar> p, int64_t dim, bool keepdim=false) const;
+ Tensor norm(c10::optional<Scalar> p, IntList dim, bool keepdim, ScalarType dtype) const;
+ Tensor norm(c10::optional<Scalar> p, IntList dim, bool keepdim=false) const;
Tensor clone() const;
Tensor & resize_as_(const Tensor & the_template);
Tensor pow(Scalar exponent) const;
diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h
index 33f3dc2a0b..e76a0eb358 100644
--- a/aten/src/ATen/core/TensorMethods.h
+++ b/aten/src/ATen/core/TensorMethods.h
@@ -673,10 +673,16 @@ inline Tensor Tensor::view_as(const Tensor & other) const {
inline Tensor Tensor::where(const Tensor & condition, const Tensor & other) const {
return type().where(condition, *this, other);
}
+inline Tensor Tensor::norm(c10::optional<Scalar> p, ScalarType dtype) const {
+ return type().norm(*this, p, dtype);
+}
inline Tensor Tensor::norm(Scalar p) const {
return type().norm(*this, p);
}
-inline Tensor Tensor::norm(c10::optional<Scalar> p, int64_t dim, bool keepdim) const {
+inline Tensor Tensor::norm(c10::optional<Scalar> p, IntList dim, bool keepdim, ScalarType dtype) const {
+ return type().norm(*this, p, dim, keepdim, dtype);
+}
+inline Tensor Tensor::norm(c10::optional<Scalar> p, IntList dim, bool keepdim) const {
return type().norm(*this, p, dim, keepdim);
}
inline Tensor Tensor::clone() const {
diff --git a/aten/src/ATen/core/Type.h b/aten/src/ATen/core/Type.h
index 7de537ff9d..1676c51e29 100644
--- a/aten/src/ATen/core/Type.h
+++ b/aten/src/ATen/core/Type.h
@@ -391,8 +391,10 @@ struct CAFFE2_API Type {
virtual Tensor var(const Tensor & self, IntList dim, bool unbiased, bool keepdim) const = 0;
virtual Tensor view_as(const Tensor & self, const Tensor & other) const = 0;
virtual Tensor where(const Tensor & condition, const Tensor & self, const Tensor & other) const = 0;
+ virtual Tensor norm(const Tensor & self, c10::optional<Scalar> p, ScalarType dtype) const = 0;
virtual Tensor norm(const Tensor & self, Scalar p) const = 0;
- virtual Tensor norm(const Tensor & self, c10::optional<Scalar> p, int64_t dim, bool keepdim) const = 0;
+ virtual Tensor norm(const Tensor & self, c10::optional<Scalar> p, IntList dim, bool keepdim, ScalarType dtype) const = 0;
+ virtual Tensor norm(const Tensor & self, c10::optional<Scalar> p, IntList dim, bool keepdim) const = 0;
virtual Tensor clone(const Tensor & self) const = 0;
virtual Tensor & resize_as_(Tensor & self, const Tensor & the_template) const = 0;
virtual Tensor pow(const Tensor & self, Scalar exponent) const = 0;
diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h
index 8e2fbd7a16..28f08e9a3e 100644
--- a/aten/src/ATen/core/aten_interned_strings.h
+++ b/aten/src/ATen/core/aten_interned_strings.h
@@ -151,7 +151,6 @@ _(aten, _th_max) \
_(aten, _th_median) \
_(aten, _th_min) \
_(aten, _th_mode) \
-_(aten, _th_norm) \
_(aten, _th_prod) \
_(aten, _th_sigmoid) \
_(aten, _th_std) \
diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp
index 4283e5d170..64d40fcb3e 100644
--- a/aten/src/ATen/native/LinearAlgebra.cpp
+++ b/aten/src/ATen/native/LinearAlgebra.cpp
@@ -532,7 +532,7 @@ Tensor frobenius_norm(const Tensor& self, IntList dim, bool keepdim) {
dim.size(),
" dimensions instead.");
if (dim.size() == 1) {
- return at::norm(self, 2, dim[0], keepdim);
+ return at::norm(self, 2, dim, keepdim, self.type().scalarType());
}
return at::sqrt(at::sum(self * self, dim, keepdim));
}
@@ -548,7 +548,7 @@ Tensor &frobenius_norm_out(
dim.size(),
" dimensions instead.");
if (dim.size() == 1) {
- return at::norm_out(result, self, 2, dim[0], keepdim);
+ return at::norm_out(result, self, 2, dim, keepdim, self.type().scalarType());
}
return at::sqrt_out(result, at::sum(self * self, dim, keepdim));
}
diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp
index a2e2acfc68..7c8c119749 100644
--- a/aten/src/ATen/native/ReduceOps.cpp
+++ b/aten/src/ATen/native/ReduceOps.cpp
@@ -23,6 +23,7 @@ namespace native {
DEFINE_DISPATCH(sum_stub);
DEFINE_DISPATCH(std_var_stub);
DEFINE_DISPATCH(prod_stub);
+DEFINE_DISPATCH(norm_stub);
DEFINE_DISPATCH(mean_stub);
DEFINE_DISPATCH(and_stub);
@@ -385,46 +386,69 @@ Tensor logsumexp(const Tensor &self, int64_t dim_, bool keepdim) {
return at::native::logsumexp_out(result, self, dim, keepdim);
}
-Tensor& _norm_out_cpu(Tensor& result, const Tensor& self, Scalar p, int64_t dim_, bool keepdim) {
- int64_t dim = maybe_wrap_dim(dim_, self.dim());
- if (_dimreduce_return_trivial(result, self, 0, dim, keepdim))
- return result;
- return at::legacy::th::_th_norm_out(result, self, p, dim, keepdim);
-}
-
-Tensor& norm_out(Tensor &result, const Tensor &self, optional<Scalar> pOpt, int64_t dim, bool keepdim) {
+static Tensor& norm_out(Tensor &result, const Tensor &self, optional<Scalar> opt_p,
+ IntList dim, bool keepdim, optional<ScalarType> opt_dtype) {
+ auto p = opt_p.value_or(2.0);
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"norm only supports CPU AND CUDA backend, got: ", toString(self.type().backend()));
- AT_CHECK(at::isFloatingType(self.type().scalarType()), "norm only supports floating-point dtypes");
- auto p = pOpt.value_or(2.0);
- dim = maybe_wrap_dim(dim, self.dim());
- if (_dimreduce_return_trivial(result, self, 0, dim, keepdim)) {
- return result;
+
+ ScalarType scalarType = opt_dtype.has_value() ? opt_dtype.value() : self.type().scalarType();
+ AT_CHECK(
+ at::isFloatingType(scalarType),
+ "Can only calculate the mean of floating types. Got ",
+ toString(scalarType),
+ " instead.");
+
+ ScalarType dtype = get_dtype(result, self, opt_dtype, true);
+ auto iter = make_reduction("norm", result, self, dim, keepdim, dtype);
+ if (iter->numel() == 0) {
+ result.zero_();
} else {
- if (self.is_cuda()) {
- return at::legacy::th::_th_norm_out(result, self, p, dim, keepdim);
- } else {
- return _norm_out_cpu(result, self, p, dim, keepdim);
- }
+ norm_stub(iter->device_type(), *iter, p);
}
+ return result;
}
-Tensor _norm(const Tensor &self, Scalar p) {
+static inline Tensor _norm(const Tensor &self, Scalar p) {
if (self.is_sparse()) {
return at::native_norm(self, p);
} else {
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"norm only supports CPU AND CUDA backend, got: ", toString(self.type().backend()));
AT_CHECK(at::isFloatingType(self.type().scalarType()), "norm only supports floating-point dtypes");
- return at::legacy::th::_th_norm(self, p);
+
+ Tensor result;
+ return at::native::norm_out(result, self, p, {}, false, c10::nullopt);
}
}
-Tensor norm(const Tensor& self, optional<Scalar> p, int64_t dim, bool keepdim) {
- Tensor result = at::empty({0}, self.options());
- return at::native::norm_out(result, self, p, dim, keepdim);
+Tensor &norm_out(Tensor& result, const Tensor& self, optional<Scalar> p, IntList dim, bool keepdim, ScalarType dtype) {
+ return at::native::norm_out(result, self, p, dim, keepdim, optional<ScalarType>(dtype));
+}
+
+Tensor &norm_out(Tensor& result, const Tensor& self, optional<Scalar> p, IntList dim, bool keepdim) {
+ return at::native::norm_out(result, self, p, dim, keepdim, c10::nullopt);
+}
+
+static Tensor norm(const Tensor& self, optional<Scalar> p, IntList dim, bool keepdim,
+ optional<ScalarType> opt_dtype) {
+ Tensor result;
+ return at::native::norm_out(result, self, p, dim, keepdim, opt_dtype);
+}
+
+Tensor norm(const Tensor& self, optional<Scalar> p, IntList dim, bool keepdim, ScalarType dtype) {
+ return at::native::norm(self, p, dim, keepdim, optional<ScalarType>(dtype));
+}
+
+Tensor norm(const Tensor& self, optional<Scalar> p, ScalarType dtype) {
+ return at::native::norm(self, p, {}, false, optional<ScalarType>(dtype));
+}
+
+Tensor norm(const Tensor& self, optional<Scalar> p, IntList dim, bool keepdim) {
+ return at::native::norm(self, p, dim, keepdim, c10::nullopt);
}
+// leave it so we support sparse tensors
Tensor norm(const Tensor& self, Scalar p) {
return at::native::_norm(self, p);
}
diff --git a/aten/src/ATen/native/ReduceOps.h b/aten/src/ATen/native/ReduceOps.h
index 5fb6ac6896..f0c0231377 100644
--- a/aten/src/ATen/native/ReduceOps.h
+++ b/aten/src/ATen/native/ReduceOps.h
@@ -25,4 +25,7 @@ using reduce_norm_fn =
void (*)(Tensor&, const Tensor&, Scalar, c10::optional<int64_t>);
DECLARE_DISPATCH(reduce_norm_fn, norm_kernel);
+using reduce_fn_flag = void(*)(TensorIterator &, Scalar);
+DECLARE_DISPATCH(reduce_fn_flag, norm_stub);
+
}} // namespace at::native
diff --git a/aten/src/ATen/native/SharedReduceOps.h b/aten/src/ATen/native/SharedReduceOps.h
index 30b9a07131..2cd6afd533 100644
--- a/aten/src/ATen/native/SharedReduceOps.h
+++ b/aten/src/ATen/native/SharedReduceOps.h
@@ -13,6 +13,21 @@
#include <cmath>
#define device_sqrt std::sqrt
#endif
+#if defined(__CUDACC__) || defined(__HIPCC__)
+#define MAX(X, Y) ::max(X,Y)
+#define MIN(X, Y) ::min(X,Y)
+#else
+#define MAX(X, Y) std::max(X,Y)
+#define MIN(X, Y) std::min(X,Y)
+#endif
+
+// ROCM hcc doesn't work well with using std:: in kernel functions
+#if defined(__CUDA_ARCH__) || defined(__HIP_PLATFORM_HCC__)
+#include <c10/cuda/CUDAMathCompat.h>
+#define compat_pow c10::cuda::compat::pow
+#else
+#define compat_pow std::pow
+#endif
namespace at { namespace native {
@@ -105,5 +120,119 @@ struct MeanOps {
}
};
+template <typename acc_t>
+struct AbsMinOps {
+
+ inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data) const {
+ return MIN(acc, std::abs(data));
+ }
+
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
+ return MIN(a, b);
+ }
+
+ inline C10_DEVICE acc_t project(acc_t a) const {
+ return a;
+ }
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+ inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
+ return WARP_SHFL_DOWN(data, offset);
+ }
+#endif
+};
+
+template <typename acc_t>
+struct AbsMaxOps {
+
+ inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data) const {
+ return MAX(acc, std::abs(data));
+ }
+
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
+ return MAX(a, b);
+ }
+
+ inline C10_DEVICE acc_t project(acc_t a) const {
+ return a;
+ }
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+ inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
+ return WARP_SHFL_DOWN(data, offset);
+ }
+#endif
+};
+
+template <typename acc_t>
+struct NormOps {
+ acc_t norm;
+
+ inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data) const {
+ return acc + compat_pow(std::abs(data), norm);
+ }
+
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
+ return a + b;
+ }
+
+ inline C10_DEVICE acc_t project(acc_t a) const {
+ return compat_pow(a, acc_t(1.0)/norm);
+ }
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+ inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
+ return WARP_SHFL_DOWN(data, offset);
+ }
+#endif
+
+ NormOps(acc_t norm): norm(norm) {
+ }
+};
+
+template <typename acc_t>
+struct NormZeroOps {
+ inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data) const {
+ return acc + (data==acc_t(0) ? acc_t(0) : acc_t(1));
+ }
+
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
+ return a + b;
+ }
+
+ inline C10_DEVICE acc_t project(acc_t a) const {
+ return a;
+ }
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+ inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
+ return WARP_SHFL_DOWN(data, offset);
+ }
+#endif
+};
+
+template <typename acc_t>
+struct NormOneOps {
+ inline C10_DEVICE acc_t reduce(acc_t acc, acc_t data) const {
+ return acc + std::abs(data);
+ }
+
+ inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
+ return a + b;
+ }
+
+ inline C10_DEVICE acc_t project(acc_t a) const {
+ return a;
+ }
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+ inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
+ return WARP_SHFL_DOWN(data, offset);
+ }
+#endif
+};
}} // namespace at::native
+
+#undef MAX
+#undef MIN
diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
index 4fca2ce7cf..3841dec5d9 100644
--- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
@@ -54,6 +54,62 @@ static void prod_kernel_impl(TensorIterator& iter) {
});
}
+static void norm_kernel_tensor_iterator_impl(
+ TensorIterator& iter,
+ Scalar p) {
+ float val;
+ if (p.isIntegral()) {
+ val = p.to<int64_t>();
+ } else if (p.isFloatingPoint()) {
+ val = p.to<float>();
+ } else {
+ AT_ERROR("norm_kernel_tensor_iterator_impl expects norm to be integer or float");
+ }
+
+
+ if (val == 0) {
+ AT_DISPATCH_FLOATING_TYPES(iter.type(), "norm", [&] {
+ binary_kernel_reduce(
+ iter,
+ NormZeroOps<scalar_t>(),
+ scalar_t(0)
+ );
+ });
+ } else if (val == 1) {
+ AT_DISPATCH_FLOATING_TYPES(iter.type(), "norm", [&] {
+ binary_kernel_reduce(
+ iter,
+ NormOneOps<scalar_t>(),
+ scalar_t(0)
+ );
+ });
+ } else if (val == INFINITY) {
+ AT_DISPATCH_FLOATING_TYPES(iter.type(), "norm", [&] {
+ binary_kernel_reduce(
+ iter,
+ AbsMaxOps<scalar_t>(),
+ std::numeric_limits<scalar_t>::min()
+ );
+ });
+ } else if (val == -INFINITY) {
+ AT_DISPATCH_FLOATING_TYPES(iter.type(), "norm", [&] {
+ binary_kernel_reduce(
+ iter,
+ AbsMinOps<scalar_t>(),
+ std::numeric_limits<scalar_t>::max()
+ );
+ });
+ } else {
+ AT_DISPATCH_FLOATING_TYPES(iter.type(), "norm", [&] {
+ binary_kernel_reduce(
+ iter,
+ NormOps<scalar_t> { scalar_t(val) },
+ scalar_t(0)
+ );
+ });
+ }
+}
+
static void and_kernel_impl(TensorIterator& iter) {
binary_kernel_reduce_vec(
iter,
@@ -84,6 +140,7 @@ REGISTER_DISPATCH(sum_stub, &sum_kernel_impl);
REGISTER_DISPATCH(std_var_stub, &std_var_kernel_impl);
REGISTER_DISPATCH(prod_stub, &prod_kernel_impl);
REGISTER_DISPATCH(mean_stub, &mean_kernel_impl);
+REGISTER_DISPATCH(norm_stub, &norm_kernel_tensor_iterator_impl);
REGISTER_DISPATCH(and_stub, &and_kernel_impl);
}} // namespace at::native
diff --git a/aten/src/ATen/native/cuda/ReduceOpsKernel.cu b/aten/src/ATen/native/cuda/ReduceOpsKernel.cu
index 5856d3712c..1ce0fecd90 100644
--- a/aten/src/ATen/native/cuda/ReduceOpsKernel.cu
+++ b/aten/src/ATen/native/cuda/ReduceOpsKernel.cu
@@ -14,17 +14,6 @@
namespace at { namespace native {
-namespace {
-
-template <typename scalar_t>
-struct SimpleCopy {
- __device__ __forceinline__ scalar_t operator() (const scalar_t a) const {
- return a;
- }
-};
-
-} // namespace
-
template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t>
void sum_kernel_impl(TensorIterator& iter) {
gpu_reduce_kernel<scalar_t, out_t>(iter, func_wrapper<out_t> ([]GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
@@ -86,6 +75,30 @@ void mean_kernel_impl<int16_t, int16_t, int16_t>(TensorIterator& iter) {
}
#endif // __HIPCC__
+template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t>
+void norm_kernel_cuda_impl(TensorIterator& iter, Scalar val) {
+ float p;
+ if (val.isIntegral()) {
+ p = val.to<int64_t>();
+ } else if (val.isFloatingPoint()) {
+ p = val.to<acc_t>();
+ } else {
+ AT_ERROR("norm_kernel_cuda_impl expects norm to be integer or float");
+ }
+
+ if (p == static_cast<float>(0)) {
+ gpu_reduce_kernel<scalar_t, out_t>(iter, NormZeroOps<acc_t>(), 0);
+ } else if (p == static_cast<float>(1)) {
+ gpu_reduce_kernel<scalar_t, out_t>(iter, NormOneOps<acc_t>(), 0);
+ } else if (p == static_cast<float>(INFINITY)) {
+ gpu_reduce_kernel<scalar_t, out_t>(iter, AbsMaxOps<acc_t>(), std::numeric_limits<acc_t>::min());
+ } else if (p == static_cast<float>(-INFINITY)) {
+ gpu_reduce_kernel<scalar_t, out_t>(iter, AbsMinOps<acc_t>(), std::numeric_limits<acc_t>::max());
+ } else {
+ gpu_reduce_kernel<scalar_t, out_t>(iter, NormOps<acc_t>{ acc_t(p) }, 0);
+ }
+}
+
static void sum_kernel_cuda(TensorIterator& iter) {
if (iter.type().scalarType() == kHalf) {
return sum_kernel_impl<at::Half, float>(iter);
@@ -119,6 +132,18 @@ static void mean_kernel_cuda(TensorIterator& iter) {
});
}
+static void norm_kernel_cuda(TensorIterator& iter, Scalar p) {
+ if (iter.type().scalarType() == kHalf) {
+ return norm_kernel_cuda_impl<at::Half, float>(iter, p);
+ } else if (iter.type(1).scalarType() == kHalf && iter.type().scalarType() == kFloat) {
+ // type promotion that does cast and reduction in a single kernel
+ return norm_kernel_cuda_impl<at::Half, float, float>(iter, p);
+ }
+ AT_DISPATCH_FLOATING_TYPES(iter.type(), "norm", [&]() {
+ norm_kernel_cuda_impl<scalar_t>(iter, p);
+ });
+}
+
void and_kernel_cuda(TensorIterator& iter) {
gpu_reduce_kernel<uint8_t, uint8_t>(
iter, func_wrapper<uint8_t> ([]GPU_LAMBDA(uint8_t a, uint8_t b) -> uint8_t {
@@ -130,6 +155,7 @@ REGISTER_DISPATCH(std_var_stub, &std_var_kernel_cuda);
REGISTER_DISPATCH(sum_stub, &sum_kernel_cuda);
REGISTER_DISPATCH(prod_stub, &prod_kernel_cuda);
REGISTER_DISPATCH(mean_stub, &mean_kernel_cuda);
+REGISTER_DISPATCH(norm_stub, &norm_kernel_cuda);
REGISTER_DISPATCH(and_stub, &and_kernel_cuda);
}} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 4f257eed17..69a7428d85 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1897,13 +1897,21 @@
SparseCPU: _sparse_sum_backward_cpu
SparseCUDA: _sparse_sum_backward_cuda
+- func: norm(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor
+ variants: function, method
+
- func: norm(Tensor self, Scalar p=2) -> Tensor
variants: function, method
-- func: norm(Tensor self, Scalar? p, int64_t dim, bool keepdim=false) -> Tensor
+- func: norm(Tensor self, Scalar? p, IntList[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor
variants: function, method
-- func: norm_out(Tensor result, Tensor self, Scalar? p, int64_t dim, bool keepdim=false) -> Tensor
+- func: norm(Tensor self, Scalar? p, IntList[1] dim, bool keepdim=false) -> Tensor
+ variants: function, method
+
+- func: norm_out(Tensor result, Tensor self, Scalar? p, IntList[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor
+
+- func: norm_out(Tensor result, Tensor self, Scalar? p, IntList[1] dim, bool keepdim=false) -> Tensor
- func: frobenius_norm(Tensor self) -> Tensor
variants: function
diff --git a/test/test_cuda.py b/test/test_cuda.py
index be5359571c..c22ddb0b40 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -2212,6 +2212,11 @@ class TestCuda(TestCase):
self.assertGreater(b.norm().item(), 0)
@skipIfRocm
+ def test_norm_type_conversion(self):
+ a = torch.ones(65536).cuda().half()
+ self.assertEqual(a.norm(p=0, dtype=torch.float32), 65536)
+
+ @skipIfRocm
# Test that wrap_with_cuda_memory_check successfully detects leak
def test_cuda_memory_leak_detection(self):
l = []
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 5eef168ee2..b391a6b5e3 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -570,9 +570,16 @@
- name: norm(Tensor self, Scalar p)
self: norm_backward(grad, self, p, result)
-- name: norm(Tensor self, Scalar? p, int64_t dim, bool keepdim)
+- name: norm(Tensor self, Scalar? p, IntList dim, bool keepdim)
self: norm_backward(grad, self, p, result, dim, keepdim)
+- name: norm(Tensor self, Scalar? p, ScalarType dtype)
+ self: norm_backward(grad, self.to(grad.type().scalarType()), p, result).to(self.type().scalarType())
+
+
+- name: norm(Tensor self, Scalar? p, IntList dim, bool keepdim, ScalarType dtype)
+ self: norm_backward(grad, self.to(grad.type().scalarType()), p, result, dim, keepdim).to(self.type().scalarType())
+
- name: _pdist_forward(Tensor self, double p)
self: _pdist_backward(grad, self, p, result)
diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp
index 4f1a158844..dd024ac625 100644
--- a/tools/autograd/templates/Functions.cpp
+++ b/tools/autograd/templates/Functions.cpp
@@ -114,10 +114,21 @@ Tensor norm_backward(const Tensor & grad, const Tensor & self, const optional<Sc
return self_scaled * scale_v;
}
-Tensor norm_backward(Tensor grad, const Tensor & self, const optional<Scalar> & p_, Tensor norm, int64_t dim, bool keepdim) {
+Tensor norm_backward(Tensor grad, const Tensor & self, const optional<Scalar> & p_, Tensor norm, IntList dim, bool keepdim) {
+ IntList sizes = self.sizes();
if (!keepdim && self.dim() != 0) {
- grad = grad.unsqueeze(dim);
- norm = norm.unsqueeze(dim);
+ if (dim.size()==1) {
+ grad = grad.unsqueeze(dim[0]);
+ norm = norm.unsqueeze(dim[0]);
+ } else {
+ auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, sizes.size());
+ for (size_t i = 0; i < sizes.size(); i++){
+ if (dims_to_unsqueeze[i]) {
+ grad = grad.unsqueeze(i);
+ norm = norm.unsqueeze(i);
+ }
+ }
+ }
}
return norm_backward(grad, self, p_, norm);
}
diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp
index 12654657a9..e8b8425f6a 100644
--- a/torch/csrc/jit/passes/shape_analysis.cpp
+++ b/torch/csrc/jit/passes/shape_analysis.cpp
@@ -902,7 +902,6 @@ class ShapePropagator {
"aten::argmin(Tensor self, int dim, bool keepdim) -> Tensor",
"aten::max_values(Tensor self, int dim, bool keepdim) -> Tensor",
"aten::min_values(Tensor self, int dim, bool keepdim) -> Tensor",
- "aten::norm(Tensor self, Scalar? p, int dim, bool keepdim) -> Tensor",
"aten::logsumexp(Tensor self, int dim, bool keepdim) -> Tensor",
"aten::all(Tensor self, int dim, bool keepdim) -> Tensor",
"aten::any(Tensor self, int dim, bool keepdim) -> Tensor",
@@ -955,6 +954,7 @@ class ShapePropagator {
static const register_formula_for multidim_reduce_ops{
{
"aten::mean(Tensor self, int[] dim, bool keepdim) -> Tensor",
+ "aten::norm(Tensor self, Scalar? p, int[] dim, bool keepdim) -> Tensor",
"aten::std(Tensor self, int[] dim, bool unbiased, bool keepdim) -> Tensor",
"aten::var(Tensor self, int[] dim, bool unbiased, bool keepdim) -> Tensor",
},
diff --git a/torch/functional.py b/torch/functional.py
index 1ff28b161b..9847f34150 100644
--- a/torch/functional.py
+++ b/torch/functional.py
@@ -633,7 +633,7 @@ def cartesian_prod(*tensors):
return torch._C._VariableFunctions.cartesian_prod(tensors)
-def norm(input, p="fro", dim=None, keepdim=False, out=None):
+def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):
r"""Returns the matrix norm or vector norm of a given tensor.
Args:
@@ -662,6 +662,10 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None):
:attr:`out` = ``None``. Default: ``False``
out (Tensor, optional): the output tensor. Ignored if
:attr:`dim` = ``None`` and :attr:`out` = ``None``.
+ dtype (:class:`torch.dtype`, optional): the desired data type of
+ returned tensor. If specified, the input tensor is casted to
+ :attr:'dtype' while performing the operation. Default: None.
+
Example::
@@ -692,26 +696,36 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None):
ndim = input.dim()
# catch default case
- if dim is None and out is None:
+ if dim is None and out is None and dtype is None:
if p == "fro":
return torch._C._VariableFunctions.frobenius_norm(input)
elif p != "nuc":
return torch._C._VariableFunctions.norm(input, p)
if p == "fro":
+ if dtype is not None:
+ raise ValueError("dtype argument is not supported in frobenius norm")
if dim is None:
dim = tuple(range(ndim))
if out is None:
return torch._C._VariableFunctions.frobenius_norm(input, dim, keepdim=keepdim)
return torch._C._VariableFunctions.frobenius_norm(input, dim, keepdim=keepdim, out=out)
elif p == "nuc":
+ if dtype is not None:
+ raise ValueError("dtype argument is not supported in nuclear norm")
if out is None:
torch._C._VariableFunctions.nuclear_norm(input, keepdim=keepdim)
return torch._C._VariableFunctions.nuclear_norm(input, keepdim=keepdim, out=out)
else:
- if out is None:
+ if dim is None:
+ dim = tuple(range(ndim))
+ if out is None and dtype is None:
return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim)
- return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, out=out)
+ elif out is None:
+ return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, dtype=dtype)
+ elif dtype is None:
+ return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, out=out)
+ return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, dtype=dtype, out=out)
def chain_matmul(*matrices):
diff --git a/torch/tensor.py b/torch/tensor.py
index eedd6413cb..be936bf38e 100644
--- a/torch/tensor.py
+++ b/torch/tensor.py
@@ -255,9 +255,9 @@ class Tensor(torch._C._TensorBase):
r"""See :func: `torch.argsort`"""
return torch.argsort(self, dim, descending)
- def norm(self, p="fro", dim=None, keepdim=False):
+ def norm(self, p="fro", dim=None, keepdim=False, dtype=None):
r"""See :func: `torch.norm`"""
- return torch.norm(self, p, dim, keepdim)
+ return torch.norm(self, p, dim, keepdim, dtype=dtype)
def potrf(self, upper=True):
r"""See :func:`torch.cholesky`"""