diff options
author | Duc Ngo <duc@fb.com> | 2019-03-22 11:14:40 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-03-22 11:23:03 -0700 |
commit | 172ec4ace520a72729191b95b9f21c651aa5e245 (patch) | |
tree | 33c294862de560a1b7c0b6505e227f49595dd39a /caffe2 | |
parent | 7397eb7e8edb0bb1f3d467acaa5f1c5648d50901 (diff) | |
download | pytorch-172ec4ace520a72729191b95b9f21c651aa5e245.tar.gz pytorch-172ec4ace520a72729191b95b9f21c651aa5e245.tar.bz2 pytorch-172ec4ace520a72729191b95b9f21c651aa5e245.zip |
caffe2 - Util to cleanup external inputs and outputs from a NetDef (#18194)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18194
Add a util method to cleanup external inputs and outputs from a NetDef
The following conditions will be met after the modification
- No duplicate external inputs
- No duplicate external outputs
- Going through list of ops in order, all op inputs must be outputs
from other ops, or registered as external inputs.
- All external outputs must be outputs of some operators.
Reviewed By: ZolotukhinM
Differential Revision: D14528589
fbshipit-source-id: c8d82fda1946aa3696abcbec869a4a8bb22f09b6
Diffstat (limited to 'caffe2')
-rw-r--r-- | caffe2/utils/proto_utils.cc | 54 | ||||
-rw-r--r-- | caffe2/utils/proto_utils.h | 8 | ||||
-rw-r--r-- | caffe2/utils/proto_utils_test.cc | 35 |
3 files changed, 94 insertions, 3 deletions
diff --git a/caffe2/utils/proto_utils.cc b/caffe2/utils/proto_utils.cc index 213feb4300..40cc1b8f31 100644 --- a/caffe2/utils/proto_utils.cc +++ b/caffe2/utils/proto_utils.cc @@ -558,4 +558,56 @@ C10_EXPORT Argument* GetMutableArgument( } } -} // namespace caffe2 +C10_EXPORT void cleanupExternalInputsAndOutputs(NetDef* net) { + std::vector<std::string> oldExternalInputs; + for (const auto& input : net->external_input()) { + oldExternalInputs.emplace_back(input); + } + std::vector<std::string> oldExternalOutputs; + for (const auto& output : net->external_output()) { + oldExternalOutputs.emplace_back(output); + } + + net->clear_external_input(); + net->clear_external_output(); + + std::set<std::string> inputSet; + for (const auto& input : oldExternalInputs) { + if (inputSet.count(input)) { + // Prevent duplicate external inputs. + continue; + } + inputSet.insert(input); + net->add_external_input(input); + } + + // Set of blobs that are external inputs or outputs of some operators. + std::set<std::string> allOutputs(inputSet.begin(), inputSet.end()); + for (const auto& op : net->op()) { + for (const auto& input : op.input()) { + if (inputSet.count(input) || allOutputs.count(input)) { + continue; + } + // Add missing external inputs. + inputSet.insert(input); + net->add_external_input(input); + } + for (const auto& output : op.output()) { + allOutputs.insert(output); + } + } + + std::set<std::string> outputSet; + for (const auto& output : oldExternalOutputs) { + if (!allOutputs.count(output)) { + continue; + } + if (outputSet.count(output)) { + continue; + } + outputSet.insert(output); + net->add_external_output(output); + } +} + +} // namespace caffe2 diff --git a/caffe2/utils/proto_utils.h b/caffe2/utils/proto_utils.h index 963783673b..22ccc63590 100644 --- a/caffe2/utils/proto_utils.h +++ b/caffe2/utils/proto_utils.h @@ -329,6 +329,14 @@ bool inline operator==(const DeviceOption& dl, const DeviceOption& dr) { return IsSameDevice(dl, dr); } +// Given a net, modify the external inputs/outputs if necessary so that +// the following conditions are met +// - No duplicate external inputs +// - No duplicate external outputs +// - Going through list of ops in order, all op inputs must be outputs +// from other ops, or registered as external inputs. +// - All external outputs must be outputs of some operators. +CAFFE2_API void cleanupExternalInputsAndOutputs(NetDef* net); } // namespace caffe2 diff --git a/caffe2/utils/proto_utils_test.cc b/caffe2/utils/proto_utils_test.cc index 5d8fb86b34..1a687690c6 100644 --- a/caffe2/utils/proto_utils_test.cc +++ b/caffe2/utils/proto_utils_test.cc @@ -1,6 +1,8 @@ -#include "caffe2/utils/proto_utils.h" #include <gtest/gtest.h> +#include "caffe2/core/test_utils.h" +#include "caffe2/utils/proto_utils.h" + namespace caffe2 { TEST(ProtoUtilsTest, IsSameDevice) { @@ -29,4 +31,33 @@ TEST(ProtoUtilsTest, SimpleReadWrite) { EXPECT_EQ(content, read_back); } -} // namespace caffe2 +TEST(ProtoUtilsTest, CleanupExternalInputsAndOutputs) { + caffe2::NetDef net; + caffe2::testing::NetMutator(&net) + .newOp("op1", {"X1", "X2"}, {"Y"}) + .newOp("op2", {"W", "Y"}, {"Z1", "Z2"}) + .newOp("op3", {"Z2", "W"}, {"O"}) + .externalInputs({"X1", "X3", "X1", "W"}) + .externalOutputs({"O", "Z2", "Z3", "O", "X3"}); + cleanupExternalInputsAndOutputs(&net); + + std::vector<std::string> externalInputs; + for (const auto& inputName : net.external_input()) { + externalInputs.emplace_back(inputName); + } + // The 2nd X1 is removed because of duplication. + // X2 is added because it should be a missing external input. + std::vector<std::string> expectedExternalInputs{"X1", "X3", "W", "X2"}; + EXPECT_EQ(externalInputs, expectedExternalInputs); + + std::vector<std::string> externalOutputs; + for (const auto& outputName : net.external_output()) { + externalOutputs.emplace_back(outputName); + } + // Z3 is removed because it's not an output of any operator in the net. + // The 2nd O is removed because of duplication. + std::vector<std::string> expectedexternalOutputs{"O", "Z2", "X3"}; + EXPECT_EQ(externalOutputs, expectedexternalOutputs); +} + +} // namespace caffe2 |