diff options
author | Zachary DeVito <zdevito@fb.com> | 2019-02-05 12:16:56 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-05 12:51:24 -0800 |
commit | 6efa40e07b8e0b3ec27e4d2818172e4ed4239a90 (patch) | |
tree | 96815c7d6198ecf72e9f92c0934fc8296e9da1fc /torch | |
parent | f8d4a14f6dc7c709c86b8bb25ccaaa486470e4a7 (diff) | |
download | pytorch-6efa40e07b8e0b3ec27e4d2818172e4ed4239a90.tar.gz pytorch-6efa40e07b8e0b3ec27e4d2818172e4ed4239a90.tar.bz2 pytorch-6efa40e07b8e0b3ec27e4d2818172e4ed4239a90.zip |
Preserve method parameter names (#16750)
Summary:
Fixes #16591
This uses uniqueBaseName so that parameters do not end up with suffixes. It changes next_id to be per-base-name rather than global to fix jittering issues when re-importing a re-numbered graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16750
Differential Revision: D13960282
Pulled By: zdevito
fbshipit-source-id: 2156f581d9b95d77bf1f1252074e800b19116555
Diffstat (limited to 'torch')
-rw-r--r-- | torch/csrc/jit/passes/python_print.cpp | 9 |
1 files changed, 4 insertions, 5 deletions
diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp index ed935dc125..e44c55c870 100644 --- a/torch/csrc/jit/passes/python_print.cpp +++ b/torch/csrc/jit/passes/python_print.cpp @@ -359,16 +359,17 @@ struct PythonPrintPass { buildConstantList(n, constants); buildConstantList(b->return_node(), constants); } + // get a new name unique across calls to uniqueName() and // anything we have used. - size_t next_id = 0; + std::unordered_map<std::string, size_t> next_id; std::string genNameImpl( const std::string& candidate, std::unordered_set<std::string>& used) { std::string name = candidate; while (used.count(name) || reserved_names.count(name)) { - name = candidate + std::to_string(next_id++); + name = candidate + std::to_string(next_id[name]++); } used.insert(name); return name; @@ -402,7 +403,7 @@ struct PythonPrintPass { // use the uniqueName if it was set, otherwise generate a name. std::string genUniqueNameFor(Value* v) { return genName( - v->hasUniqueName() ? makeValidIdentifier(v->uniqueName()) : "_"); + v->hasUniqueName() ? makeValidIdentifier(v->uniqueNameBase()) : "_"); } // map from Value to how it should be printed at each use @@ -1006,7 +1007,6 @@ struct PythonPrintPass { } void printMethod(script::Method& method) { std::unordered_map<at::Tensor*, QualifiedNamePtr> parameter_names; - ; createTensorToParameterNameMap( method.owner(), QualifiedName::create("self"), parameter_names); printMethod(method, parameter_names); @@ -1027,7 +1027,6 @@ struct PythonPrintPass { } void printModule(script::Module& module) { std::unordered_map<at::Tensor*, QualifiedNamePtr> parameter_names; - ; createTensorToParameterNameMap( module, QualifiedName::create("self"), parameter_names); for (auto& method : module.get_methods()) { |