diff options
author | Lu Fang <lufang@fb.com> | 2019-04-12 11:58:06 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-12 12:01:14 -0700 |
commit | bd55abb4635bf935c8fb6e222fc0061e87318f77 (patch) | |
tree | c624ef7e2280cd0f4cd21b7d4b732e8ee8da1f81 /torch/onnx | |
parent | c480798a1cb704b542bc74707108df575f3d4ee5 (diff) | |
download | pytorch-bd55abb4635bf935c8fb6e222fc0061e87318f77.tar.gz pytorch-bd55abb4635bf935c8fb6e222fc0061e87318f77.tar.bz2 pytorch-bd55abb4635bf935c8fb6e222fc0061e87318f77.zip |
Fix onnx ints (#19102)
Summary:
If JIT constant propagation doesn't work, we have to handle the ListConstructor in symbolic.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19102
Reviewed By: zrphercule
Differential Revision: D14875588
Pulled By: houseroad
fbshipit-source-id: d25c847d224d2d32db50aae1751100080e115022
Diffstat (limited to 'torch/onnx')
-rw-r--r-- | torch/onnx/symbolic.py | 37 |
1 files changed, 24 insertions, 13 deletions
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index 27ba60ae69..bb656bda93 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -69,19 +69,30 @@ def _parse_arg(value, desc): return value if desc == 'v' or not _is_value(value): return value - if value.node().kind() != 'onnx::Constant': - raise RuntimeError("ONNX symbolic expected a constant value in the trace") - tval = value.node()['value'] - if desc == 'i': - return int(tval) - elif desc == 'f': - return float(tval) - elif desc == 't': - return tval - elif desc == 'is': - return [int(v) for v in tval] - else: - raise RuntimeError("Casting constants to `{}` is not implemented".format(desc)) + if value.node().kind() == 'onnx::Constant': + tval = value.node()['value'] + if desc == 'i': + return int(tval) + elif desc == 'f': + return float(tval) + elif desc == 't': + return tval + elif desc == 'is': + return [int(v) for v in tval] + else: + raise RuntimeError("ONNX symbolic doesn't know to interpret Constant node") + elif value.node().kind() == 'prim::ListConstruct': + if desc == 'is': + for v in value.node().inputs(): + if v.node().kind() != 'onnx::Constant': + raise RuntimeError("Failed to export an ONNX attribute, " + "since it's not constant, please try to make " + "things (e.g., kernel size) static if possible") + return [int(v.node()['value']) for v in value.node().inputs()] + else: + raise RuntimeError("ONNX symbolic doesn't know to interpret ListConstruct node") + + raise RuntimeError("Unexpected node type: {}".format(value.node().kind())) def _maybe_get_const(value, desc): |