summaryrefslogtreecommitdiff
path: root/torch/onnx
diff options
context:
space:
mode:
authorLu Fang <lufang@fb.com>2019-03-06 14:59:16 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-03-06 15:02:08 -0800
commit4db3f8f8065cec8ca021eb750468f040499dce51 (patch)
tree06ffb0fd0aa30e3403a359bcef685ebd90660720 /torch/onnx
parentc78da0c6ede91f8bb778bb566e5aca1d57fc6cf5 (diff)
downloadpytorch-4db3f8f8065cec8ca021eb750468f040499dce51.tar.gz
pytorch-4db3f8f8065cec8ca021eb750468f040499dce51.tar.bz2
pytorch-4db3f8f8065cec8ca021eb750468f040499dce51.zip
Improve ONNX symbolic for logsoftmax and softmax (#17672)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17672 support dtype in the onnx symbolic Reviewed By: zrphercule Differential Revision: D14313987 fbshipit-source-id: e9364621b3f795191d880599711dfbcb220d0e31
Diffstat (limited to 'torch/onnx')
-rw-r--r--torch/onnx/symbolic.py21
1 files changed, 14 insertions, 7 deletions
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py
index 244aa03640..a3c145ff8d 100644
--- a/torch/onnx/symbolic.py
+++ b/torch/onnx/symbolic.py
@@ -114,7 +114,8 @@ def _unpack_list(list_value):
def parse_args(*arg_descriptors):
def decorator(fn):
def wrapper(g, *args):
- assert len(arg_descriptors) == len(args)
+ # some args may be optional, so the length may be smaller
+ assert len(arg_descriptors) >= len(args)
args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
return fn(g, *args)
# In Python 2 functools.wraps chokes on partially applied functions, so we need this as a workaround
@@ -567,8 +568,8 @@ def glu(g, input, dim):
return g.op('Mul', first, g.op('Sigmoid', second))
-@parse_args('v', 'i')
-def softmax(g, input, dim):
+@parse_args('v', 'i', 'i')
+def softmax(g, input, dim, dtype=None):
# Softmax does normalization at vector level.
# PyTorch and ONNX use different strategies to split the input tensor into vectors.
# Thus dim and axis have different meanings.
@@ -589,7 +590,10 @@ def softmax(g, input, dim):
dim = input.type().dim() + dim
if input.type().dim() != dim + 1:
return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input.")
- return g.op('Softmax', input, axis_i=dim)
+ return_op = g.op('Softmax', input, axis_i=dim)
+ if dtype:
+ return_op = g.op("Cast", return_op, to_i=scalar_type_to_onnx[dtype])
+ return return_op
@parse_args('v', 't', 'v')
@@ -870,15 +874,18 @@ def where(g, condition, self, other):
return g.op("ATen", condition, self, other, operator_s="where")
-@parse_args('v', 'i')
-def log_softmax(g, input, dim=None):
+@parse_args('v', 'i', 'i')
+def log_softmax(g, input, dim=None, dtype=None):
# PyTorch dim and ONNX axis have different meanings.
# See Softmax comment for details.
if dim < 0:
dim = input.type().dim() + dim
if input.type().dim() != dim + 1:
return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input.")
- return g.op("LogSoftmax", input, axis_i=dim)
+ return_op = g.op("LogSoftmax", input, axis_i=dim)
+ if dtype:
+ return_op = g.op("Cast", return_op, to_i=scalar_type_to_onnx[dtype])
+ return return_op
@parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i', 'is', 'i', 'i', 'i', 'i')