summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
authorEdward Yang <ezyang@fb.com>2019-02-04 19:23:57 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-04 19:30:04 -0800
commit6c04224cd84385910021580176413a98839cec72 (patch)
tree60739d235527165884dad0417f24efc16f253d14 /torch
parent1409a2afc8d81d911dbcfc2d0b210cf62658237b (diff)
downloadpytorch-6c04224cd84385910021580176413a98839cec72.tar.gz
pytorch-6c04224cd84385910021580176413a98839cec72.tar.bz2
pytorch-6c04224cd84385910021580176413a98839cec72.zip
Revert "Move outplace ops to ATen (#12413)" (#16731)
Summary: This reverts commit f660d3ae19decc64390e894fbaf8de80d87585e0. cc zasdfgbnm Reasoning at https://github.com/pytorch/pytorch/pull/12413#issuecomment-460424129 Pull Request resolved: https://github.com/pytorch/pytorch/pull/16731 Differential Revision: D13948022 Pulled By: ezyang fbshipit-source-id: b10669cf03679e306850314b7b5b08bed0839e19
Diffstat (limited to 'torch')
-rw-r--r--torch/_tensor_docs.py49
-rw-r--r--torch/tensor.py35
2 files changed, 35 insertions, 49 deletions
diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py
index a94a633e34..ba0e31f8fb 100644
--- a/torch/_tensor_docs.py
+++ b/torch/_tensor_docs.py
@@ -2962,55 +2962,6 @@ pinverse() -> Tensor
See :func:`torch.pinverse`
""")
-add_docstr_all('index_add',
- r"""
-index_add(dim, index, tensor) -> Tensor
-
-Out-of-place version of :meth:`torch.Tensor.index_add_`
-""")
-
-add_docstr_all('index_copy',
- r"""
-index_copy(dim, index, tensor) -> Tensor
-
-Out-of-place version of :meth:`torch.Tensor.index_copy_`
-""")
-
-add_docstr_all('index_fill',
- r"""
-index_fill(dim, index, value) -> Tensor
-
-Out-of-place version of :meth:`torch.Tensor.index_fill_`
-""")
-
-add_docstr_all('scatter',
- r"""
-scatter(dim, index, source) -> Tensor
-
-Out-of-place version of :meth:`torch.Tensor.scatter_`
-""")
-
-add_docstr_all('scatter_add',
- r"""
-scatter_add(dim, index, source) -> Tensor
-
-Out-of-place version of :meth:`torch.Tensor.scatter_add_`
-""")
-
-add_docstr_all('masked_scatter',
- r"""
-masked_scatter(mask, tensor) -> Tensor
-
-Out-of-place version of :meth:`torch.Tensor.masked_scatter_`
-""")
-
-add_docstr_all('masked_fill',
- r"""
-masked_fill(mask, value) -> Tensor
-
-Out-of-place version of :meth:`torch.Tensor.masked_fill_`
-""")
-
add_docstr_all('grad',
r"""
This attribute is ``None`` by default and becomes a Tensor the first time a call to
diff --git a/torch/tensor.py b/torch/tensor.py
index 0cd19ba60b..2ed379cada 100644
--- a/torch/tensor.py
+++ b/torch/tensor.py
@@ -307,6 +307,41 @@ class Tensor(torch._C._TensorBase):
else:
return super(Tensor, self).split_with_sizes(split_size, dim)
+ def index_add(self, dim, index, tensor):
+ r"""Out-of-place version of :meth:`torch.Tensor.index_add_`
+ """
+ return self.clone().index_add_(dim, index, tensor)
+
+ def index_copy(self, dim, index, tensor):
+ r"""Out-of-place version of :meth:`torch.Tensor.index_copy_`
+ """
+ return self.clone().index_copy_(dim, index, tensor)
+
+ def index_fill(self, dim, index, value):
+ r"""Out-of-place version of :meth:`torch.Tensor.index_fill_`
+ """
+ return self.clone().index_fill_(dim, index, value)
+
+ def scatter(self, dim, index, source):
+ r"""Out-of-place version of :meth:`torch.Tensor.scatter_`
+ """
+ return self.clone().scatter_(dim, index, source)
+
+ def scatter_add(self, dim, index, source):
+ r"""Out-of-place version of :meth:`torch.Tensor.scatter_add_`
+ """
+ return self.clone().scatter_add_(dim, index, source)
+
+ def masked_scatter(self, mask, tensor):
+ r"""Out-of-place version of :meth:`torch.Tensor.masked_scatter_`
+ """
+ return self.clone().masked_scatter_(mask, tensor)
+
+ def masked_fill(self, mask, value):
+ r"""Out-of-place version of :meth:`torch.Tensor.masked_fill_`
+ """
+ return self.clone().masked_fill_(mask, value)
+
def unique(self, sorted=True, return_inverse=False, dim=None):
r"""Returns the unique scalar elements of the tensor as a 1-D tensor.