diff options
author | Edward Yang <ezyang@fb.com> | 2019-02-04 19:23:57 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-04 19:30:04 -0800 |
commit | 6c04224cd84385910021580176413a98839cec72 (patch) | |
tree | 60739d235527165884dad0417f24efc16f253d14 | |
parent | 1409a2afc8d81d911dbcfc2d0b210cf62658237b (diff) | |
download | pytorch-6c04224cd84385910021580176413a98839cec72.tar.gz pytorch-6c04224cd84385910021580176413a98839cec72.tar.bz2 pytorch-6c04224cd84385910021580176413a98839cec72.zip |
Revert "Move outplace ops to ATen (#12413)" (#16731)
Summary:
This reverts commit f660d3ae19decc64390e894fbaf8de80d87585e0.
cc zasdfgbnm
Reasoning at https://github.com/pytorch/pytorch/pull/12413#issuecomment-460424129
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16731
Differential Revision: D13948022
Pulled By: ezyang
fbshipit-source-id: b10669cf03679e306850314b7b5b08bed0839e19
-rw-r--r-- | aten/src/ATen/core/Tensor.h | 7 | ||||
-rw-r--r-- | aten/src/ATen/core/TensorMethods.h | 21 | ||||
-rw-r--r-- | aten/src/ATen/core/Type.h | 7 | ||||
-rw-r--r-- | aten/src/ATen/native/Indexing.cpp | 32 | ||||
-rw-r--r-- | aten/src/ATen/native/native_functions.yaml | 21 | ||||
-rw-r--r-- | test/common_methods_invocations.py | 4 | ||||
-rw-r--r-- | test/test_jit.py | 2 | ||||
-rw-r--r-- | test/test_torch.py | 12 | ||||
-rw-r--r-- | tools/pyi/gen_pyi.py | 7 | ||||
-rw-r--r-- | torch/_tensor_docs.py | 49 | ||||
-rw-r--r-- | torch/tensor.py | 35 |
11 files changed, 41 insertions, 156 deletions
diff --git a/aten/src/ATen/core/Tensor.h b/aten/src/ATen/core/Tensor.h index cf3e05d30a..ba4ccb4393 100644 --- a/aten/src/ATen/core/Tensor.h +++ b/aten/src/ATen/core/Tensor.h @@ -389,15 +389,8 @@ class CAFFE2_API Tensor { Tensor irfft(int64_t signal_ndim, bool normalized=false, bool onesided=true, IntList signal_sizes={}) const; Tensor index(TensorList indices) const; Tensor & index_copy_(int64_t dim, const Tensor & index, const Tensor & source); - Tensor index_copy(int64_t dim, const Tensor & index, const Tensor & source) const; Tensor index_put(TensorList indices, const Tensor & values, bool accumulate=false) const; Tensor & index_put_(TensorList indices, const Tensor & values, bool accumulate=false); - Tensor index_add(int64_t dim, const Tensor & index, const Tensor & source) const; - Tensor index_fill(int64_t dim, const Tensor & index, Scalar source) const; - Tensor scatter(int64_t dim, const Tensor & index, const Tensor & source) const; - Tensor scatter_add(int64_t dim, const Tensor & index, const Tensor & source) const; - Tensor masked_scatter(const Tensor & mask, const Tensor & source) const; - Tensor masked_fill(const Tensor & mask, Scalar source) const; Tensor inverse() const; Tensor isclose(const Tensor & other, double rtol=1e-05, double atol=1e-08, bool equal_nan=false) const; bool is_distributed() const; diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h index 5a5389f6c4..e76a0eb358 100644 --- a/aten/src/ATen/core/TensorMethods.h +++ b/aten/src/ATen/core/TensorMethods.h @@ -310,33 +310,12 @@ inline Tensor Tensor::index(TensorList indices) const { inline Tensor & Tensor::index_copy_(int64_t dim, const Tensor & index, const Tensor & source) { return type().index_copy_(*this, dim, index, source); } -inline Tensor Tensor::index_copy(int64_t dim, const Tensor & index, const Tensor & source) const { - return type().index_copy(*this, dim, index, source); -} inline Tensor Tensor::index_put(TensorList indices, const Tensor & values, bool accumulate) const { return type().index_put(*this, indices, values, accumulate); } inline Tensor & Tensor::index_put_(TensorList indices, const Tensor & values, bool accumulate) { return type().index_put_(*this, indices, values, accumulate); } -inline Tensor Tensor::index_add(int64_t dim, const Tensor & index, const Tensor & source) const { - return type().index_add(*this, dim, index, source); -} -inline Tensor Tensor::index_fill(int64_t dim, const Tensor & index, Scalar source) const { - return type().index_fill(*this, dim, index, source); -} -inline Tensor Tensor::scatter(int64_t dim, const Tensor & index, const Tensor & source) const { - return type().scatter(*this, dim, index, source); -} -inline Tensor Tensor::scatter_add(int64_t dim, const Tensor & index, const Tensor & source) const { - return type().scatter_add(*this, dim, index, source); -} -inline Tensor Tensor::masked_scatter(const Tensor & mask, const Tensor & source) const { - return type().masked_scatter(*this, mask, source); -} -inline Tensor Tensor::masked_fill(const Tensor & mask, Scalar source) const { - return type().masked_fill(*this, mask, source); -} inline Tensor Tensor::inverse() const { return type().inverse(*this); } diff --git a/aten/src/ATen/core/Type.h b/aten/src/ATen/core/Type.h index f7b20e7e53..30601cd3cf 100644 --- a/aten/src/ATen/core/Type.h +++ b/aten/src/ATen/core/Type.h @@ -271,15 +271,8 @@ struct CAFFE2_API Type { virtual Tensor irfft(const Tensor & self, int64_t signal_ndim, bool normalized, bool onesided, IntList signal_sizes) const = 0; virtual Tensor index(const Tensor & self, TensorList indices) const = 0; virtual Tensor & index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) const = 0; - virtual Tensor index_copy(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) const = 0; virtual Tensor index_put(const Tensor & self, TensorList indices, const Tensor & values, bool accumulate) const = 0; virtual Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & values, bool accumulate) const = 0; - virtual Tensor index_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) const = 0; - virtual Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, Scalar source) const = 0; - virtual Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) const = 0; - virtual Tensor scatter_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) const = 0; - virtual Tensor masked_scatter(const Tensor & self, const Tensor & mask, const Tensor & source) const = 0; - virtual Tensor masked_fill(const Tensor & self, const Tensor & mask, Scalar source) const = 0; virtual Tensor inverse(const Tensor & self) const = 0; virtual Tensor isclose(const Tensor & self, const Tensor & other, double rtol, double atol, bool equal_nan) const = 0; virtual bool is_distributed(const Tensor & self) const = 0; diff --git a/aten/src/ATen/native/Indexing.cpp b/aten/src/ATen/native/Indexing.cpp index 0a299a6830..eb6c764202 100644 --- a/aten/src/ATen/native/Indexing.cpp +++ b/aten/src/ATen/native/Indexing.cpp @@ -497,36 +497,4 @@ Tensor & index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Ten return at::legacy::th::_th_index_copy_(self, dim, index, source); } -Tensor index_copy(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) { - return self.clone().index_copy_(dim, index, source); -} - -Tensor index_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) { - return self.clone().index_add_(dim, index, source); -} - -Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, Scalar source) { - return self.clone().index_fill_(dim, index, source); -} - -Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) { - return self.clone().scatter_(dim, index, source); -} - -Tensor scatter_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) { - return self.clone().scatter_add_(dim, index, source); -} - -Tensor masked_scatter(const Tensor & self, const Tensor & mask, const Tensor & source) { - Tensor _mask, _self; - std::tie(_mask, _self) = expand_outplace(mask, self); - return _self.clone().masked_scatter_(_mask, source); -} - -Tensor masked_fill(const Tensor & self, const Tensor & mask, Scalar source) { - Tensor _mask, _self; - std::tie(_mask, _self) = expand_outplace(mask, self); - return _self.clone().masked_fill_(mask, source); -} - }} // at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 91cd1284fa..b980982158 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1102,9 +1102,6 @@ - func: index_copy_(Tensor(a!) self, int dim, IndexTensor index, Tensor source) -> Tensor(a!) variants: method -- func: index_copy(Tensor self, int64_t dim, IndexTensor index, Tensor source) -> Tensor - variants: function, method - - func: index_put(Tensor self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor matches_jit_signature: True variants: function, method @@ -1113,24 +1110,6 @@ matches_jit_signature: True variants: function, method -- func: index_add(Tensor self, int64_t dim, IndexTensor index, Tensor source) -> Tensor - variants: function, method - -- func: index_fill(Tensor self, int64_t dim, IndexTensor index, Scalar source) -> Tensor - variants: function, method - -- func: scatter(Tensor self, int64_t dim, IndexTensor index, Tensor source) -> Tensor - variants: function, method - -- func: scatter_add(Tensor self, int64_t dim, IndexTensor index, Tensor source) -> Tensor - variants: function, method - -- func: masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor - variants: function, method - -- func: masked_fill(Tensor self, Tensor mask, Scalar source) -> Tensor - variants: function, method - - func: instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor matches_jit_signature: True variants: function diff --git a/test/common_methods_invocations.py b/test/common_methods_invocations.py index b998c6fef6..41c8763e16 100644 --- a/test/common_methods_invocations.py +++ b/test/common_methods_invocations.py @@ -744,7 +744,7 @@ def method_tests(): ('masked_select', (), (mask_not_all_zeros((M, M)),), 'scalar_broadcast_lhs'), ('masked_fill', (M, M), (torch.ByteTensor(M, M).bernoulli_(), 10)), ('masked_fill', (M, M), (torch.ByteTensor(M, M).bernoulli_(), torch.tensor(10)), 'tensor'), - ('masked_fill', (M,), (torch.ByteTensor(M, M).bernoulli_(), 10), 'broadcast_lhs'), + # no lhs or all broadcast on masked_fill or masked_scatter because it's always inplace ('masked_fill', (M, M), (torch.ByteTensor(M,).bernoulli_(), 10), 'broadcast_rhs'), ('masked_fill', (), (torch.tensor(0, dtype=torch.uint8, requires_grad=False).bernoulli_(), 10), 'scalar'), ('masked_fill', (), (torch.tensor(0, dtype=torch.uint8, requires_grad=False).bernoulli_(), torch.tensor(10)), @@ -752,8 +752,6 @@ def method_tests(): ('masked_fill', (M, M), (torch.tensor(0, dtype=torch.uint8, requires_grad=False).bernoulli_(), 10), 'scalar_broadcast_rhs'), ('masked_scatter', (M, M), (torch.ByteTensor(M, M).bernoulli_(), (M, M))), - ('masked_scatter', (M,), (torch.ByteTensor(M, M).bernoulli_(), (M, M)), - 'broadcast_lhs'), ('masked_scatter', (M, M), (torch.ByteTensor(M,).bernoulli_(), (M, M)), 'broadcast_rhs'), ('masked_scatter', (M, M), (bernoulli_scalar(), (M, M)), 'scalar'), diff --git a/test/test_jit.py b/test/test_jit.py index 42ec3d34c3..b0ef8326da 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -8834,7 +8834,7 @@ a") def test_builtin_error_messsage(self): from torch.nn.modules.utils import _single, _pair, _triple, _quadruple - with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"): + with self.assertRaisesRegex(RuntimeError, "aten::masked_fill_"): @torch.jit.script def close_match(x): return x.masked_fill(True) diff --git a/test/test_torch.py b/test/test_torch.py index 1c538d5a2d..f3e8d8234c 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -3450,7 +3450,6 @@ class _TestTorchMixin(object): for fn in fns: (dims_small, dims_large, dims_full) = self._select_broadcastable_dims() - full1d = cast(torch.randn(*dims_full).flatten().float()) small = cast(torch.randn(*dims_small).float()) large = cast(torch.randn(*dims_large).float()) small_expanded = small.expand(*dims_full) @@ -3467,7 +3466,8 @@ class _TestTorchMixin(object): # map and map2 are not implementd on CUDA tensors continue - if hasattr(large_expanded, fn): + # TODO: fix masked_scatter and masked_fill broadcasting + if hasattr(large_expanded, fn) and fn not in ['masked_scatter', 'masked_fill']: # run through tensor versions of functions # and verify fully expanded inputs give same results expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded} @@ -3477,10 +3477,6 @@ class _TestTorchMixin(object): return myfn(t1, 0.5) elif fn == "masked_select": return myfn(t1 < 0) - elif fn == "masked_scatter": - return myfn(t1 < 0.5, full1d) - elif fn == "masked_fill": - return myfn(t1 < 0.5, 1.0) elif fn in fns_3_args: return myfn(1, t1, t2) else: @@ -3508,7 +3504,7 @@ class _TestTorchMixin(object): elif fn == "masked_select": return fntorch(t1, t2 < 0) elif fn == "masked_scatter": - return fntorch(t1, t2 < 0.5, full1d) + return fntorch(t1, t2 < 0.5, cast(torch.arange(1, t1.nelement() + 1).float())) elif fn == "masked_fill": return fntorch(t1, t2 < 0.5, 1.0) elif fn in fns_3_args: @@ -3539,7 +3535,7 @@ class _TestTorchMixin(object): if fn == "lerp": return t0_fn(t1, 0.5) elif fn == "masked_scatter": - return t0_fn(t1 < 0.5, full1d) + return t0_fn(t1 < 0.5, cast(torch.arange(1, t0.nelement() + 1).float())) elif fn == "masked_fill": return t0_fn(t1 < 0.5, 1.0) elif fn == "map": diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index 827c57385c..6905f80f47 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -76,13 +76,6 @@ blacklist = [ 'tensordot', 'norm', 'split', - 'index_add', - 'index_copy', - 'index_fill', - 'scatter', - 'scatter_add', - 'masked_scatter', - 'masked_fill', # These are handled specially by python_arg_parser.cpp 'add', 'add_', diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index a94a633e34..ba0e31f8fb 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -2962,55 +2962,6 @@ pinverse() -> Tensor See :func:`torch.pinverse` """) -add_docstr_all('index_add', - r""" -index_add(dim, index, tensor) -> Tensor - -Out-of-place version of :meth:`torch.Tensor.index_add_` -""") - -add_docstr_all('index_copy', - r""" -index_copy(dim, index, tensor) -> Tensor - -Out-of-place version of :meth:`torch.Tensor.index_copy_` -""") - -add_docstr_all('index_fill', - r""" -index_fill(dim, index, value) -> Tensor - -Out-of-place version of :meth:`torch.Tensor.index_fill_` -""") - -add_docstr_all('scatter', - r""" -scatter(dim, index, source) -> Tensor - -Out-of-place version of :meth:`torch.Tensor.scatter_` -""") - -add_docstr_all('scatter_add', - r""" -scatter_add(dim, index, source) -> Tensor - -Out-of-place version of :meth:`torch.Tensor.scatter_add_` -""") - -add_docstr_all('masked_scatter', - r""" -masked_scatter(mask, tensor) -> Tensor - -Out-of-place version of :meth:`torch.Tensor.masked_scatter_` -""") - -add_docstr_all('masked_fill', - r""" -masked_fill(mask, value) -> Tensor - -Out-of-place version of :meth:`torch.Tensor.masked_fill_` -""") - add_docstr_all('grad', r""" This attribute is ``None`` by default and becomes a Tensor the first time a call to diff --git a/torch/tensor.py b/torch/tensor.py index 0cd19ba60b..2ed379cada 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -307,6 +307,41 @@ class Tensor(torch._C._TensorBase): else: return super(Tensor, self).split_with_sizes(split_size, dim) + def index_add(self, dim, index, tensor): + r"""Out-of-place version of :meth:`torch.Tensor.index_add_` + """ + return self.clone().index_add_(dim, index, tensor) + + def index_copy(self, dim, index, tensor): + r"""Out-of-place version of :meth:`torch.Tensor.index_copy_` + """ + return self.clone().index_copy_(dim, index, tensor) + + def index_fill(self, dim, index, value): + r"""Out-of-place version of :meth:`torch.Tensor.index_fill_` + """ + return self.clone().index_fill_(dim, index, value) + + def scatter(self, dim, index, source): + r"""Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + return self.clone().scatter_(dim, index, source) + + def scatter_add(self, dim, index, source): + r"""Out-of-place version of :meth:`torch.Tensor.scatter_add_` + """ + return self.clone().scatter_add_(dim, index, source) + + def masked_scatter(self, mask, tensor): + r"""Out-of-place version of :meth:`torch.Tensor.masked_scatter_` + """ + return self.clone().masked_scatter_(mask, tensor) + + def masked_fill(self, mask, value): + r"""Out-of-place version of :meth:`torch.Tensor.masked_fill_` + """ + return self.clone().masked_fill_(mask, value) + def unique(self, sorted=True, return_inverse=False, dim=None): r"""Returns the unique scalar elements of the tensor as a 1-D tensor. |