diff options
author | Natalia Gimelshein <ngimelshein@nvidia.com> | 2017-11-28 16:27:58 -0800 |
---|---|---|
committer | Adam Paszke <adam.paszke@gmail.com> | 2017-11-29 10:54:57 +0100 |
commit | ea28deee75d3a2d39be95652383179832d2483e7 (patch) | |
tree | 28d187de7b79ed98d6a5ef3f304992ba110e303f /torch | |
parent | 0a434ff6853b8c03415344286e6ade6d4f17162a (diff) | |
download | pytorch-ea28deee75d3a2d39be95652383179832d2483e7.tar.gz pytorch-ea28deee75d3a2d39be95652383179832d2483e7.tar.bz2 pytorch-ea28deee75d3a2d39be95652383179832d2483e7.zip |
use torch.cat in _flatten
Diffstat (limited to 'torch')
-rw-r--r-- | torch/_utils.py | 8 |
1 files changed, 1 insertions, 7 deletions
diff --git a/torch/_utils.py b/torch/_utils.py index 45ddbb1534..94cd88acf3 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -120,13 +120,7 @@ def _flatten_dense_tensors(tensors): """ if len(tensors) == 1: return tensors[0].contiguous().view(-1) - numels = [tensor.numel() for tensor in tensors] - size = sum(numels) - offset = 0 - flat = tensors[0].new(size) - for tensor, numel in zip(tensors, numels): - flat.narrow(0, offset, numel).copy_(tensor, broadcast=False) - offset += numel + flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0) return flat |