diff options
author | Sam Gross <colesbury@gmail.com> | 2018-01-17 17:30:43 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-01-17 17:30:43 -0500 |
commit | 720c7b1e2c9580fa897110abca3a41ae676174da (patch) | |
tree | 9ebe99e73cd0a22a1932e99113c8f8d40d4a133a /torch/_utils.py | |
parent | b37aa2bf0e1d0fc62a7a924c0b50fe06870c1bce (diff) | |
download | pytorch-720c7b1e2c9580fa897110abca3a41ae676174da.tar.gz pytorch-720c7b1e2c9580fa897110abca3a41ae676174da.tar.bz2 pytorch-720c7b1e2c9580fa897110abca3a41ae676174da.zip |
Move repeat to torch/_utils.py (#4712)
This moves the implementation of repeat to _utils so that the autograd
function can call it directly instead of relying on forward being called
on tensors.
This also removes _range, which was previously necessary because we
shadowed the built-in range() function.
Diffstat (limited to 'torch/_utils.py')
-rw-r--r-- | torch/_utils.py | 55 |
1 files changed, 51 insertions, 4 deletions
diff --git a/torch/_utils.py b/torch/_utils.py index 94cd88acf3..84c19a0b2e 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -76,10 +76,6 @@ def _rebuild_tensor(storage, storage_offset, size, stride): return tensor_class().set_(storage, storage_offset, size, stride) -def _range(*args, **kwargs): - return __builtins__['range'](*args, **kwargs) - - def _import_dotted_name(name): components = name.split('.') obj = __import__(components[0]) @@ -238,3 +234,54 @@ def _take_tensors(tensors, size_limit): for buf, _ in buf_dict.values(): if len(buf) > 0: yield buf + + +def _repeat(self, *sizes): + r"""Repeats this tensor along the specified dimensions. + + Unlike :meth:`expand`, this function copies the tensor's data. + + Args: + *sizes (torch.Size or int...): The number of times to repeat this + tensor along each dimension + + Example: + >>> x = torch.Tensor([1, 2, 3]) + >>> x.repeat(4, 2) + 1 2 3 1 2 3 + 1 2 3 1 2 3 + 1 2 3 1 2 3 + 1 2 3 1 2 3 + [torch.FloatTensor of size 4x6] + >>> x.repeat(4, 2, 1).size() + torch.Size([4, 2, 3]) + """ + # If args == (torch.Size,), then we need to unpack the tuple + if len(sizes) == 1 and isinstance(sizes[0], torch.Size): + sizes = sizes[0] + + repeats = list(sizes) + + if len(repeats) < self.dim(): + raise ValueError('Number of dimensions of repeat dims can not be ' + 'smaller than number of dimensions of tensor') + + # Add new leading dimensions to the tensor if the + # number of target dimensions is larger than the + # number of source dimensions. + num_new_dimensions = len(repeats) - self.dim() + padded_size = [1] * num_new_dimensions + list(self.size()) + target_size = torch.Size([a * b for a, b in zip(padded_size, repeats)]) + + xtensor = self.new().set_(self) + xtensor = xtensor.expand(padded_size) + + result = self.new() + result.resize_(target_size) + urtensor = result.new(result) + for i in range(xtensor.dim()): + urtensor = urtensor.unfold(i, xtensor.size(i), xtensor.size(i)) + + urtensor.copy_(xtensor.expand_as(urtensor)) + + return result |