diff options
author | Zachary DeVito <zdevito@fb.com> | 2018-08-27 14:30:25 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-08-27 14:40:40 -0700 |
commit | 6ce799edd6fbecbf3803d512cff59508d53641eb (patch) | |
tree | 5c97fa621d2d08908b87e02044b6ca4cee5ecaaa /torch/onnx | |
parent | f64f6eed3ac5a9f297d570b8077a158cb371b112 (diff) | |
download | pytorch-6ce799edd6fbecbf3803d512cff59508d53641eb.tar.gz pytorch-6ce799edd6fbecbf3803d512cff59508d53641eb.tar.bz2 pytorch-6ce799edd6fbecbf3803d512cff59508d53641eb.zip |
Tuples/Lists can now be inputs/outputs to script and other simple fixes. (#10812)
Summary:
* Fix the necessary pathways so that tuples and lists can be inputs to the script.
* prevent linear algebra functions from being run in shape prop because
they frequently will error out for nonsense data.
* favor schema-driven python input conversion where possible.
remaining cases where we directly create Stacks without schema are
only for debugging
* Make the error messages when calling script/trace functions more pythonic
* Simplify FlattenTuples -- now that tuples are supported we can choose to only flatten tuples when needed. This may have to be revisited pending onnx test results, but is necessary for making tuple io work.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10812
Differential Revision: D9477982
Pulled By: zdevito
fbshipit-source-id: ed06fc426e6ef6deb404602a26c435a7fc40ea0c
Diffstat (limited to 'torch/onnx')
-rw-r--r-- | torch/onnx/utils.py | 2 |
1 files changed, 2 insertions, 0 deletions
diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index f566dbb53b..34c30aea65 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -131,6 +131,8 @@ def _optimize_graph(graph, operator_export_type): # onnx only supports tensors, so we turn all out number types into tensors torch._C._jit_pass_erase_number_types(graph) + # onnx does not support tuples, so try to remove them + torch._C._jit_pass_lower_all_tuples(graph) torch._C._jit_pass_peephole(graph) torch._C._jit_pass_lint(graph) |