summaryrefslogtreecommitdiff
path: root/torch/onnx
diff options
context:
space:
mode:
authorLu Fang <lufang@fb.com>2019-04-12 11:58:06 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-12 12:01:14 -0700
commitbd55abb4635bf935c8fb6e222fc0061e87318f77 (patch)
treec624ef7e2280cd0f4cd21b7d4b732e8ee8da1f81 /torch/onnx
parentc480798a1cb704b542bc74707108df575f3d4ee5 (diff)
downloadpytorch-bd55abb4635bf935c8fb6e222fc0061e87318f77.tar.gz
pytorch-bd55abb4635bf935c8fb6e222fc0061e87318f77.tar.bz2
pytorch-bd55abb4635bf935c8fb6e222fc0061e87318f77.zip
Fix onnx ints (#19102)
Summary: If JIT constant propagation doesn't work, we have to handle the ListConstructor in symbolic. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19102 Reviewed By: zrphercule Differential Revision: D14875588 Pulled By: houseroad fbshipit-source-id: d25c847d224d2d32db50aae1751100080e115022
Diffstat (limited to 'torch/onnx')
-rw-r--r--torch/onnx/symbolic.py37
1 files changed, 24 insertions, 13 deletions
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py
index 27ba60ae69..bb656bda93 100644
--- a/torch/onnx/symbolic.py
+++ b/torch/onnx/symbolic.py
@@ -69,19 +69,30 @@ def _parse_arg(value, desc):
return value
if desc == 'v' or not _is_value(value):
return value
- if value.node().kind() != 'onnx::Constant':
- raise RuntimeError("ONNX symbolic expected a constant value in the trace")
- tval = value.node()['value']
- if desc == 'i':
- return int(tval)
- elif desc == 'f':
- return float(tval)
- elif desc == 't':
- return tval
- elif desc == 'is':
- return [int(v) for v in tval]
- else:
- raise RuntimeError("Casting constants to `{}` is not implemented".format(desc))
+ if value.node().kind() == 'onnx::Constant':
+ tval = value.node()['value']
+ if desc == 'i':
+ return int(tval)
+ elif desc == 'f':
+ return float(tval)
+ elif desc == 't':
+ return tval
+ elif desc == 'is':
+ return [int(v) for v in tval]
+ else:
+ raise RuntimeError("ONNX symbolic doesn't know to interpret Constant node")
+ elif value.node().kind() == 'prim::ListConstruct':
+ if desc == 'is':
+ for v in value.node().inputs():
+ if v.node().kind() != 'onnx::Constant':
+ raise RuntimeError("Failed to export an ONNX attribute, "
+ "since it's not constant, please try to make "
+ "things (e.g., kernel size) static if possible")
+ return [int(v.node()['value']) for v in value.node().inputs()]
+ else:
+ raise RuntimeError("ONNX symbolic doesn't know to interpret ListConstruct node")
+
+ raise RuntimeError("Unexpected node type: {}".format(value.node().kind()))
def _maybe_get_const(value, desc):