diff options
author | gchanan <gregchanan@gmail.com> | 2018-04-19 22:03:25 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-04-19 22:03:25 -0400 |
commit | a4ab83045d992650381278dc03eac70b592789a0 (patch) | |
tree | d410d0cd2d21f01d2a5b7a539017c7a49b961c99 /torch/csrc | |
parent | 1a53e45558127755a151a8925be25c1f696b07e3 (diff) | |
download | pytorch-a4ab83045d992650381278dc03eac70b592789a0.tar.gz pytorch-a4ab83045d992650381278dc03eac70b592789a0.tar.bz2 pytorch-a4ab83045d992650381278dc03eac70b592789a0.zip |
Fix cross device indexing for more than 1 cuda device. (#6781)
* Fix cross device indexing for more than 1 cuda device.
Cross device indexing is attempted from ATen, which doesn't work well because ATen doesn't have AutoGPU, etc.
Instead, before dispatching to ATen we do type conversion on the indices; it would probably be better if we
pushed all this down to ATen, but that will take some work.
* Small cleanup.
Diffstat (limited to 'torch/csrc')
-rw-r--r-- | torch/csrc/autograd/python_variable_indexing.cpp | 22 |
1 files changed, 18 insertions, 4 deletions
diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp index 797141a115..02abe3e69d 100644 --- a/torch/csrc/autograd/python_variable_indexing.cpp +++ b/torch/csrc/autograd/python_variable_indexing.cpp @@ -10,6 +10,7 @@ #include "torch/csrc/utils/python_compat.h" #include "torch/csrc/utils/python_numbers.h" #include "torch/csrc/utils/tensor_new.h" +#include "torch/csrc/utils/tensor_conversion_dispatch.h" #include <ATen/ExpandUtils.h> #include <vector> @@ -168,20 +169,33 @@ static Variable applySlicing(const Variable& self, PyObject* index, variable_lis return result; } -static std::vector<Tensor> asTensorList(const variable_list& v) { - return std::vector<Tensor>(v.begin(), v.end()); +static std::vector<Tensor> typeConvertIndices(const Variable& self, const variable_list& indices) { + std::vector<Tensor> converted_inds(indices.size()); + int64_t device = self.is_cuda() ? self.get_device() : -1; + for (size_t i = 0; i < indices.size(); ++i) { + const auto &ind = indices[i]; + if (ind.defined()) { + auto& new_type = ind.type().toBackend(self.type().backend()); + converted_inds[i] = torch::utils::dispatch_type_conversion(ind, new_type, device, false); + } else { + converted_inds[i] = indices[i]; + } + } + return converted_inds; } static Variable dispatch_index(const Variable& self, const variable_list& indices) { + std::vector<Tensor> converted_indices = typeConvertIndices(self, indices); AutoNoGIL no_gil; AutoGPU auto_gpu(self); - return self.index(asTensorList(indices)); + return self.index(converted_indices); } static Variable dispatch_index_put_(Variable& self, const variable_list& indices, const Variable& value) { + std::vector<Tensor> converted_indices = typeConvertIndices(self, indices); AutoNoGIL no_gil; AutoGPU auto_gpu(self); - return self.index_put_(asTensorList(indices), value); + return self.index_put_(converted_indices, value); } static bool treatSequenceAsTuple(PyObject* index) { |