summaryrefslogtreecommitdiff
path: root/torch/_utils.py
diff options
context:
space:
mode:
authorSam Gross <colesbury@gmail.com>2018-01-17 17:30:43 -0500
committerGitHub <noreply@github.com>2018-01-17 17:30:43 -0500
commit720c7b1e2c9580fa897110abca3a41ae676174da (patch)
tree9ebe99e73cd0a22a1932e99113c8f8d40d4a133a /torch/_utils.py
parentb37aa2bf0e1d0fc62a7a924c0b50fe06870c1bce (diff)
downloadpytorch-720c7b1e2c9580fa897110abca3a41ae676174da.tar.gz
pytorch-720c7b1e2c9580fa897110abca3a41ae676174da.tar.bz2
pytorch-720c7b1e2c9580fa897110abca3a41ae676174da.zip
Move repeat to torch/_utils.py (#4712)
This moves the implementation of repeat to _utils so that the autograd function can call it directly instead of relying on forward being called on tensors. This also removes _range, which was previously necessary because we shadowed the built-in range() function.
Diffstat (limited to 'torch/_utils.py')
-rw-r--r--torch/_utils.py55
1 files changed, 51 insertions, 4 deletions
diff --git a/torch/_utils.py b/torch/_utils.py
index 94cd88acf3..84c19a0b2e 100644
--- a/torch/_utils.py
+++ b/torch/_utils.py
@@ -76,10 +76,6 @@ def _rebuild_tensor(storage, storage_offset, size, stride):
return tensor_class().set_(storage, storage_offset, size, stride)
-def _range(*args, **kwargs):
- return __builtins__['range'](*args, **kwargs)
-
-
def _import_dotted_name(name):
components = name.split('.')
obj = __import__(components[0])
@@ -238,3 +234,54 @@ def _take_tensors(tensors, size_limit):
for buf, _ in buf_dict.values():
if len(buf) > 0:
yield buf
+
+
+def _repeat(self, *sizes):
+ r"""Repeats this tensor along the specified dimensions.
+
+ Unlike :meth:`expand`, this function copies the tensor's data.
+
+ Args:
+ *sizes (torch.Size or int...): The number of times to repeat this
+ tensor along each dimension
+
+ Example:
+ >>> x = torch.Tensor([1, 2, 3])
+ >>> x.repeat(4, 2)
+ 1 2 3 1 2 3
+ 1 2 3 1 2 3
+ 1 2 3 1 2 3
+ 1 2 3 1 2 3
+ [torch.FloatTensor of size 4x6]
+ >>> x.repeat(4, 2, 1).size()
+ torch.Size([4, 2, 3])
+ """
+ # If args == (torch.Size,), then we need to unpack the tuple
+ if len(sizes) == 1 and isinstance(sizes[0], torch.Size):
+ sizes = sizes[0]
+
+ repeats = list(sizes)
+
+ if len(repeats) < self.dim():
+ raise ValueError('Number of dimensions of repeat dims can not be '
+ 'smaller than number of dimensions of tensor')
+
+ # Add new leading dimensions to the tensor if the
+ # number of target dimensions is larger than the
+ # number of source dimensions.
+ num_new_dimensions = len(repeats) - self.dim()
+ padded_size = [1] * num_new_dimensions + list(self.size())
+ target_size = torch.Size([a * b for a, b in zip(padded_size, repeats)])
+
+ xtensor = self.new().set_(self)
+ xtensor = xtensor.expand(padded_size)
+
+ result = self.new()
+ result.resize_(target_size)
+ urtensor = result.new(result)
+ for i in range(xtensor.dim()):
+ urtensor = urtensor.unfold(i, xtensor.size(i), xtensor.size(i))
+
+ urtensor.copy_(xtensor.expand_as(urtensor))
+
+ return result