diff options
author | Bram Wasti <bwasti@fb.com> | 2018-09-28 14:06:08 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-09-28 14:11:51 -0700 |
commit | 60061a20d95751c11d9d1083defa097d74a1894f (patch) | |
tree | 39b5b22582d3fb8758945f106192897b44130fd5 /caffe2/opt | |
parent | 7b2c0a09e4795bb8dc06c3d1881289fbe75d84e2 (diff) | |
download | pytorch-60061a20d95751c11d9d1083defa097d74a1894f.tar.gz pytorch-60061a20d95751c11d9d1083defa097d74a1894f.tar.bz2 pytorch-60061a20d95751c11d9d1083defa097d74a1894f.zip |
Adding Declare and Export operators (#11954)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11954
Adding an alternative to external_input and external_output for use in some distributed settings
Reviewed By: aazzolini
Differential Revision: D9997121
fbshipit-source-id: 1b5cc03fd3051368a3edc69e7bc472386f5746b5
Diffstat (limited to 'caffe2/opt')
-rw-r--r-- | caffe2/opt/converter.cc | 44 | ||||
-rw-r--r-- | caffe2/opt/converter.h | 3 | ||||
-rw-r--r-- | caffe2/opt/converter_nomigraph_test.cc | 34 |
3 files changed, 81 insertions, 0 deletions
diff --git a/caffe2/opt/converter.cc b/caffe2/opt/converter.cc index 6a75207160..46fd8349b0 100644 --- a/caffe2/opt/converter.cc +++ b/caffe2/opt/converter.cc @@ -519,4 +519,48 @@ caffe2::NetDef convertToCaffe2Proto(repr::NNModule &m, const caffe2::NetDef& old return predictNet; } +void pushOpToFront(caffe2::OperatorDef& op, caffe2::NetDef* net) { + *net->add_op() = op; + google::protobuf::RepeatedPtrField<caffe2::OperatorDef>* op_list( + net->mutable_op()); + // Reverse iterate, swapping new element in front each time + for (int i(net->op_size() - 1); i > 0; --i) { + op_list->SwapElements(i, i - 1); + } +} + +void injectDataEdgeIndicators(caffe2::NetDef* net) { + for (const auto& input : net->external_input()) { + caffe2::OperatorDef op; + op.set_type("Declare"); + op.add_output(input); + pushOpToFront(op, net); + } + for (const auto& output : net->external_output()) { + caffe2::OperatorDef op; + op.set_type("Export"); + op.add_input(output); + *net->add_op() = op; + } + net->clear_external_input(); + net->clear_external_output(); +} + +void removeDataEdgeIndicators(caffe2::NetDef* net) { + google::protobuf::RepeatedPtrField<caffe2::OperatorDef>* op_list( + net->mutable_op()); + for (auto i = 0; i < net->op_size(); ++i) { + auto op = net->op(i); + if (op.type() == "Declare") { + net->add_external_input(op.output(0)); + } else if (op.type() == "Export") { + net->add_external_output(op.input(0)); + } else { + continue; + } + // Note that this compensates for modifying the list inplace + op_list->DeleteSubrange(i--, 1); + } +} + } // namespace caffe2 diff --git a/caffe2/opt/converter.h b/caffe2/opt/converter.h index f5933313c7..be0901ac64 100644 --- a/caffe2/opt/converter.h +++ b/caffe2/opt/converter.h @@ -13,6 +13,9 @@ namespace caffe2 { +void injectDataEdgeIndicators(caffe2::NetDef* net); +void removeDataEdgeIndicators(caffe2::NetDef* net); + CAFFE2_API nom::repr::NNModule convertToNNModule(caffe2::NetDef &net, bool strict = false); CAFFE2_API caffe2::NetDef convertToCaffe2Proto(nom::repr::NNModule&); diff --git a/caffe2/opt/converter_nomigraph_test.cc b/caffe2/opt/converter_nomigraph_test.cc index 995c9a5961..e9da69a42d 100644 --- a/caffe2/opt/converter_nomigraph_test.cc +++ b/caffe2/opt/converter_nomigraph_test.cc @@ -98,3 +98,37 @@ TEST(Converter, ExternalOutputs) { EXPECT_EQ(new_netdef.external_output(i), net.external_output(i)); } } + +TEST(Converter, InjectDataEdgeIndicators) { + auto net = fakeNet(); + caffe2::injectDataEdgeIndicators(&net); + + EXPECT_EQ(net.op_size(), 3 + 1 + 2); // Inserted 1 Declare and 2 Export + + auto declare_count = 0; + auto export_count = 0; + for (const auto& op : net.op()) { + declare_count += op.type() == "Declare"; + export_count += op.type() == "Export"; + } + EXPECT_EQ(declare_count, 1); + EXPECT_EQ(export_count, 2); + + // Remove them from the network + EXPECT_EQ(net.external_input_size(), 0); + EXPECT_EQ(net.external_output_size(), 0); + + // Ensure nomnigraph can handle this change + auto nn = caffe2::convertToNNModule(net); + auto new_net = caffe2::convertToCaffe2Proto(nn); + + caffe2::removeDataEdgeIndicators(&new_net); + + for (const auto& op : new_net.op()) { + EXPECT_NE(op.type(), "Declare"); + EXPECT_NE(op.type(), "Export"); + } + + EXPECT_EQ(new_net.external_input_size(), 1); + EXPECT_EQ(new_net.external_output_size(), 2); +} |