summaryrefslogtreecommitdiff
path: root/torch/onnx
diff options
context:
space:
mode:
authorZachary DeVito <zdevito@fb.com>2018-08-27 14:30:25 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-08-27 14:40:40 -0700
commit6ce799edd6fbecbf3803d512cff59508d53641eb (patch)
tree5c97fa621d2d08908b87e02044b6ca4cee5ecaaa /torch/onnx
parentf64f6eed3ac5a9f297d570b8077a158cb371b112 (diff)
downloadpytorch-6ce799edd6fbecbf3803d512cff59508d53641eb.tar.gz
pytorch-6ce799edd6fbecbf3803d512cff59508d53641eb.tar.bz2
pytorch-6ce799edd6fbecbf3803d512cff59508d53641eb.zip
Tuples/Lists can now be inputs/outputs to script and other simple fixes. (#10812)
Summary: * Fix the necessary pathways so that tuples and lists can be inputs to the script. * prevent linear algebra functions from being run in shape prop because they frequently will error out for nonsense data. * favor schema-driven python input conversion where possible. remaining cases where we directly create Stacks without schema are only for debugging * Make the error messages when calling script/trace functions more pythonic * Simplify FlattenTuples -- now that tuples are supported we can choose to only flatten tuples when needed. This may have to be revisited pending onnx test results, but is necessary for making tuple io work. Pull Request resolved: https://github.com/pytorch/pytorch/pull/10812 Differential Revision: D9477982 Pulled By: zdevito fbshipit-source-id: ed06fc426e6ef6deb404602a26c435a7fc40ea0c
Diffstat (limited to 'torch/onnx')
-rw-r--r--torch/onnx/utils.py2
1 files changed, 2 insertions, 0 deletions
diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py
index f566dbb53b..34c30aea65 100644
--- a/torch/onnx/utils.py
+++ b/torch/onnx/utils.py
@@ -131,6 +131,8 @@ def _optimize_graph(graph, operator_export_type):
# onnx only supports tensors, so we turn all out number types into tensors
torch._C._jit_pass_erase_number_types(graph)
+ # onnx does not support tuples, so try to remove them
+ torch._C._jit_pass_lower_all_tuples(graph)
torch._C._jit_pass_peephole(graph)
torch._C._jit_pass_lint(graph)