summaryrefslogtreecommitdiff
path: root/caffe2
diff options
context:
space:
mode:
authorDuc Ngo <duc@fb.com>2019-03-22 11:14:40 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-03-22 11:23:03 -0700
commit172ec4ace520a72729191b95b9f21c651aa5e245 (patch)
tree33c294862de560a1b7c0b6505e227f49595dd39a /caffe2
parent7397eb7e8edb0bb1f3d467acaa5f1c5648d50901 (diff)
downloadpytorch-172ec4ace520a72729191b95b9f21c651aa5e245.tar.gz
pytorch-172ec4ace520a72729191b95b9f21c651aa5e245.tar.bz2
pytorch-172ec4ace520a72729191b95b9f21c651aa5e245.zip
caffe2 - Util to cleanup external inputs and outputs from a NetDef (#18194)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18194 Add a util method to cleanup external inputs and outputs from a NetDef The following conditions will be met after the modification - No duplicate external inputs - No duplicate external outputs - Going through list of ops in order, all op inputs must be outputs from other ops, or registered as external inputs. - All external outputs must be outputs of some operators. Reviewed By: ZolotukhinM Differential Revision: D14528589 fbshipit-source-id: c8d82fda1946aa3696abcbec869a4a8bb22f09b6
Diffstat (limited to 'caffe2')
-rw-r--r--caffe2/utils/proto_utils.cc54
-rw-r--r--caffe2/utils/proto_utils.h8
-rw-r--r--caffe2/utils/proto_utils_test.cc35
3 files changed, 94 insertions, 3 deletions
diff --git a/caffe2/utils/proto_utils.cc b/caffe2/utils/proto_utils.cc
index 213feb4300..40cc1b8f31 100644
--- a/caffe2/utils/proto_utils.cc
+++ b/caffe2/utils/proto_utils.cc
@@ -558,4 +558,56 @@ C10_EXPORT Argument* GetMutableArgument(
}
}
-} // namespace caffe2
+C10_EXPORT void cleanupExternalInputsAndOutputs(NetDef* net) {
+ std::vector<std::string> oldExternalInputs;
+ for (const auto& input : net->external_input()) {
+ oldExternalInputs.emplace_back(input);
+ }
+ std::vector<std::string> oldExternalOutputs;
+ for (const auto& output : net->external_output()) {
+ oldExternalOutputs.emplace_back(output);
+ }
+
+ net->clear_external_input();
+ net->clear_external_output();
+
+ std::set<std::string> inputSet;
+ for (const auto& input : oldExternalInputs) {
+ if (inputSet.count(input)) {
+ // Prevent duplicate external inputs.
+ continue;
+ }
+ inputSet.insert(input);
+ net->add_external_input(input);
+ }
+
+ // Set of blobs that are external inputs or outputs of some operators.
+ std::set<std::string> allOutputs(inputSet.begin(), inputSet.end());
+ for (const auto& op : net->op()) {
+ for (const auto& input : op.input()) {
+ if (inputSet.count(input) || allOutputs.count(input)) {
+ continue;
+ }
+ // Add missing external inputs.
+ inputSet.insert(input);
+ net->add_external_input(input);
+ }
+ for (const auto& output : op.output()) {
+ allOutputs.insert(output);
+ }
+ }
+
+ std::set<std::string> outputSet;
+ for (const auto& output : oldExternalOutputs) {
+ if (!allOutputs.count(output)) {
+ continue;
+ }
+ if (outputSet.count(output)) {
+ continue;
+ }
+ outputSet.insert(output);
+ net->add_external_output(output);
+ }
+}
+
+} // namespace caffe2
diff --git a/caffe2/utils/proto_utils.h b/caffe2/utils/proto_utils.h
index 963783673b..22ccc63590 100644
--- a/caffe2/utils/proto_utils.h
+++ b/caffe2/utils/proto_utils.h
@@ -329,6 +329,14 @@ bool inline operator==(const DeviceOption& dl, const DeviceOption& dr) {
return IsSameDevice(dl, dr);
}
+// Given a net, modify the external inputs/outputs if necessary so that
+// the following conditions are met
+// - No duplicate external inputs
+// - No duplicate external outputs
+// - Going through list of ops in order, all op inputs must be outputs
+// from other ops, or registered as external inputs.
+// - All external outputs must be outputs of some operators.
+CAFFE2_API void cleanupExternalInputsAndOutputs(NetDef* net);
} // namespace caffe2
diff --git a/caffe2/utils/proto_utils_test.cc b/caffe2/utils/proto_utils_test.cc
index 5d8fb86b34..1a687690c6 100644
--- a/caffe2/utils/proto_utils_test.cc
+++ b/caffe2/utils/proto_utils_test.cc
@@ -1,6 +1,8 @@
-#include "caffe2/utils/proto_utils.h"
#include <gtest/gtest.h>
+#include "caffe2/core/test_utils.h"
+#include "caffe2/utils/proto_utils.h"
+
namespace caffe2 {
TEST(ProtoUtilsTest, IsSameDevice) {
@@ -29,4 +31,33 @@ TEST(ProtoUtilsTest, SimpleReadWrite) {
EXPECT_EQ(content, read_back);
}
-} // namespace caffe2
+TEST(ProtoUtilsTest, CleanupExternalInputsAndOutputs) {
+ caffe2::NetDef net;
+ caffe2::testing::NetMutator(&net)
+ .newOp("op1", {"X1", "X2"}, {"Y"})
+ .newOp("op2", {"W", "Y"}, {"Z1", "Z2"})
+ .newOp("op3", {"Z2", "W"}, {"O"})
+ .externalInputs({"X1", "X3", "X1", "W"})
+ .externalOutputs({"O", "Z2", "Z3", "O", "X3"});
+ cleanupExternalInputsAndOutputs(&net);
+
+ std::vector<std::string> externalInputs;
+ for (const auto& inputName : net.external_input()) {
+ externalInputs.emplace_back(inputName);
+ }
+ // The 2nd X1 is removed because of duplication.
+ // X2 is added because it should be a missing external input.
+ std::vector<std::string> expectedExternalInputs{"X1", "X3", "W", "X2"};
+ EXPECT_EQ(externalInputs, expectedExternalInputs);
+
+ std::vector<std::string> externalOutputs;
+ for (const auto& outputName : net.external_output()) {
+ externalOutputs.emplace_back(outputName);
+ }
+ // Z3 is removed because it's not an output of any operator in the net.
+ // The 2nd O is removed because of duplication.
+ std::vector<std::string> expectedexternalOutputs{"O", "Z2", "X3"};
+ EXPECT_EQ(externalOutputs, expectedexternalOutputs);
+}
+
+} // namespace caffe2