diff options
Diffstat (limited to 'torch/csrc/cuda/comm.cpp')
-rw-r--r-- | torch/csrc/cuda/comm.cpp | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/torch/csrc/cuda/comm.cpp b/torch/csrc/cuda/comm.cpp index 53faa6baa5..a1743355bf 100644 --- a/torch/csrc/cuda/comm.cpp +++ b/torch/csrc/cuda/comm.cpp @@ -59,7 +59,7 @@ std::vector<Tensor> broadcast(const Tensor& tensor, IntArrayRef devices) { tensors.push_back(tensor); for (auto device : devices.slice(1)) { _device_guard.set_index(device); - tensors.push_back(at::empty(tensor.sizes(), type.options())); + tensors.push_back(at::empty(tensor.sizes(), type.options(tensor.scalar_type()))); } nccl::broadcast(tensors); } else { |