summaryrefslogtreecommitdiff
path: root/torch/utils
diff options
context:
space:
mode:
authorSsnL <tongzhou.wang.1994@gmail.com>2018-12-28 11:51:26 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-12-28 12:26:46 -0800
commitfb22f76eb6eae224469493b26d88b87fe5455fa7 (patch)
treedc774bb88cf648d78870d2f2376ab73ec0496eef /torch/utils
parent6a3e54eda90dd4bef003b35cb291621ea0b2a65d (diff)
downloadpytorch-fb22f76eb6eae224469493b26d88b87fe5455fa7.tar.gz
pytorch-fb22f76eb6eae224469493b26d88b87fe5455fa7.tar.bz2
pytorch-fb22f76eb6eae224469493b26d88b87fe5455fa7.zip
default_collate should collate bool list to byte tensors (#14669)
Summary: Based on #15331 . Review only the last commit. Fixes https://github.com/pytorch/pytorch/issues/14507. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14669 Reviewed By: ezyang Differential Revision: D13528725 Pulled By: soumith fbshipit-source-id: f12f1ac1c4ff2a3ddd6877c0c096a5da3a1ffa3c
Diffstat (limited to 'torch/utils')
-rw-r--r--torch/utils/data/_utils/collate.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/torch/utils/data/_utils/collate.py b/torch/utils/data/_utils/collate.py
index e0823b2397..fdc287f979 100644
--- a/torch/utils/data/_utils/collate.py
+++ b/torch/utils/data/_utils/collate.py
@@ -53,10 +53,10 @@ def default_collate(batch):
if elem.shape == (): # scalars
py_type = float if elem.dtype.name.startswith('float') else int
return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
- elif isinstance(batch[0], int_classes):
- return torch.LongTensor(batch)
elif isinstance(batch[0], float):
- return torch.DoubleTensor(batch)
+ return torch.tensor(batch, dtype=torch.float64)
+ elif isinstance(batch[0], int_classes):
+ return torch.tensor(batch)
elif isinstance(batch[0], string_classes):
return batch
elif isinstance(batch[0], container_abcs.Mapping):