diff options
author | Jerry Zhang <jerryzh@fb.com> | 2018-09-05 16:13:54 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-09-05 16:28:09 -0700 |
commit | 9f4bcdf0752b688775ded8752c2db8d30e480fdf (patch) | |
tree | dc89973b5840cca39cf2f952422b4ba08aaa4050 /caffe2/transforms | |
parent | ac9f0a68846bbef8489112a027174dffab4b3ae6 (diff) | |
download | pytorch-9f4bcdf0752b688775ded8752c2db8d30e480fdf.tar.gz pytorch-9f4bcdf0752b688775ded8752c2db8d30e480fdf.tar.bz2 pytorch-9f4bcdf0752b688775ded8752c2db8d30e480fdf.zip |
caffe2::DeviceType -> at::DeviceType (#11254)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11254
Previously we use DeviceType in caffe2.proto directly, but it's an `enum` and have implicit conversion to int, which does not have type safety, e.g. we have to explicitly check for a device type is valid in event.h:
```
template <int d>
struct EventCreateFunctionRegisterer {
explicit EventCreateFunctionRegisterer(EventCreateFunction f) {
static_assert(d < MaxDeviceTypes, "");
Event::event_creator_[d] = f;
}
};
```
at::DeviceType is an `enum class`, and it does not have implicit conversion to int, and provides better type safety guarantees. In this diff we have done the following refactor(taking CPU as an example):
1. caffe2::DeviceType → caffe2::DeviceTypeProto
2. caffe2::CPU → caffe2::PROTO_CPU
3. caffe2::DeviceType = at::DeviceType
4. caffe2::CPU = at::DeviceType::CPU
codemod -d caffe2/caffe2 --extensions h,cc,cpp 'device_type\(\), ' 'device_type(), PROTO_'
+ some manual changes
In short, after this diff, in c++, caffe2::CPU refers to the at::DeviceType::CPU and the old proto caffe2::CPU will be caffe2::PROTO_CPU.
In python side, we have a temporary workaround that alias `caffe2_pb2.CPU = caffe2_pb2.PROOT_CPU` to make the change easier to review and this will be removed later.
Reviewed By: ezyang
Differential Revision: D9545704
fbshipit-source-id: 461a28a4ca74e616d3ee183a607078a717fd38a7
Diffstat (limited to 'caffe2/transforms')
-rw-r--r-- | caffe2/transforms/conv_to_nnpack_transform.h | 2 | ||||
-rw-r--r-- | caffe2/transforms/conv_to_nnpack_transform_test.cc | 4 | ||||
-rw-r--r-- | caffe2/transforms/pattern_net_transform_test.cc | 12 |
3 files changed, 9 insertions, 9 deletions
diff --git a/caffe2/transforms/conv_to_nnpack_transform.h b/caffe2/transforms/conv_to_nnpack_transform.h index 5d61be232c..8563732f22 100644 --- a/caffe2/transforms/conv_to_nnpack_transform.h +++ b/caffe2/transforms/conv_to_nnpack_transform.h @@ -12,7 +12,7 @@ class CAFFE2_API ConvToNNPackTransform : public SingleOpTransform { // Specify what the op needs to be to match the pattern. bool MatchOperator(const OperatorDef& op) override { return ( - op.type() == "Conv" && op.device_option().device_type() == CPU && + op.type() == "Conv" && op.device_option().device_type() == PROTO_CPU && op.engine() != "NNPACK"); } diff --git a/caffe2/transforms/conv_to_nnpack_transform_test.cc b/caffe2/transforms/conv_to_nnpack_transform_test.cc index 4ab80fccd9..92b5f1ade2 100644 --- a/caffe2/transforms/conv_to_nnpack_transform_test.cc +++ b/caffe2/transforms/conv_to_nnpack_transform_test.cc @@ -15,7 +15,7 @@ TEST(ConvToNNPackTest, TestSimple) { op = AddOp(&netdef, "Conv", {"in"}, {"out"}); op = AddOp(&netdef, "Relu", {"out"}, {"out"}); op = AddOp(&netdef, "Conv", {"out"}, {"out"}); // if not CPU, won't transform - op->mutable_device_option()->set_device_type(CUDA); + op->mutable_device_option()->set_device_type(PROTO_CUDA); op = AddOp(&netdef, "Relu", {"out"}, {"out"}); op = AddOp(&netdef, "Conv", {"out"}, {"out"}); op->set_engine("NNPACK"); // does not need to be transformed @@ -28,7 +28,7 @@ TEST(ConvToNNPackTest, TestSimple) { int nnpack_count = 0; for (auto& op : transformed_netdef.op()) { - if (op.type() == "Conv" && op.device_option().device_type() == CPU) { + if (op.type() == "Conv" && op.device_option().device_type() == PROTO_CPU) { EXPECT_EQ(op.engine(), "NNPACK"); nnpack_count++; } diff --git a/caffe2/transforms/pattern_net_transform_test.cc b/caffe2/transforms/pattern_net_transform_test.cc index 36925d9d43..8ac21af540 100644 --- a/caffe2/transforms/pattern_net_transform_test.cc +++ b/caffe2/transforms/pattern_net_transform_test.cc @@ -250,19 +250,19 @@ TEST(PatternNetTransformTest, TestDeviceOptionMatching) { NetDef pdef; auto op = AddOp(&pdef, "DummyOp1", {"in"}, {"out"}); - op->mutable_device_option()->set_device_type(CPU); + op->mutable_device_option()->set_device_type(PROTO_CPU); NetDef rdef; op = AddOp(&rdef, "DummyOp1", {"in"}, {"out"}); - op->mutable_device_option()->set_device_type(CUDA); + op->mutable_device_option()->set_device_type(PROTO_CUDA); NetDef netdef; op = AddOp(&netdef, "DummyOp1", {"in"}, {"mid"}); - op->mutable_device_option()->set_device_type(CPU); + op->mutable_device_option()->set_device_type(PROTO_CPU); op = AddOp(&netdef, "DummyOp1", {"mid"}, {"mid"}); // should not match - op->mutable_device_option()->set_device_type(CUDA); + op->mutable_device_option()->set_device_type(PROTO_CUDA); op = AddOp(&netdef, "DummyOp1", {"mid"}, {"out"}); - op->mutable_device_option()->set_device_type(CPU); + op->mutable_device_option()->set_device_type(PROTO_CPU); PatternNetTransform t(pdef, rdef); transform::Graph g(netdef); @@ -272,7 +272,7 @@ TEST(PatternNetTransformTest, TestDeviceOptionMatching) { NetDef transformed_net = t.ApplyTo(netdef); for (const auto& opdef : transformed_net.op()) { EXPECT_TRUE(opdef.has_device_option()); - EXPECT_EQ(opdef.device_option().device_type(), CUDA); + EXPECT_EQ(opdef.device_option().device_type(), PROTO_CUDA); } } |