summaryrefslogtreecommitdiff
path: root/caffe2/opt
diff options
context:
space:
mode:
authorBram Wasti <bwasti@fb.com>2018-09-28 14:06:08 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-09-28 14:11:51 -0700
commit60061a20d95751c11d9d1083defa097d74a1894f (patch)
tree39b5b22582d3fb8758945f106192897b44130fd5 /caffe2/opt
parent7b2c0a09e4795bb8dc06c3d1881289fbe75d84e2 (diff)
downloadpytorch-60061a20d95751c11d9d1083defa097d74a1894f.tar.gz
pytorch-60061a20d95751c11d9d1083defa097d74a1894f.tar.bz2
pytorch-60061a20d95751c11d9d1083defa097d74a1894f.zip
Adding Declare and Export operators (#11954)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11954 Adding an alternative to external_input and external_output for use in some distributed settings Reviewed By: aazzolini Differential Revision: D9997121 fbshipit-source-id: 1b5cc03fd3051368a3edc69e7bc472386f5746b5
Diffstat (limited to 'caffe2/opt')
-rw-r--r--caffe2/opt/converter.cc44
-rw-r--r--caffe2/opt/converter.h3
-rw-r--r--caffe2/opt/converter_nomigraph_test.cc34
3 files changed, 81 insertions, 0 deletions
diff --git a/caffe2/opt/converter.cc b/caffe2/opt/converter.cc
index 6a75207160..46fd8349b0 100644
--- a/caffe2/opt/converter.cc
+++ b/caffe2/opt/converter.cc
@@ -519,4 +519,48 @@ caffe2::NetDef convertToCaffe2Proto(repr::NNModule &m, const caffe2::NetDef& old
return predictNet;
}
+void pushOpToFront(caffe2::OperatorDef& op, caffe2::NetDef* net) {
+ *net->add_op() = op;
+ google::protobuf::RepeatedPtrField<caffe2::OperatorDef>* op_list(
+ net->mutable_op());
+ // Reverse iterate, swapping new element in front each time
+ for (int i(net->op_size() - 1); i > 0; --i) {
+ op_list->SwapElements(i, i - 1);
+ }
+}
+
+void injectDataEdgeIndicators(caffe2::NetDef* net) {
+ for (const auto& input : net->external_input()) {
+ caffe2::OperatorDef op;
+ op.set_type("Declare");
+ op.add_output(input);
+ pushOpToFront(op, net);
+ }
+ for (const auto& output : net->external_output()) {
+ caffe2::OperatorDef op;
+ op.set_type("Export");
+ op.add_input(output);
+ *net->add_op() = op;
+ }
+ net->clear_external_input();
+ net->clear_external_output();
+}
+
+void removeDataEdgeIndicators(caffe2::NetDef* net) {
+ google::protobuf::RepeatedPtrField<caffe2::OperatorDef>* op_list(
+ net->mutable_op());
+ for (auto i = 0; i < net->op_size(); ++i) {
+ auto op = net->op(i);
+ if (op.type() == "Declare") {
+ net->add_external_input(op.output(0));
+ } else if (op.type() == "Export") {
+ net->add_external_output(op.input(0));
+ } else {
+ continue;
+ }
+ // Note that this compensates for modifying the list inplace
+ op_list->DeleteSubrange(i--, 1);
+ }
+}
+
} // namespace caffe2
diff --git a/caffe2/opt/converter.h b/caffe2/opt/converter.h
index f5933313c7..be0901ac64 100644
--- a/caffe2/opt/converter.h
+++ b/caffe2/opt/converter.h
@@ -13,6 +13,9 @@
namespace caffe2 {
+void injectDataEdgeIndicators(caffe2::NetDef* net);
+void removeDataEdgeIndicators(caffe2::NetDef* net);
+
CAFFE2_API nom::repr::NNModule convertToNNModule(caffe2::NetDef &net, bool strict = false);
CAFFE2_API caffe2::NetDef convertToCaffe2Proto(nom::repr::NNModule&);
diff --git a/caffe2/opt/converter_nomigraph_test.cc b/caffe2/opt/converter_nomigraph_test.cc
index 995c9a5961..e9da69a42d 100644
--- a/caffe2/opt/converter_nomigraph_test.cc
+++ b/caffe2/opt/converter_nomigraph_test.cc
@@ -98,3 +98,37 @@ TEST(Converter, ExternalOutputs) {
EXPECT_EQ(new_netdef.external_output(i), net.external_output(i));
}
}
+
+TEST(Converter, InjectDataEdgeIndicators) {
+ auto net = fakeNet();
+ caffe2::injectDataEdgeIndicators(&net);
+
+ EXPECT_EQ(net.op_size(), 3 + 1 + 2); // Inserted 1 Declare and 2 Export
+
+ auto declare_count = 0;
+ auto export_count = 0;
+ for (const auto& op : net.op()) {
+ declare_count += op.type() == "Declare";
+ export_count += op.type() == "Export";
+ }
+ EXPECT_EQ(declare_count, 1);
+ EXPECT_EQ(export_count, 2);
+
+ // Remove them from the network
+ EXPECT_EQ(net.external_input_size(), 0);
+ EXPECT_EQ(net.external_output_size(), 0);
+
+ // Ensure nomnigraph can handle this change
+ auto nn = caffe2::convertToNNModule(net);
+ auto new_net = caffe2::convertToCaffe2Proto(nn);
+
+ caffe2::removeDataEdgeIndicators(&new_net);
+
+ for (const auto& op : new_net.op()) {
+ EXPECT_NE(op.type(), "Declare");
+ EXPECT_NE(op.type(), "Export");
+ }
+
+ EXPECT_EQ(new_net.external_input_size(), 1);
+ EXPECT_EQ(new_net.external_output_size(), 2);
+}