diff options
author | Mikhail Zolotukhin <mvz@fb.com> | 2019-02-22 14:56:02 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-22 15:25:52 -0800 |
commit | 6d744f8fbfe6ef1c90534893d4ca2b2caf8a0df4 (patch) | |
tree | fd470113098c17c3e7a5a3e199f06ae017cebcda /torch | |
parent | dbd66c17bcae6015844f2603bb5db28af84ac7d7 (diff) | |
download | pytorch-6d744f8fbfe6ef1c90534893d4ca2b2caf8a0df4.tar.gz pytorch-6d744f8fbfe6ef1c90534893d4ca2b2caf8a0df4.tar.bz2 pytorch-6d744f8fbfe6ef1c90534893d4ca2b2caf8a0df4.zip |
Preserve names when converting to/from NetDef.
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17378
Differential Revision: D14176515
Pulled By: ZolotukhinM
fbshipit-source-id: da9ea28310250ab3ca3a99cdc210fd8d1fbbc82b
Diffstat (limited to 'torch')
-rw-r--r-- | torch/csrc/jit/netdef_converter.cpp | 35 |
1 files changed, 34 insertions, 1 deletions
diff --git a/torch/csrc/jit/netdef_converter.cpp b/torch/csrc/jit/netdef_converter.cpp index 4c8ca5fe08..8f77165ed2 100644 --- a/torch/csrc/jit/netdef_converter.cpp +++ b/torch/csrc/jit/netdef_converter.cpp @@ -77,11 +77,13 @@ void convertNetDefToIR( std::unordered_map<std::string, Value*>* valueMapPtr, const std::string& prefix) { std::unordered_map<std::string, Value*>& valueMap = *valueMapPtr; + std::unordered_map<Value*, std::string> namesMap; valueMap.clear(); for (const auto& inputName : net.external_input()) { AT_ASSERT(!valueMap.count(inputName)); valueMap[inputName] = g->addInput(); + namesMap[valueMap.at(inputName)] = inputName; } for (const auto& op : net.op()) { @@ -98,7 +100,9 @@ void convertNetDefToIR( for (const auto& output : op.output()) { // If output already exists in valueMap, overwrite it. This way we will // have the last definition of a value named 'output' in valueMap. - valueMap[output] = node->outputs()[idx++]; + Value* v = node->outputs()[idx++]; + valueMap[output] = v; + namesMap[v] = output; } for (const auto& arg : op.arg()) { convertArg(arg, node); @@ -108,6 +112,35 @@ void convertNetDefToIR( for (const auto& outputName : net.external_output()) { AT_ASSERT(valueMap.count(outputName)); g->registerOutput(valueMap.at(outputName)); + namesMap[valueMap.at(outputName)] = outputName; + } + + // Set proper unique names for all values. + // We will set the names for external inputs and outputs last, so that if the + // names are reused, then intermediate values will be renamed and the external + // values will keep the original names. + for (Node* n : g->nodes()) { + for (Value* v : n->outputs()) { + AT_ASSERT(namesMap.count(v)); + const std::string& name = namesMap.at(v); + if (Value::isValidName(name)) { + v->setUniqueName(name); + } + } + } + for (Value* v : g->inputs()) { + AT_ASSERT(namesMap.count(v)); + const std::string& name = namesMap.at(v); + if (Value::isValidName(name)) { + v->setUniqueName(name); + } + } + for (Value* v : g->outputs()) { + AT_ASSERT(namesMap.count(v)); + const std::string& name = namesMap.at(v); + if (Value::isValidName(name)) { + v->setUniqueName(name); + } } } |