diff options
author | SsnL <tongzhou.wang.1994@gmail.com> | 2018-12-28 11:51:26 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-12-28 12:26:46 -0800 |
commit | fb22f76eb6eae224469493b26d88b87fe5455fa7 (patch) | |
tree | dc774bb88cf648d78870d2f2376ab73ec0496eef /torch/utils | |
parent | 6a3e54eda90dd4bef003b35cb291621ea0b2a65d (diff) | |
download | pytorch-fb22f76eb6eae224469493b26d88b87fe5455fa7.tar.gz pytorch-fb22f76eb6eae224469493b26d88b87fe5455fa7.tar.bz2 pytorch-fb22f76eb6eae224469493b26d88b87fe5455fa7.zip |
default_collate should collate bool list to byte tensors (#14669)
Summary:
Based on #15331 . Review only the last commit.
Fixes https://github.com/pytorch/pytorch/issues/14507.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14669
Reviewed By: ezyang
Differential Revision: D13528725
Pulled By: soumith
fbshipit-source-id: f12f1ac1c4ff2a3ddd6877c0c096a5da3a1ffa3c
Diffstat (limited to 'torch/utils')
-rw-r--r-- | torch/utils/data/_utils/collate.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/torch/utils/data/_utils/collate.py b/torch/utils/data/_utils/collate.py index e0823b2397..fdc287f979 100644 --- a/torch/utils/data/_utils/collate.py +++ b/torch/utils/data/_utils/collate.py @@ -53,10 +53,10 @@ def default_collate(batch): if elem.shape == (): # scalars py_type = float if elem.dtype.name.startswith('float') else int return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) - elif isinstance(batch[0], int_classes): - return torch.LongTensor(batch) elif isinstance(batch[0], float): - return torch.DoubleTensor(batch) + return torch.tensor(batch, dtype=torch.float64) + elif isinstance(batch[0], int_classes): + return torch.tensor(batch) elif isinstance(batch[0], string_classes): return batch elif isinstance(batch[0], container_abcs.Mapping): |