diff options
author | jiej <jiej@nvidia.com> | 2019-01-16 22:12:13 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-01-16 22:15:25 -0800 |
commit | 7c56db73d5a9e1432dabc0231acad63575c3089e (patch) | |
tree | fd26ce91d87d61ab336165285ef04c28a5dd3f8d /torch/functional.py | |
parent | 55511004d17bd3e0e36e88efa6abdc9a5a03dec1 (diff) | |
download | pytorch-7c56db73d5a9e1432dabc0231acad63575c3089e.tar.gz pytorch-7c56db73d5a9e1432dabc0231acad63575c3089e.tar.bz2 pytorch-7c56db73d5a9e1432dabc0231acad63575c3089e.zip |
Moving torch.norm to ATen using TensorIterator (#15414)
Summary:
Adding supports for torch.nomr:
i. multi dimensions for dim
ii. dtype that specifies math/output tensor type
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15414
Differential Revision: D13702022
Pulled By: ezyang
fbshipit-source-id: da2676f2b6aff988889b1539d0de8ecd4946823a
Diffstat (limited to 'torch/functional.py')
-rw-r--r-- | torch/functional.py | 22 |
1 files changed, 18 insertions, 4 deletions
diff --git a/torch/functional.py b/torch/functional.py index 1ff28b161b..9847f34150 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -633,7 +633,7 @@ def cartesian_prod(*tensors): return torch._C._VariableFunctions.cartesian_prod(tensors) -def norm(input, p="fro", dim=None, keepdim=False, out=None): +def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): r"""Returns the matrix norm or vector norm of a given tensor. Args: @@ -662,6 +662,10 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None): :attr:`out` = ``None``. Default: ``False`` out (Tensor, optional): the output tensor. Ignored if :attr:`dim` = ``None`` and :attr:`out` = ``None``. + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. If specified, the input tensor is casted to + :attr:'dtype' while performing the operation. Default: None. + Example:: @@ -692,26 +696,36 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None): ndim = input.dim() # catch default case - if dim is None and out is None: + if dim is None and out is None and dtype is None: if p == "fro": return torch._C._VariableFunctions.frobenius_norm(input) elif p != "nuc": return torch._C._VariableFunctions.norm(input, p) if p == "fro": + if dtype is not None: + raise ValueError("dtype argument is not supported in frobenius norm") if dim is None: dim = tuple(range(ndim)) if out is None: return torch._C._VariableFunctions.frobenius_norm(input, dim, keepdim=keepdim) return torch._C._VariableFunctions.frobenius_norm(input, dim, keepdim=keepdim, out=out) elif p == "nuc": + if dtype is not None: + raise ValueError("dtype argument is not supported in nuclear norm") if out is None: torch._C._VariableFunctions.nuclear_norm(input, keepdim=keepdim) return torch._C._VariableFunctions.nuclear_norm(input, keepdim=keepdim, out=out) else: - if out is None: + if dim is None: + dim = tuple(range(ndim)) + if out is None and dtype is None: return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim) - return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, out=out) + elif out is None: + return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, dtype=dtype) + elif dtype is None: + return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, out=out) + return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, dtype=dtype, out=out) def chain_matmul(*matrices): |