summaryrefslogtreecommitdiff
path: root/caffe2
diff options
context:
space:
mode:
Diffstat (limited to 'caffe2')
-rw-r--r--caffe2/contrib/aten/aten_op_template.h16
1 files changed, 10 insertions, 6 deletions
diff --git a/caffe2/contrib/aten/aten_op_template.h b/caffe2/contrib/aten/aten_op_template.h
index b597084939..8c5f02e5b8 100644
--- a/caffe2/contrib/aten/aten_op_template.h
+++ b/caffe2/contrib/aten/aten_op_template.h
@@ -54,18 +54,22 @@ private:
#undef DEFINE_CASE
}
- at::Type& typeFor(const Tensor& ten) {
- at::Backend b = backend();
+ at::TensorOptions optionsFor(const Tensor& ten) {
+ at::Device device = ten.GetDevice();
#ifdef __HIP_PLATFORM_HCC__
- if (b == at::Backend::HIP) {
- b = at::Backend::CUDA;
+ if (backend() == at::Backend::HIP) {
+ device = at::Device(kCUDA, device.index());
}
#endif
- return at::getNonVariableType(b, typeMetaToScalarType(ten.meta()));
+ return at::TensorOptions(device).dtype(ten.dtype());
}
+
at::Tensor tensorWrapping(const Tensor& ten_) {
auto& ten = const_cast<Tensor&>(ten_);
- return typeFor(ten).tensorFromBlob(ten.raw_mutable_data(), ten.sizes());
+ return at::from_blob(
+ ten.raw_mutable_data(),
+ ten.sizes(),
+ optionsFor(ten));
}
at::Tensor peek(size_t i, size_t N) {