summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
authorNatalia Gimelshein <ngimelshein@nvidia.com>2017-11-28 16:27:58 -0800
committerAdam Paszke <adam.paszke@gmail.com>2017-11-29 10:54:57 +0100
commitea28deee75d3a2d39be95652383179832d2483e7 (patch)
tree28d187de7b79ed98d6a5ef3f304992ba110e303f /torch
parent0a434ff6853b8c03415344286e6ade6d4f17162a (diff)
downloadpytorch-ea28deee75d3a2d39be95652383179832d2483e7.tar.gz
pytorch-ea28deee75d3a2d39be95652383179832d2483e7.tar.bz2
pytorch-ea28deee75d3a2d39be95652383179832d2483e7.zip
use torch.cat in _flatten
Diffstat (limited to 'torch')
-rw-r--r--torch/_utils.py8
1 files changed, 1 insertions, 7 deletions
diff --git a/torch/_utils.py b/torch/_utils.py
index 45ddbb1534..94cd88acf3 100644
--- a/torch/_utils.py
+++ b/torch/_utils.py
@@ -120,13 +120,7 @@ def _flatten_dense_tensors(tensors):
"""
if len(tensors) == 1:
return tensors[0].contiguous().view(-1)
- numels = [tensor.numel() for tensor in tensors]
- size = sum(numels)
- offset = 0
- flat = tensors[0].new(size)
- for tensor, numel in zip(tensors, numels):
- flat.narrow(0, offset, numel).copy_(tensor, broadcast=False)
- offset += numel
+ flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0)
return flat