summaryrefslogtreecommitdiff
path: root/caffe2/proto
diff options
context:
space:
mode:
authorJerry Zhang <jerryzh@fb.com>2018-10-01 11:02:11 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-10-01 11:10:46 -0700
commit006171fffc4f3fe7d8538f9f7a5b015d5bfc0332 (patch)
tree60908329ef6026751552d49ff7b9e0725ca7e593 /caffe2/proto
parentfed91f873fc73a9ec4d212a0d1abd3fc966eacc0 (diff)
downloadpytorch-006171fffc4f3fe7d8538f9f7a5b015d5bfc0332.tar.gz
pytorch-006171fffc4f3fe7d8538f9f7a5b015d5bfc0332.tar.bz2
pytorch-006171fffc4f3fe7d8538f9f7a5b015d5bfc0332.zip
Back out "[pytorch][PR] Revert "Move CreateContext to global registry (#11688)"" (#12121)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/12121 Pull Request resolved: https://github.com/pytorch/pytorch/pull/12055 Original commit changeset: 6ca9de65b707 Reviewed By: ezyang Differential Revision: D10033396 fbshipit-source-id: ca9f4b2f7ef0561f619b833415d394a8b9972bf4
Diffstat (limited to 'caffe2/proto')
-rw-r--r--caffe2/proto/caffe2_pb.h19
1 files changed, 18 insertions, 1 deletions
diff --git a/caffe2/proto/caffe2_pb.h b/caffe2/proto/caffe2_pb.h
index 0a08c8db24..e0eb8e8dcd 100644
--- a/caffe2/proto/caffe2_pb.h
+++ b/caffe2/proto/caffe2_pb.h
@@ -1,5 +1,5 @@
#pragma once
-#include <ATen/core/DeviceType.h>
+#include <ATen/core/Device.h>
#include <ATen/core/Error.h>
#include <caffe2/proto/caffe2.pb.h>
@@ -47,6 +47,10 @@ inline CAFFE2_API DeviceType ProtoToType(const caffe2::DeviceTypeProto p) {
}
}
+inline CAFFE2_API DeviceType ProtoToType(int p) {
+ return ProtoToType(static_cast<caffe2::DeviceTypeProto>(p));
+}
+
inline CAFFE2_API DeviceTypeProto TypeToProto(const DeviceType& t) {
switch (t) {
case DeviceType::CPU:
@@ -77,4 +81,17 @@ inline CAFFE2_API DeviceTypeProto TypeToProto(const DeviceType& t) {
}
}
+inline CAFFE2_API caffe2::DeviceOption DeviceToOption(
+ const at::Device& device) {
+ caffe2::DeviceOption option;
+ auto type = device.type();
+ option.set_device_type(TypeToProto(type));
+ option.set_device_id(device.index());
+ return option;
+}
+
+inline CAFFE2_API at::Device OptionToDevice(const caffe2::DeviceOption option) {
+ return at::Device(ProtoToType(option.device_type()), option.device_id());
+}
+
} // namespace caffe2