summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--aten/doc/Tensor.h4
-rw-r--r--aten/doc/Type.h4
-rw-r--r--aten/src/ATen/Declarations.cwrap4
-rw-r--r--aten/src/ATen/function_wrapper.py2
-rw-r--r--test/test_torch.py13
5 files changed, 20 insertions, 7 deletions
diff --git a/aten/doc/Tensor.h b/aten/doc/Tensor.h
index 54c79d75e2..120105ad68 100644
--- a/aten/doc/Tensor.h
+++ b/aten/doc/Tensor.h
@@ -122,8 +122,8 @@ struct Tensor : public detail::TensorBase {
int64_t storage_offset() const;
Tensor & resize_(IntList size);
int64_t numel() const;
- Tensor & set_(Storage & storage);
- Tensor & set_(Storage & sourceStorage, int64_t storage_offset, IntList size, IntList stride={});
+ Tensor & set_(Storage & source);
+ Tensor & set_(Storage & source, int64_t storage_offset, IntList size, IntList stride={});
Tensor & set_(const Tensor & source);
Tensor & set_();
Tensor & fill_(Scalar value);
diff --git a/aten/doc/Type.h b/aten/doc/Type.h
index 8dce294518..871d979b85 100644
--- a/aten/doc/Type.h
+++ b/aten/doc/Type.h
@@ -133,8 +133,8 @@ struct AT_API Type {
virtual Tensor & ones_like_out(Tensor & result, const Tensor & input) const;
virtual Tensor ones_like(const Tensor & input) const;
virtual int64_t numel(const Tensor & self) const;
- virtual Tensor & set_(Tensor & self, Storage & storage) const;
- virtual Tensor & set_(Tensor & self, Storage & sourceStorage, int64_t storage_offset, IntList size, IntList stride={}) const;
+ virtual Tensor & set_(Tensor & self, Storage & source) const;
+ virtual Tensor & set_(Tensor & self, Storage & source, int64_t storage_offset, IntList size, IntList stride={}) const;
virtual Tensor & set_(Tensor & self, const Tensor & source) const;
virtual Tensor & set_(Tensor & self) const;
virtual Tensor & fill_(Tensor & self, Scalar value) const;
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap
index 6d7c7eab03..f2754b16c7 100644
--- a/aten/src/ATen/Declarations.cwrap
+++ b/aten/src/ATen/Declarations.cwrap
@@ -60,14 +60,14 @@
scalar_check: False
arguments:
- THTensor* self
- - THStorage* storage
+ - THStorage* source
- CONSTANT 0
- CONSTANT __storage_size.get()
- CONSTANT NULL
- cname: setStorage
arguments:
- THTensor* self
- - THStorage* sourceStorage
+ - THStorage* source
- long storage_offset
- THSize* size
- arg: THStride* stride
diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py
index 1bccfbcc88..44783aa62a 100644
--- a/aten/src/ATen/function_wrapper.py
+++ b/aten/src/ATen/function_wrapper.py
@@ -290,7 +290,7 @@ ALLOC_WRAP = {
CONSTANT_REPLACEMENTS = [
('AS_REAL', '${AS_REAL}'),
('__storage_size.get\\(\\)',
- 'THLongStorageView(static_cast<int64_t>(storage.size()), THLongStorageViewKind::LENGTH)'),
+ 'THLongStorageView(static_cast<int64_t>(source.size()), THLongStorageViewKind::LENGTH)'),
('__last_dim', 'self.ndimension()-1'),
]
diff --git a/test/test_torch.py b/test/test_torch.py
index 6db51aad7b..45057a8376 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -5805,6 +5805,19 @@ class TestTorch(TestCase):
self.assertEqual(t1.size(), size)
self.assertEqual(t1.stride(), stride)
+ # test argument names
+ t1 = torch.Tensor()
+ # 1. case when source is tensor
+ t1.set_(source=t2)
+ self.assertEqual(t1.storage()._cdata, t2.storage()._cdata)
+ # 2. case when source is storage
+ t1.set_(source=t2.storage())
+ self.assertEqual(t1.storage()._cdata, t2.storage()._cdata)
+ # 3. case when source is storage, and other args also specified
+ t1.set_(source=t2.storage(), storage_offset=0, size=size, stride=stride)
+ self.assertEqual(t1.size(), size)
+ self.assertEqual(t1.stride(), stride)
+
def test_equal(self):
# Contiguous, 1D
t1 = torch.Tensor((3, 4, 9, 10))