summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEdward Yang <ezyang@fb.com>2019-02-04 19:23:57 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-04 19:30:04 -0800
commit6c04224cd84385910021580176413a98839cec72 (patch)
tree60739d235527165884dad0417f24efc16f253d14
parent1409a2afc8d81d911dbcfc2d0b210cf62658237b (diff)
downloadpytorch-6c04224cd84385910021580176413a98839cec72.tar.gz
pytorch-6c04224cd84385910021580176413a98839cec72.tar.bz2
pytorch-6c04224cd84385910021580176413a98839cec72.zip
Revert "Move outplace ops to ATen (#12413)" (#16731)
Summary: This reverts commit f660d3ae19decc64390e894fbaf8de80d87585e0. cc zasdfgbnm Reasoning at https://github.com/pytorch/pytorch/pull/12413#issuecomment-460424129 Pull Request resolved: https://github.com/pytorch/pytorch/pull/16731 Differential Revision: D13948022 Pulled By: ezyang fbshipit-source-id: b10669cf03679e306850314b7b5b08bed0839e19
-rw-r--r--aten/src/ATen/core/Tensor.h7
-rw-r--r--aten/src/ATen/core/TensorMethods.h21
-rw-r--r--aten/src/ATen/core/Type.h7
-rw-r--r--aten/src/ATen/native/Indexing.cpp32
-rw-r--r--aten/src/ATen/native/native_functions.yaml21
-rw-r--r--test/common_methods_invocations.py4
-rw-r--r--test/test_jit.py2
-rw-r--r--test/test_torch.py12
-rw-r--r--tools/pyi/gen_pyi.py7
-rw-r--r--torch/_tensor_docs.py49
-rw-r--r--torch/tensor.py35
11 files changed, 41 insertions, 156 deletions
diff --git a/aten/src/ATen/core/Tensor.h b/aten/src/ATen/core/Tensor.h
index cf3e05d30a..ba4ccb4393 100644
--- a/aten/src/ATen/core/Tensor.h
+++ b/aten/src/ATen/core/Tensor.h
@@ -389,15 +389,8 @@ class CAFFE2_API Tensor {
Tensor irfft(int64_t signal_ndim, bool normalized=false, bool onesided=true, IntList signal_sizes={}) const;
Tensor index(TensorList indices) const;
Tensor & index_copy_(int64_t dim, const Tensor & index, const Tensor & source);
- Tensor index_copy(int64_t dim, const Tensor & index, const Tensor & source) const;
Tensor index_put(TensorList indices, const Tensor & values, bool accumulate=false) const;
Tensor & index_put_(TensorList indices, const Tensor & values, bool accumulate=false);
- Tensor index_add(int64_t dim, const Tensor & index, const Tensor & source) const;
- Tensor index_fill(int64_t dim, const Tensor & index, Scalar source) const;
- Tensor scatter(int64_t dim, const Tensor & index, const Tensor & source) const;
- Tensor scatter_add(int64_t dim, const Tensor & index, const Tensor & source) const;
- Tensor masked_scatter(const Tensor & mask, const Tensor & source) const;
- Tensor masked_fill(const Tensor & mask, Scalar source) const;
Tensor inverse() const;
Tensor isclose(const Tensor & other, double rtol=1e-05, double atol=1e-08, bool equal_nan=false) const;
bool is_distributed() const;
diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h
index 5a5389f6c4..e76a0eb358 100644
--- a/aten/src/ATen/core/TensorMethods.h
+++ b/aten/src/ATen/core/TensorMethods.h
@@ -310,33 +310,12 @@ inline Tensor Tensor::index(TensorList indices) const {
inline Tensor & Tensor::index_copy_(int64_t dim, const Tensor & index, const Tensor & source) {
return type().index_copy_(*this, dim, index, source);
}
-inline Tensor Tensor::index_copy(int64_t dim, const Tensor & index, const Tensor & source) const {
- return type().index_copy(*this, dim, index, source);
-}
inline Tensor Tensor::index_put(TensorList indices, const Tensor & values, bool accumulate) const {
return type().index_put(*this, indices, values, accumulate);
}
inline Tensor & Tensor::index_put_(TensorList indices, const Tensor & values, bool accumulate) {
return type().index_put_(*this, indices, values, accumulate);
}
-inline Tensor Tensor::index_add(int64_t dim, const Tensor & index, const Tensor & source) const {
- return type().index_add(*this, dim, index, source);
-}
-inline Tensor Tensor::index_fill(int64_t dim, const Tensor & index, Scalar source) const {
- return type().index_fill(*this, dim, index, source);
-}
-inline Tensor Tensor::scatter(int64_t dim, const Tensor & index, const Tensor & source) const {
- return type().scatter(*this, dim, index, source);
-}
-inline Tensor Tensor::scatter_add(int64_t dim, const Tensor & index, const Tensor & source) const {
- return type().scatter_add(*this, dim, index, source);
-}
-inline Tensor Tensor::masked_scatter(const Tensor & mask, const Tensor & source) const {
- return type().masked_scatter(*this, mask, source);
-}
-inline Tensor Tensor::masked_fill(const Tensor & mask, Scalar source) const {
- return type().masked_fill(*this, mask, source);
-}
inline Tensor Tensor::inverse() const {
return type().inverse(*this);
}
diff --git a/aten/src/ATen/core/Type.h b/aten/src/ATen/core/Type.h
index f7b20e7e53..30601cd3cf 100644
--- a/aten/src/ATen/core/Type.h
+++ b/aten/src/ATen/core/Type.h
@@ -271,15 +271,8 @@ struct CAFFE2_API Type {
virtual Tensor irfft(const Tensor & self, int64_t signal_ndim, bool normalized, bool onesided, IntList signal_sizes) const = 0;
virtual Tensor index(const Tensor & self, TensorList indices) const = 0;
virtual Tensor & index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) const = 0;
- virtual Tensor index_copy(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) const = 0;
virtual Tensor index_put(const Tensor & self, TensorList indices, const Tensor & values, bool accumulate) const = 0;
virtual Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & values, bool accumulate) const = 0;
- virtual Tensor index_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) const = 0;
- virtual Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, Scalar source) const = 0;
- virtual Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) const = 0;
- virtual Tensor scatter_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) const = 0;
- virtual Tensor masked_scatter(const Tensor & self, const Tensor & mask, const Tensor & source) const = 0;
- virtual Tensor masked_fill(const Tensor & self, const Tensor & mask, Scalar source) const = 0;
virtual Tensor inverse(const Tensor & self) const = 0;
virtual Tensor isclose(const Tensor & self, const Tensor & other, double rtol, double atol, bool equal_nan) const = 0;
virtual bool is_distributed(const Tensor & self) const = 0;
diff --git a/aten/src/ATen/native/Indexing.cpp b/aten/src/ATen/native/Indexing.cpp
index 0a299a6830..eb6c764202 100644
--- a/aten/src/ATen/native/Indexing.cpp
+++ b/aten/src/ATen/native/Indexing.cpp
@@ -497,36 +497,4 @@ Tensor & index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Ten
return at::legacy::th::_th_index_copy_(self, dim, index, source);
}
-Tensor index_copy(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
- return self.clone().index_copy_(dim, index, source);
-}
-
-Tensor index_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
- return self.clone().index_add_(dim, index, source);
-}
-
-Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, Scalar source) {
- return self.clone().index_fill_(dim, index, source);
-}
-
-Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
- return self.clone().scatter_(dim, index, source);
-}
-
-Tensor scatter_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
- return self.clone().scatter_add_(dim, index, source);
-}
-
-Tensor masked_scatter(const Tensor & self, const Tensor & mask, const Tensor & source) {
- Tensor _mask, _self;
- std::tie(_mask, _self) = expand_outplace(mask, self);
- return _self.clone().masked_scatter_(_mask, source);
-}
-
-Tensor masked_fill(const Tensor & self, const Tensor & mask, Scalar source) {
- Tensor _mask, _self;
- std::tie(_mask, _self) = expand_outplace(mask, self);
- return _self.clone().masked_fill_(mask, source);
-}
-
}} // at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 91cd1284fa..b980982158 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1102,9 +1102,6 @@
- func: index_copy_(Tensor(a!) self, int dim, IndexTensor index, Tensor source) -> Tensor(a!)
variants: method
-- func: index_copy(Tensor self, int64_t dim, IndexTensor index, Tensor source) -> Tensor
- variants: function, method
-
- func: index_put(Tensor self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor
matches_jit_signature: True
variants: function, method
@@ -1113,24 +1110,6 @@
matches_jit_signature: True
variants: function, method
-- func: index_add(Tensor self, int64_t dim, IndexTensor index, Tensor source) -> Tensor
- variants: function, method
-
-- func: index_fill(Tensor self, int64_t dim, IndexTensor index, Scalar source) -> Tensor
- variants: function, method
-
-- func: scatter(Tensor self, int64_t dim, IndexTensor index, Tensor source) -> Tensor
- variants: function, method
-
-- func: scatter_add(Tensor self, int64_t dim, IndexTensor index, Tensor source) -> Tensor
- variants: function, method
-
-- func: masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor
- variants: function, method
-
-- func: masked_fill(Tensor self, Tensor mask, Scalar source) -> Tensor
- variants: function, method
-
- func: instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor
matches_jit_signature: True
variants: function
diff --git a/test/common_methods_invocations.py b/test/common_methods_invocations.py
index b998c6fef6..41c8763e16 100644
--- a/test/common_methods_invocations.py
+++ b/test/common_methods_invocations.py
@@ -744,7 +744,7 @@ def method_tests():
('masked_select', (), (mask_not_all_zeros((M, M)),), 'scalar_broadcast_lhs'),
('masked_fill', (M, M), (torch.ByteTensor(M, M).bernoulli_(), 10)),
('masked_fill', (M, M), (torch.ByteTensor(M, M).bernoulli_(), torch.tensor(10)), 'tensor'),
- ('masked_fill', (M,), (torch.ByteTensor(M, M).bernoulli_(), 10), 'broadcast_lhs'),
+ # no lhs or all broadcast on masked_fill or masked_scatter because it's always inplace
('masked_fill', (M, M), (torch.ByteTensor(M,).bernoulli_(), 10), 'broadcast_rhs'),
('masked_fill', (), (torch.tensor(0, dtype=torch.uint8, requires_grad=False).bernoulli_(), 10), 'scalar'),
('masked_fill', (), (torch.tensor(0, dtype=torch.uint8, requires_grad=False).bernoulli_(), torch.tensor(10)),
@@ -752,8 +752,6 @@ def method_tests():
('masked_fill', (M, M), (torch.tensor(0, dtype=torch.uint8, requires_grad=False).bernoulli_(), 10),
'scalar_broadcast_rhs'),
('masked_scatter', (M, M), (torch.ByteTensor(M, M).bernoulli_(), (M, M))),
- ('masked_scatter', (M,), (torch.ByteTensor(M, M).bernoulli_(), (M, M)),
- 'broadcast_lhs'),
('masked_scatter', (M, M), (torch.ByteTensor(M,).bernoulli_(), (M, M)),
'broadcast_rhs'),
('masked_scatter', (M, M), (bernoulli_scalar(), (M, M)), 'scalar'),
diff --git a/test/test_jit.py b/test/test_jit.py
index 42ec3d34c3..b0ef8326da 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -8834,7 +8834,7 @@ a")
def test_builtin_error_messsage(self):
from torch.nn.modules.utils import _single, _pair, _triple, _quadruple
- with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
+ with self.assertRaisesRegex(RuntimeError, "aten::masked_fill_"):
@torch.jit.script
def close_match(x):
return x.masked_fill(True)
diff --git a/test/test_torch.py b/test/test_torch.py
index 1c538d5a2d..f3e8d8234c 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -3450,7 +3450,6 @@ class _TestTorchMixin(object):
for fn in fns:
(dims_small, dims_large, dims_full) = self._select_broadcastable_dims()
- full1d = cast(torch.randn(*dims_full).flatten().float())
small = cast(torch.randn(*dims_small).float())
large = cast(torch.randn(*dims_large).float())
small_expanded = small.expand(*dims_full)
@@ -3467,7 +3466,8 @@ class _TestTorchMixin(object):
# map and map2 are not implementd on CUDA tensors
continue
- if hasattr(large_expanded, fn):
+ # TODO: fix masked_scatter and masked_fill broadcasting
+ if hasattr(large_expanded, fn) and fn not in ['masked_scatter', 'masked_fill']:
# run through tensor versions of functions
# and verify fully expanded inputs give same results
expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded}
@@ -3477,10 +3477,6 @@ class _TestTorchMixin(object):
return myfn(t1, 0.5)
elif fn == "masked_select":
return myfn(t1 < 0)
- elif fn == "masked_scatter":
- return myfn(t1 < 0.5, full1d)
- elif fn == "masked_fill":
- return myfn(t1 < 0.5, 1.0)
elif fn in fns_3_args:
return myfn(1, t1, t2)
else:
@@ -3508,7 +3504,7 @@ class _TestTorchMixin(object):
elif fn == "masked_select":
return fntorch(t1, t2 < 0)
elif fn == "masked_scatter":
- return fntorch(t1, t2 < 0.5, full1d)
+ return fntorch(t1, t2 < 0.5, cast(torch.arange(1, t1.nelement() + 1).float()))
elif fn == "masked_fill":
return fntorch(t1, t2 < 0.5, 1.0)
elif fn in fns_3_args:
@@ -3539,7 +3535,7 @@ class _TestTorchMixin(object):
if fn == "lerp":
return t0_fn(t1, 0.5)
elif fn == "masked_scatter":
- return t0_fn(t1 < 0.5, full1d)
+ return t0_fn(t1 < 0.5, cast(torch.arange(1, t0.nelement() + 1).float()))
elif fn == "masked_fill":
return t0_fn(t1 < 0.5, 1.0)
elif fn == "map":
diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py
index 827c57385c..6905f80f47 100644
--- a/tools/pyi/gen_pyi.py
+++ b/tools/pyi/gen_pyi.py
@@ -76,13 +76,6 @@ blacklist = [
'tensordot',
'norm',
'split',
- 'index_add',
- 'index_copy',
- 'index_fill',
- 'scatter',
- 'scatter_add',
- 'masked_scatter',
- 'masked_fill',
# These are handled specially by python_arg_parser.cpp
'add',
'add_',
diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py
index a94a633e34..ba0e31f8fb 100644
--- a/torch/_tensor_docs.py
+++ b/torch/_tensor_docs.py
@@ -2962,55 +2962,6 @@ pinverse() -> Tensor
See :func:`torch.pinverse`
""")
-add_docstr_all('index_add',
- r"""
-index_add(dim, index, tensor) -> Tensor
-
-Out-of-place version of :meth:`torch.Tensor.index_add_`
-""")
-
-add_docstr_all('index_copy',
- r"""
-index_copy(dim, index, tensor) -> Tensor
-
-Out-of-place version of :meth:`torch.Tensor.index_copy_`
-""")
-
-add_docstr_all('index_fill',
- r"""
-index_fill(dim, index, value) -> Tensor
-
-Out-of-place version of :meth:`torch.Tensor.index_fill_`
-""")
-
-add_docstr_all('scatter',
- r"""
-scatter(dim, index, source) -> Tensor
-
-Out-of-place version of :meth:`torch.Tensor.scatter_`
-""")
-
-add_docstr_all('scatter_add',
- r"""
-scatter_add(dim, index, source) -> Tensor
-
-Out-of-place version of :meth:`torch.Tensor.scatter_add_`
-""")
-
-add_docstr_all('masked_scatter',
- r"""
-masked_scatter(mask, tensor) -> Tensor
-
-Out-of-place version of :meth:`torch.Tensor.masked_scatter_`
-""")
-
-add_docstr_all('masked_fill',
- r"""
-masked_fill(mask, value) -> Tensor
-
-Out-of-place version of :meth:`torch.Tensor.masked_fill_`
-""")
-
add_docstr_all('grad',
r"""
This attribute is ``None`` by default and becomes a Tensor the first time a call to
diff --git a/torch/tensor.py b/torch/tensor.py
index 0cd19ba60b..2ed379cada 100644
--- a/torch/tensor.py
+++ b/torch/tensor.py
@@ -307,6 +307,41 @@ class Tensor(torch._C._TensorBase):
else:
return super(Tensor, self).split_with_sizes(split_size, dim)
+ def index_add(self, dim, index, tensor):
+ r"""Out-of-place version of :meth:`torch.Tensor.index_add_`
+ """
+ return self.clone().index_add_(dim, index, tensor)
+
+ def index_copy(self, dim, index, tensor):
+ r"""Out-of-place version of :meth:`torch.Tensor.index_copy_`
+ """
+ return self.clone().index_copy_(dim, index, tensor)
+
+ def index_fill(self, dim, index, value):
+ r"""Out-of-place version of :meth:`torch.Tensor.index_fill_`
+ """
+ return self.clone().index_fill_(dim, index, value)
+
+ def scatter(self, dim, index, source):
+ r"""Out-of-place version of :meth:`torch.Tensor.scatter_`
+ """
+ return self.clone().scatter_(dim, index, source)
+
+ def scatter_add(self, dim, index, source):
+ r"""Out-of-place version of :meth:`torch.Tensor.scatter_add_`
+ """
+ return self.clone().scatter_add_(dim, index, source)
+
+ def masked_scatter(self, mask, tensor):
+ r"""Out-of-place version of :meth:`torch.Tensor.masked_scatter_`
+ """
+ return self.clone().masked_scatter_(mask, tensor)
+
+ def masked_fill(self, mask, value):
+ r"""Out-of-place version of :meth:`torch.Tensor.masked_fill_`
+ """
+ return self.clone().masked_fill_(mask, value)
+
def unique(self, sorted=True, return_inverse=False, dim=None):
r"""Returns the unique scalar elements of the tensor as a 1-D tensor.