summaryrefslogtreecommitdiff
path: root/torch/csrc
diff options
context:
space:
mode:
authorgchanan <gregchanan@gmail.com>2018-04-19 22:03:25 -0400
committerGitHub <noreply@github.com>2018-04-19 22:03:25 -0400
commita4ab83045d992650381278dc03eac70b592789a0 (patch)
treed410d0cd2d21f01d2a5b7a539017c7a49b961c99 /torch/csrc
parent1a53e45558127755a151a8925be25c1f696b07e3 (diff)
downloadpytorch-a4ab83045d992650381278dc03eac70b592789a0.tar.gz
pytorch-a4ab83045d992650381278dc03eac70b592789a0.tar.bz2
pytorch-a4ab83045d992650381278dc03eac70b592789a0.zip
Fix cross device indexing for more than 1 cuda device. (#6781)
* Fix cross device indexing for more than 1 cuda device. Cross device indexing is attempted from ATen, which doesn't work well because ATen doesn't have AutoGPU, etc. Instead, before dispatching to ATen we do type conversion on the indices; it would probably be better if we pushed all this down to ATen, but that will take some work. * Small cleanup.
Diffstat (limited to 'torch/csrc')
-rw-r--r--torch/csrc/autograd/python_variable_indexing.cpp22
1 files changed, 18 insertions, 4 deletions
diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp
index 797141a115..02abe3e69d 100644
--- a/torch/csrc/autograd/python_variable_indexing.cpp
+++ b/torch/csrc/autograd/python_variable_indexing.cpp
@@ -10,6 +10,7 @@
#include "torch/csrc/utils/python_compat.h"
#include "torch/csrc/utils/python_numbers.h"
#include "torch/csrc/utils/tensor_new.h"
+#include "torch/csrc/utils/tensor_conversion_dispatch.h"
#include <ATen/ExpandUtils.h>
#include <vector>
@@ -168,20 +169,33 @@ static Variable applySlicing(const Variable& self, PyObject* index, variable_lis
return result;
}
-static std::vector<Tensor> asTensorList(const variable_list& v) {
- return std::vector<Tensor>(v.begin(), v.end());
+static std::vector<Tensor> typeConvertIndices(const Variable& self, const variable_list& indices) {
+ std::vector<Tensor> converted_inds(indices.size());
+ int64_t device = self.is_cuda() ? self.get_device() : -1;
+ for (size_t i = 0; i < indices.size(); ++i) {
+ const auto &ind = indices[i];
+ if (ind.defined()) {
+ auto& new_type = ind.type().toBackend(self.type().backend());
+ converted_inds[i] = torch::utils::dispatch_type_conversion(ind, new_type, device, false);
+ } else {
+ converted_inds[i] = indices[i];
+ }
+ }
+ return converted_inds;
}
static Variable dispatch_index(const Variable& self, const variable_list& indices) {
+ std::vector<Tensor> converted_indices = typeConvertIndices(self, indices);
AutoNoGIL no_gil;
AutoGPU auto_gpu(self);
- return self.index(asTensorList(indices));
+ return self.index(converted_indices);
}
static Variable dispatch_index_put_(Variable& self, const variable_list& indices, const Variable& value) {
+ std::vector<Tensor> converted_indices = typeConvertIndices(self, indices);
AutoNoGIL no_gil;
AutoGPU auto_gpu(self);
- return self.index_put_(asTensorList(indices), value);
+ return self.index_put_(converted_indices, value);
}
static bool treatSequenceAsTuple(PyObject* index) {