diff options
Diffstat (limited to 'caffe2')
-rw-r--r-- | caffe2/contrib/aten/aten_op_template.h | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/caffe2/contrib/aten/aten_op_template.h b/caffe2/contrib/aten/aten_op_template.h index b597084939..8c5f02e5b8 100644 --- a/caffe2/contrib/aten/aten_op_template.h +++ b/caffe2/contrib/aten/aten_op_template.h @@ -54,18 +54,22 @@ private: #undef DEFINE_CASE } - at::Type& typeFor(const Tensor& ten) { - at::Backend b = backend(); + at::TensorOptions optionsFor(const Tensor& ten) { + at::Device device = ten.GetDevice(); #ifdef __HIP_PLATFORM_HCC__ - if (b == at::Backend::HIP) { - b = at::Backend::CUDA; + if (backend() == at::Backend::HIP) { + device = at::Device(kCUDA, device.index()); } #endif - return at::getNonVariableType(b, typeMetaToScalarType(ten.meta())); + return at::TensorOptions(device).dtype(ten.dtype()); } + at::Tensor tensorWrapping(const Tensor& ten_) { auto& ten = const_cast<Tensor&>(ten_); - return typeFor(ten).tensorFromBlob(ten.raw_mutable_data(), ten.sizes()); + return at::from_blob( + ten.raw_mutable_data(), + ten.sizes(), + optionsFor(ten)); } at::Tensor peek(size_t i, size_t N) { |