diff options
author | Jerry Zhang <jerryzh@fb.com> | 2018-09-05 16:13:54 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-09-05 16:28:09 -0700 |
commit | 9f4bcdf0752b688775ded8752c2db8d30e480fdf (patch) | |
tree | dc89973b5840cca39cf2f952422b4ba08aaa4050 /binaries | |
parent | ac9f0a68846bbef8489112a027174dffab4b3ae6 (diff) | |
download | pytorch-9f4bcdf0752b688775ded8752c2db8d30e480fdf.tar.gz pytorch-9f4bcdf0752b688775ded8752c2db8d30e480fdf.tar.bz2 pytorch-9f4bcdf0752b688775ded8752c2db8d30e480fdf.zip |
caffe2::DeviceType -> at::DeviceType (#11254)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11254
Previously we use DeviceType in caffe2.proto directly, but it's an `enum` and have implicit conversion to int, which does not have type safety, e.g. we have to explicitly check for a device type is valid in event.h:
```
template <int d>
struct EventCreateFunctionRegisterer {
explicit EventCreateFunctionRegisterer(EventCreateFunction f) {
static_assert(d < MaxDeviceTypes, "");
Event::event_creator_[d] = f;
}
};
```
at::DeviceType is an `enum class`, and it does not have implicit conversion to int, and provides better type safety guarantees. In this diff we have done the following refactor(taking CPU as an example):
1. caffe2::DeviceType → caffe2::DeviceTypeProto
2. caffe2::CPU → caffe2::PROTO_CPU
3. caffe2::DeviceType = at::DeviceType
4. caffe2::CPU = at::DeviceType::CPU
codemod -d caffe2/caffe2 --extensions h,cc,cpp 'device_type\(\), ' 'device_type(), PROTO_'
+ some manual changes
In short, after this diff, in c++, caffe2::CPU refers to the at::DeviceType::CPU and the old proto caffe2::CPU will be caffe2::PROTO_CPU.
In python side, we have a temporary workaround that alias `caffe2_pb2.CPU = caffe2_pb2.PROOT_CPU` to make the change easier to review and this will be removed later.
Reviewed By: ezyang
Differential Revision: D9545704
fbshipit-source-id: 461a28a4ca74e616d3ee183a607078a717fd38a7
Diffstat (limited to 'binaries')
-rw-r--r-- | binaries/benchmark_helper.cc | 2 | ||||
-rw-r--r-- | binaries/core_overhead_benchmark_gpu.cc | 4 | ||||
-rw-r--r-- | binaries/print_registered_core_operators.cc | 4 |
3 files changed, 5 insertions, 5 deletions
diff --git a/binaries/benchmark_helper.cc b/binaries/benchmark_helper.cc index c5becbb2b4..255fa8ce4b 100644 --- a/binaries/benchmark_helper.cc +++ b/binaries/benchmark_helper.cc @@ -68,7 +68,7 @@ bool backendCudaSet(const string& backend) { void setDeviceType(caffe2::NetDef* net_def, caffe2::DeviceType& run_dev) { for (int j = 0; j < net_def->op_size(); j++) { caffe2::OperatorDef* op = net_def->mutable_op(j); - op->mutable_device_option()->set_device_type(run_dev); + op->mutable_device_option()->set_device_type(caffe2::TypeToProto(run_dev)); } } diff --git a/binaries/core_overhead_benchmark_gpu.cc b/binaries/core_overhead_benchmark_gpu.cc index 5cb0a62797..018880432d 100644 --- a/binaries/core_overhead_benchmark_gpu.cc +++ b/binaries/core_overhead_benchmark_gpu.cc @@ -167,7 +167,7 @@ static void BM_OperatorCreationCPU(benchmark::State& state) { OperatorDef def; Workspace ws; def.set_type("DummyEmpty"); - def.mutable_device_option()->set_device_type(CPU); + def.mutable_device_option()->set_device_type(PROTO_CPU); while (state.KeepRunning()) { op = CreateOperator(def, &ws); } @@ -180,7 +180,7 @@ static void BM_OperatorCreationCUDA(benchmark::State& state) { OperatorDef def; Workspace ws; def.set_type("DummyEmpty"); - def.mutable_device_option()->set_device_type(CUDA); + def.mutable_device_option()->set_device_type(PROTO_CUDA); while (state.KeepRunning()) { op = CreateOperator(def, &ws); } diff --git a/binaries/print_registered_core_operators.cc b/binaries/print_registered_core_operators.cc index c76ea3eaca..412ce88c44 100644 --- a/binaries/print_registered_core_operators.cc +++ b/binaries/print_registered_core_operators.cc @@ -52,8 +52,8 @@ int main(int argc, char** argv) { for (const auto& pair : *caffe2::gDeviceTypeRegistry()) { std::cout << "Device type " << pair.first #ifndef CAFFE2_USE_LITE_PROTO - << " (" << caffe2::DeviceType_Name( - static_cast<caffe2::DeviceType>(pair.first)) + << " (" + << at::DeviceTypeName(static_cast<caffe2::DeviceType>(pair.first)) << ")" #endif << std::endl; |