summaryrefslogtreecommitdiff
path: root/torch/csrc/cuda/comm.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'torch/csrc/cuda/comm.cpp')
-rw-r--r--torch/csrc/cuda/comm.cpp2
1 files changed, 1 insertions, 1 deletions
diff --git a/torch/csrc/cuda/comm.cpp b/torch/csrc/cuda/comm.cpp
index 53faa6baa5..a1743355bf 100644
--- a/torch/csrc/cuda/comm.cpp
+++ b/torch/csrc/cuda/comm.cpp
@@ -59,7 +59,7 @@ std::vector<Tensor> broadcast(const Tensor& tensor, IntArrayRef devices) {
tensors.push_back(tensor);
for (auto device : devices.slice(1)) {
_device_guard.set_index(device);
- tensors.push_back(at::empty(tensor.sizes(), type.options()));
+ tensors.push_back(at::empty(tensor.sizes(), type.options(tensor.scalar_type())));
}
nccl::broadcast(tensors);
} else {