summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
authorMikhail Zolotukhin <mvz@fb.com>2019-02-22 14:56:02 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-22 15:25:52 -0800
commit6d744f8fbfe6ef1c90534893d4ca2b2caf8a0df4 (patch)
treefd470113098c17c3e7a5a3e199f06ae017cebcda /torch
parentdbd66c17bcae6015844f2603bb5db28af84ac7d7 (diff)
downloadpytorch-6d744f8fbfe6ef1c90534893d4ca2b2caf8a0df4.tar.gz
pytorch-6d744f8fbfe6ef1c90534893d4ca2b2caf8a0df4.tar.bz2
pytorch-6d744f8fbfe6ef1c90534893d4ca2b2caf8a0df4.zip
Preserve names when converting to/from NetDef.
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17378 Differential Revision: D14176515 Pulled By: ZolotukhinM fbshipit-source-id: da9ea28310250ab3ca3a99cdc210fd8d1fbbc82b
Diffstat (limited to 'torch')
-rw-r--r--torch/csrc/jit/netdef_converter.cpp35
1 files changed, 34 insertions, 1 deletions
diff --git a/torch/csrc/jit/netdef_converter.cpp b/torch/csrc/jit/netdef_converter.cpp
index 4c8ca5fe08..8f77165ed2 100644
--- a/torch/csrc/jit/netdef_converter.cpp
+++ b/torch/csrc/jit/netdef_converter.cpp
@@ -77,11 +77,13 @@ void convertNetDefToIR(
std::unordered_map<std::string, Value*>* valueMapPtr,
const std::string& prefix) {
std::unordered_map<std::string, Value*>& valueMap = *valueMapPtr;
+ std::unordered_map<Value*, std::string> namesMap;
valueMap.clear();
for (const auto& inputName : net.external_input()) {
AT_ASSERT(!valueMap.count(inputName));
valueMap[inputName] = g->addInput();
+ namesMap[valueMap.at(inputName)] = inputName;
}
for (const auto& op : net.op()) {
@@ -98,7 +100,9 @@ void convertNetDefToIR(
for (const auto& output : op.output()) {
// If output already exists in valueMap, overwrite it. This way we will
// have the last definition of a value named 'output' in valueMap.
- valueMap[output] = node->outputs()[idx++];
+ Value* v = node->outputs()[idx++];
+ valueMap[output] = v;
+ namesMap[v] = output;
}
for (const auto& arg : op.arg()) {
convertArg(arg, node);
@@ -108,6 +112,35 @@ void convertNetDefToIR(
for (const auto& outputName : net.external_output()) {
AT_ASSERT(valueMap.count(outputName));
g->registerOutput(valueMap.at(outputName));
+ namesMap[valueMap.at(outputName)] = outputName;
+ }
+
+ // Set proper unique names for all values.
+ // We will set the names for external inputs and outputs last, so that if the
+ // names are reused, then intermediate values will be renamed and the external
+ // values will keep the original names.
+ for (Node* n : g->nodes()) {
+ for (Value* v : n->outputs()) {
+ AT_ASSERT(namesMap.count(v));
+ const std::string& name = namesMap.at(v);
+ if (Value::isValidName(name)) {
+ v->setUniqueName(name);
+ }
+ }
+ }
+ for (Value* v : g->inputs()) {
+ AT_ASSERT(namesMap.count(v));
+ const std::string& name = namesMap.at(v);
+ if (Value::isValidName(name)) {
+ v->setUniqueName(name);
+ }
+ }
+ for (Value* v : g->outputs()) {
+ AT_ASSERT(namesMap.count(v));
+ const std::string& name = namesMap.at(v);
+ if (Value::isValidName(name)) {
+ v->setUniqueName(name);
+ }
}
}