From f6af76ead7f03b1e75a920d93c3d2d387f5eaef7 Mon Sep 17 00:00:00 2001 From: Roy Li Date: Sun, 7 Apr 2019 01:35:11 -0700 Subject: Remove tensorFromBlob() from Type (#18779) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18779 ghimport-source-id: e7453b74fcce0e4f4a9cbce0324992a85272a426 Stack from [ghstack](https://github.com/ezyang/ghstack): * #18780 Remove tensorWithAllocator() from Type * **#18779 Remove tensorFromBlob() from Type** Differential Revision: D14739335 fbshipit-source-id: 8a0619a5b412332efa3b2d60c1edebd53d089d50 --- caffe2/contrib/aten/aten_op_template.h | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) (limited to 'caffe2') diff --git a/caffe2/contrib/aten/aten_op_template.h b/caffe2/contrib/aten/aten_op_template.h index b597084939..8c5f02e5b8 100644 --- a/caffe2/contrib/aten/aten_op_template.h +++ b/caffe2/contrib/aten/aten_op_template.h @@ -54,18 +54,22 @@ private: #undef DEFINE_CASE } - at::Type& typeFor(const Tensor& ten) { - at::Backend b = backend(); + at::TensorOptions optionsFor(const Tensor& ten) { + at::Device device = ten.GetDevice(); #ifdef __HIP_PLATFORM_HCC__ - if (b == at::Backend::HIP) { - b = at::Backend::CUDA; + if (backend() == at::Backend::HIP) { + device = at::Device(kCUDA, device.index()); } #endif - return at::getNonVariableType(b, typeMetaToScalarType(ten.meta())); + return at::TensorOptions(device).dtype(ten.dtype()); } + at::Tensor tensorWrapping(const Tensor& ten_) { auto& ten = const_cast(ten_); - return typeFor(ten).tensorFromBlob(ten.raw_mutable_data(), ten.sizes()); + return at::from_blob( + ten.raw_mutable_data(), + ten.sizes(), + optionsFor(ten)); } at::Tensor peek(size_t i, size_t N) { -- cgit v1.2.3