summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
authorZachary DeVito <zdevito@fb.com>2019-02-05 12:16:56 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-05 12:51:24 -0800
commit6efa40e07b8e0b3ec27e4d2818172e4ed4239a90 (patch)
tree96815c7d6198ecf72e9f92c0934fc8296e9da1fc /torch
parentf8d4a14f6dc7c709c86b8bb25ccaaa486470e4a7 (diff)
downloadpytorch-6efa40e07b8e0b3ec27e4d2818172e4ed4239a90.tar.gz
pytorch-6efa40e07b8e0b3ec27e4d2818172e4ed4239a90.tar.bz2
pytorch-6efa40e07b8e0b3ec27e4d2818172e4ed4239a90.zip
Preserve method parameter names (#16750)
Summary: Fixes #16591 This uses uniqueBaseName so that parameters do not end up with suffixes. It changes next_id to be per-base-name rather than global to fix jittering issues when re-importing a re-numbered graph. Pull Request resolved: https://github.com/pytorch/pytorch/pull/16750 Differential Revision: D13960282 Pulled By: zdevito fbshipit-source-id: 2156f581d9b95d77bf1f1252074e800b19116555
Diffstat (limited to 'torch')
-rw-r--r--torch/csrc/jit/passes/python_print.cpp9
1 files changed, 4 insertions, 5 deletions
diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp
index ed935dc125..e44c55c870 100644
--- a/torch/csrc/jit/passes/python_print.cpp
+++ b/torch/csrc/jit/passes/python_print.cpp
@@ -359,16 +359,17 @@ struct PythonPrintPass {
buildConstantList(n, constants);
buildConstantList(b->return_node(), constants);
}
+
// get a new name unique across calls to uniqueName() and
// anything we have used.
- size_t next_id = 0;
+ std::unordered_map<std::string, size_t> next_id;
std::string genNameImpl(
const std::string& candidate,
std::unordered_set<std::string>& used) {
std::string name = candidate;
while (used.count(name) || reserved_names.count(name)) {
- name = candidate + std::to_string(next_id++);
+ name = candidate + std::to_string(next_id[name]++);
}
used.insert(name);
return name;
@@ -402,7 +403,7 @@ struct PythonPrintPass {
// use the uniqueName if it was set, otherwise generate a name.
std::string genUniqueNameFor(Value* v) {
return genName(
- v->hasUniqueName() ? makeValidIdentifier(v->uniqueName()) : "_");
+ v->hasUniqueName() ? makeValidIdentifier(v->uniqueNameBase()) : "_");
}
// map from Value to how it should be printed at each use
@@ -1006,7 +1007,6 @@ struct PythonPrintPass {
}
void printMethod(script::Method& method) {
std::unordered_map<at::Tensor*, QualifiedNamePtr> parameter_names;
- ;
createTensorToParameterNameMap(
method.owner(), QualifiedName::create("self"), parameter_names);
printMethod(method, parameter_names);
@@ -1027,7 +1027,6 @@ struct PythonPrintPass {
}
void printModule(script::Module& module) {
std::unordered_map<at::Tensor*, QualifiedNamePtr> parameter_names;
- ;
createTensorToParameterNameMap(
module, QualifiedName::create("self"), parameter_names);
for (auto& method : module.get_methods()) {