diff options
-rw-r--r-- | aten/doc/Tensor.h | 4 | ||||
-rw-r--r-- | aten/doc/Type.h | 4 | ||||
-rw-r--r-- | aten/src/ATen/Declarations.cwrap | 4 | ||||
-rw-r--r-- | aten/src/ATen/function_wrapper.py | 2 | ||||
-rw-r--r-- | test/test_torch.py | 13 |
5 files changed, 20 insertions, 7 deletions
diff --git a/aten/doc/Tensor.h b/aten/doc/Tensor.h index 54c79d75e2..120105ad68 100644 --- a/aten/doc/Tensor.h +++ b/aten/doc/Tensor.h @@ -122,8 +122,8 @@ struct Tensor : public detail::TensorBase { int64_t storage_offset() const; Tensor & resize_(IntList size); int64_t numel() const; - Tensor & set_(Storage & storage); - Tensor & set_(Storage & sourceStorage, int64_t storage_offset, IntList size, IntList stride={}); + Tensor & set_(Storage & source); + Tensor & set_(Storage & source, int64_t storage_offset, IntList size, IntList stride={}); Tensor & set_(const Tensor & source); Tensor & set_(); Tensor & fill_(Scalar value); diff --git a/aten/doc/Type.h b/aten/doc/Type.h index 8dce294518..871d979b85 100644 --- a/aten/doc/Type.h +++ b/aten/doc/Type.h @@ -133,8 +133,8 @@ struct AT_API Type { virtual Tensor & ones_like_out(Tensor & result, const Tensor & input) const; virtual Tensor ones_like(const Tensor & input) const; virtual int64_t numel(const Tensor & self) const; - virtual Tensor & set_(Tensor & self, Storage & storage) const; - virtual Tensor & set_(Tensor & self, Storage & sourceStorage, int64_t storage_offset, IntList size, IntList stride={}) const; + virtual Tensor & set_(Tensor & self, Storage & source) const; + virtual Tensor & set_(Tensor & self, Storage & source, int64_t storage_offset, IntList size, IntList stride={}) const; virtual Tensor & set_(Tensor & self, const Tensor & source) const; virtual Tensor & set_(Tensor & self) const; virtual Tensor & fill_(Tensor & self, Scalar value) const; diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap index 6d7c7eab03..f2754b16c7 100644 --- a/aten/src/ATen/Declarations.cwrap +++ b/aten/src/ATen/Declarations.cwrap @@ -60,14 +60,14 @@ scalar_check: False arguments: - THTensor* self - - THStorage* storage + - THStorage* source - CONSTANT 0 - CONSTANT __storage_size.get() - CONSTANT NULL - cname: setStorage arguments: - THTensor* self - - THStorage* sourceStorage + - THStorage* source - long storage_offset - THSize* size - arg: THStride* stride diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py index 1bccfbcc88..44783aa62a 100644 --- a/aten/src/ATen/function_wrapper.py +++ b/aten/src/ATen/function_wrapper.py @@ -290,7 +290,7 @@ ALLOC_WRAP = { CONSTANT_REPLACEMENTS = [ ('AS_REAL', '${AS_REAL}'), ('__storage_size.get\\(\\)', - 'THLongStorageView(static_cast<int64_t>(storage.size()), THLongStorageViewKind::LENGTH)'), + 'THLongStorageView(static_cast<int64_t>(source.size()), THLongStorageViewKind::LENGTH)'), ('__last_dim', 'self.ndimension()-1'), ] diff --git a/test/test_torch.py b/test/test_torch.py index 6db51aad7b..45057a8376 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -5805,6 +5805,19 @@ class TestTorch(TestCase): self.assertEqual(t1.size(), size) self.assertEqual(t1.stride(), stride) + # test argument names + t1 = torch.Tensor() + # 1. case when source is tensor + t1.set_(source=t2) + self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) + # 2. case when source is storage + t1.set_(source=t2.storage()) + self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) + # 3. case when source is storage, and other args also specified + t1.set_(source=t2.storage(), storage_offset=0, size=size, stride=stride) + self.assertEqual(t1.size(), size) + self.assertEqual(t1.stride(), stride) + def test_equal(self): # Contiguous, 1D t1 = torch.Tensor((3, 4, 9, 10)) |