summaryrefslogtreecommitdiff
path: root/binaries
diff options
context:
space:
mode:
authorJerry Zhang <jerryzh@fb.com>2018-09-05 16:13:54 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-09-05 16:28:09 -0700
commit9f4bcdf0752b688775ded8752c2db8d30e480fdf (patch)
treedc89973b5840cca39cf2f952422b4ba08aaa4050 /binaries
parentac9f0a68846bbef8489112a027174dffab4b3ae6 (diff)
downloadpytorch-9f4bcdf0752b688775ded8752c2db8d30e480fdf.tar.gz
pytorch-9f4bcdf0752b688775ded8752c2db8d30e480fdf.tar.bz2
pytorch-9f4bcdf0752b688775ded8752c2db8d30e480fdf.zip
caffe2::DeviceType -> at::DeviceType (#11254)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11254 Previously we use DeviceType in caffe2.proto directly, but it's an `enum` and have implicit conversion to int, which does not have type safety, e.g. we have to explicitly check for a device type is valid in event.h: ``` template <int d> struct EventCreateFunctionRegisterer { explicit EventCreateFunctionRegisterer(EventCreateFunction f) { static_assert(d < MaxDeviceTypes, ""); Event::event_creator_[d] = f; } }; ``` at::DeviceType is an `enum class`, and it does not have implicit conversion to int, and provides better type safety guarantees. In this diff we have done the following refactor(taking CPU as an example): 1. caffe2::DeviceType → caffe2::DeviceTypeProto 2. caffe2::CPU → caffe2::PROTO_CPU 3. caffe2::DeviceType = at::DeviceType 4. caffe2::CPU = at::DeviceType::CPU codemod -d caffe2/caffe2 --extensions h,cc,cpp 'device_type\(\), ' 'device_type(), PROTO_' + some manual changes In short, after this diff, in c++, caffe2::CPU refers to the at::DeviceType::CPU and the old proto caffe2::CPU will be caffe2::PROTO_CPU. In python side, we have a temporary workaround that alias `caffe2_pb2.CPU = caffe2_pb2.PROOT_CPU` to make the change easier to review and this will be removed later. Reviewed By: ezyang Differential Revision: D9545704 fbshipit-source-id: 461a28a4ca74e616d3ee183a607078a717fd38a7
Diffstat (limited to 'binaries')
-rw-r--r--binaries/benchmark_helper.cc2
-rw-r--r--binaries/core_overhead_benchmark_gpu.cc4
-rw-r--r--binaries/print_registered_core_operators.cc4
3 files changed, 5 insertions, 5 deletions
diff --git a/binaries/benchmark_helper.cc b/binaries/benchmark_helper.cc
index c5becbb2b4..255fa8ce4b 100644
--- a/binaries/benchmark_helper.cc
+++ b/binaries/benchmark_helper.cc
@@ -68,7 +68,7 @@ bool backendCudaSet(const string& backend) {
void setDeviceType(caffe2::NetDef* net_def, caffe2::DeviceType& run_dev) {
for (int j = 0; j < net_def->op_size(); j++) {
caffe2::OperatorDef* op = net_def->mutable_op(j);
- op->mutable_device_option()->set_device_type(run_dev);
+ op->mutable_device_option()->set_device_type(caffe2::TypeToProto(run_dev));
}
}
diff --git a/binaries/core_overhead_benchmark_gpu.cc b/binaries/core_overhead_benchmark_gpu.cc
index 5cb0a62797..018880432d 100644
--- a/binaries/core_overhead_benchmark_gpu.cc
+++ b/binaries/core_overhead_benchmark_gpu.cc
@@ -167,7 +167,7 @@ static void BM_OperatorCreationCPU(benchmark::State& state) {
OperatorDef def;
Workspace ws;
def.set_type("DummyEmpty");
- def.mutable_device_option()->set_device_type(CPU);
+ def.mutable_device_option()->set_device_type(PROTO_CPU);
while (state.KeepRunning()) {
op = CreateOperator(def, &ws);
}
@@ -180,7 +180,7 @@ static void BM_OperatorCreationCUDA(benchmark::State& state) {
OperatorDef def;
Workspace ws;
def.set_type("DummyEmpty");
- def.mutable_device_option()->set_device_type(CUDA);
+ def.mutable_device_option()->set_device_type(PROTO_CUDA);
while (state.KeepRunning()) {
op = CreateOperator(def, &ws);
}
diff --git a/binaries/print_registered_core_operators.cc b/binaries/print_registered_core_operators.cc
index c76ea3eaca..412ce88c44 100644
--- a/binaries/print_registered_core_operators.cc
+++ b/binaries/print_registered_core_operators.cc
@@ -52,8 +52,8 @@ int main(int argc, char** argv) {
for (const auto& pair : *caffe2::gDeviceTypeRegistry()) {
std::cout << "Device type " << pair.first
#ifndef CAFFE2_USE_LITE_PROTO
- << " (" << caffe2::DeviceType_Name(
- static_cast<caffe2::DeviceType>(pair.first))
+ << " ("
+ << at::DeviceTypeName(static_cast<caffe2::DeviceType>(pair.first))
<< ")"
#endif
<< std::endl;