summaryrefslogtreecommitdiff
path: root/caffe2
diff options
context:
space:
mode:
authorRoy Li <royboy@fb.com>2019-04-07 01:35:11 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-07 01:37:43 -0700
commitf6af76ead7f03b1e75a920d93c3d2d387f5eaef7 (patch)
treef4cc07d4ec7c5b37ea40e9d6bb585cdd4356a833 /caffe2
parent9b69f21a95fa626522ef371f8557e7286f9db318 (diff)
downloadpytorch-f6af76ead7f03b1e75a920d93c3d2d387f5eaef7.tar.gz
pytorch-f6af76ead7f03b1e75a920d93c3d2d387f5eaef7.tar.bz2
pytorch-f6af76ead7f03b1e75a920d93c3d2d387f5eaef7.zip
Remove tensorFromBlob() from Type (#18779)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18779 ghimport-source-id: e7453b74fcce0e4f4a9cbce0324992a85272a426 Stack from [ghstack](https://github.com/ezyang/ghstack): * #18780 Remove tensorWithAllocator() from Type * **#18779 Remove tensorFromBlob() from Type** Differential Revision: D14739335 fbshipit-source-id: 8a0619a5b412332efa3b2d60c1edebd53d089d50
Diffstat (limited to 'caffe2')
-rw-r--r--caffe2/contrib/aten/aten_op_template.h16
1 files changed, 10 insertions, 6 deletions
diff --git a/caffe2/contrib/aten/aten_op_template.h b/caffe2/contrib/aten/aten_op_template.h
index b597084939..8c5f02e5b8 100644
--- a/caffe2/contrib/aten/aten_op_template.h
+++ b/caffe2/contrib/aten/aten_op_template.h
@@ -54,18 +54,22 @@ private:
#undef DEFINE_CASE
}
- at::Type& typeFor(const Tensor& ten) {
- at::Backend b = backend();
+ at::TensorOptions optionsFor(const Tensor& ten) {
+ at::Device device = ten.GetDevice();
#ifdef __HIP_PLATFORM_HCC__
- if (b == at::Backend::HIP) {
- b = at::Backend::CUDA;
+ if (backend() == at::Backend::HIP) {
+ device = at::Device(kCUDA, device.index());
}
#endif
- return at::getNonVariableType(b, typeMetaToScalarType(ten.meta()));
+ return at::TensorOptions(device).dtype(ten.dtype());
}
+
at::Tensor tensorWrapping(const Tensor& ten_) {
auto& ten = const_cast<Tensor&>(ten_);
- return typeFor(ten).tensorFromBlob(ten.raw_mutable_data(), ten.sizes());
+ return at::from_blob(
+ ten.raw_mutable_data(),
+ ten.sizes(),
+ optionsFor(ten));
}
at::Tensor peek(size_t i, size_t N) {