summaryrefslogtreecommitdiff
path: root/torch/functional.py
diff options
context:
space:
mode:
authorjiej <jiej@nvidia.com>2019-01-16 22:12:13 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-01-16 22:15:25 -0800
commit7c56db73d5a9e1432dabc0231acad63575c3089e (patch)
treefd26ce91d87d61ab336165285ef04c28a5dd3f8d /torch/functional.py
parent55511004d17bd3e0e36e88efa6abdc9a5a03dec1 (diff)
downloadpytorch-7c56db73d5a9e1432dabc0231acad63575c3089e.tar.gz
pytorch-7c56db73d5a9e1432dabc0231acad63575c3089e.tar.bz2
pytorch-7c56db73d5a9e1432dabc0231acad63575c3089e.zip
Moving torch.norm to ATen using TensorIterator (#15414)
Summary: Adding supports for torch.nomr: i. multi dimensions for dim ii. dtype that specifies math/output tensor type Pull Request resolved: https://github.com/pytorch/pytorch/pull/15414 Differential Revision: D13702022 Pulled By: ezyang fbshipit-source-id: da2676f2b6aff988889b1539d0de8ecd4946823a
Diffstat (limited to 'torch/functional.py')
-rw-r--r--torch/functional.py22
1 files changed, 18 insertions, 4 deletions
diff --git a/torch/functional.py b/torch/functional.py
index 1ff28b161b..9847f34150 100644
--- a/torch/functional.py
+++ b/torch/functional.py
@@ -633,7 +633,7 @@ def cartesian_prod(*tensors):
return torch._C._VariableFunctions.cartesian_prod(tensors)
-def norm(input, p="fro", dim=None, keepdim=False, out=None):
+def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):
r"""Returns the matrix norm or vector norm of a given tensor.
Args:
@@ -662,6 +662,10 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None):
:attr:`out` = ``None``. Default: ``False``
out (Tensor, optional): the output tensor. Ignored if
:attr:`dim` = ``None`` and :attr:`out` = ``None``.
+ dtype (:class:`torch.dtype`, optional): the desired data type of
+ returned tensor. If specified, the input tensor is casted to
+ :attr:'dtype' while performing the operation. Default: None.
+
Example::
@@ -692,26 +696,36 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None):
ndim = input.dim()
# catch default case
- if dim is None and out is None:
+ if dim is None and out is None and dtype is None:
if p == "fro":
return torch._C._VariableFunctions.frobenius_norm(input)
elif p != "nuc":
return torch._C._VariableFunctions.norm(input, p)
if p == "fro":
+ if dtype is not None:
+ raise ValueError("dtype argument is not supported in frobenius norm")
if dim is None:
dim = tuple(range(ndim))
if out is None:
return torch._C._VariableFunctions.frobenius_norm(input, dim, keepdim=keepdim)
return torch._C._VariableFunctions.frobenius_norm(input, dim, keepdim=keepdim, out=out)
elif p == "nuc":
+ if dtype is not None:
+ raise ValueError("dtype argument is not supported in nuclear norm")
if out is None:
torch._C._VariableFunctions.nuclear_norm(input, keepdim=keepdim)
return torch._C._VariableFunctions.nuclear_norm(input, keepdim=keepdim, out=out)
else:
- if out is None:
+ if dim is None:
+ dim = tuple(range(ndim))
+ if out is None and dtype is None:
return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim)
- return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, out=out)
+ elif out is None:
+ return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, dtype=dtype)
+ elif dtype is None:
+ return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, out=out)
+ return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, dtype=dtype, out=out)
def chain_matmul(*matrices):