diff options
author | Roy Li <royboy@fb.com> | 2019-04-07 01:35:11 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-07 01:37:43 -0700 |
commit | f6af76ead7f03b1e75a920d93c3d2d387f5eaef7 (patch) | |
tree | f4cc07d4ec7c5b37ea40e9d6bb585cdd4356a833 /caffe2 | |
parent | 9b69f21a95fa626522ef371f8557e7286f9db318 (diff) | |
download | pytorch-f6af76ead7f03b1e75a920d93c3d2d387f5eaef7.tar.gz pytorch-f6af76ead7f03b1e75a920d93c3d2d387f5eaef7.tar.bz2 pytorch-f6af76ead7f03b1e75a920d93c3d2d387f5eaef7.zip |
Remove tensorFromBlob() from Type (#18779)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18779
ghimport-source-id: e7453b74fcce0e4f4a9cbce0324992a85272a426
Stack from [ghstack](https://github.com/ezyang/ghstack):
* #18780 Remove tensorWithAllocator() from Type
* **#18779 Remove tensorFromBlob() from Type**
Differential Revision: D14739335
fbshipit-source-id: 8a0619a5b412332efa3b2d60c1edebd53d089d50
Diffstat (limited to 'caffe2')
-rw-r--r-- | caffe2/contrib/aten/aten_op_template.h | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/caffe2/contrib/aten/aten_op_template.h b/caffe2/contrib/aten/aten_op_template.h index b597084939..8c5f02e5b8 100644 --- a/caffe2/contrib/aten/aten_op_template.h +++ b/caffe2/contrib/aten/aten_op_template.h @@ -54,18 +54,22 @@ private: #undef DEFINE_CASE } - at::Type& typeFor(const Tensor& ten) { - at::Backend b = backend(); + at::TensorOptions optionsFor(const Tensor& ten) { + at::Device device = ten.GetDevice(); #ifdef __HIP_PLATFORM_HCC__ - if (b == at::Backend::HIP) { - b = at::Backend::CUDA; + if (backend() == at::Backend::HIP) { + device = at::Device(kCUDA, device.index()); } #endif - return at::getNonVariableType(b, typeMetaToScalarType(ten.meta())); + return at::TensorOptions(device).dtype(ten.dtype()); } + at::Tensor tensorWrapping(const Tensor& ten_) { auto& ten = const_cast<Tensor&>(ten_); - return typeFor(ten).tensorFromBlob(ten.raw_mutable_data(), ten.sizes()); + return at::from_blob( + ten.raw_mutable_data(), + ten.sizes(), + optionsFor(ten)); } at::Tensor peek(size_t i, size_t N) { |