diff options
author | Lu Fang <lufang@fb.com> | 2019-03-06 14:59:16 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-03-06 15:02:08 -0800 |
commit | 4db3f8f8065cec8ca021eb750468f040499dce51 (patch) | |
tree | 06ffb0fd0aa30e3403a359bcef685ebd90660720 /torch/onnx | |
parent | c78da0c6ede91f8bb778bb566e5aca1d57fc6cf5 (diff) | |
download | pytorch-4db3f8f8065cec8ca021eb750468f040499dce51.tar.gz pytorch-4db3f8f8065cec8ca021eb750468f040499dce51.tar.bz2 pytorch-4db3f8f8065cec8ca021eb750468f040499dce51.zip |
Improve ONNX symbolic for logsoftmax and softmax (#17672)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17672
support dtype in the onnx symbolic
Reviewed By: zrphercule
Differential Revision: D14313987
fbshipit-source-id: e9364621b3f795191d880599711dfbcb220d0e31
Diffstat (limited to 'torch/onnx')
-rw-r--r-- | torch/onnx/symbolic.py | 21 |
1 files changed, 14 insertions, 7 deletions
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index 244aa03640..a3c145ff8d 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -114,7 +114,8 @@ def _unpack_list(list_value): def parse_args(*arg_descriptors): def decorator(fn): def wrapper(g, *args): - assert len(arg_descriptors) == len(args) + # some args may be optional, so the length may be smaller + assert len(arg_descriptors) >= len(args) args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)] return fn(g, *args) # In Python 2 functools.wraps chokes on partially applied functions, so we need this as a workaround @@ -567,8 +568,8 @@ def glu(g, input, dim): return g.op('Mul', first, g.op('Sigmoid', second)) -@parse_args('v', 'i') -def softmax(g, input, dim): +@parse_args('v', 'i', 'i') +def softmax(g, input, dim, dtype=None): # Softmax does normalization at vector level. # PyTorch and ONNX use different strategies to split the input tensor into vectors. # Thus dim and axis have different meanings. @@ -589,7 +590,10 @@ def softmax(g, input, dim): dim = input.type().dim() + dim if input.type().dim() != dim + 1: return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input.") - return g.op('Softmax', input, axis_i=dim) + return_op = g.op('Softmax', input, axis_i=dim) + if dtype: + return_op = g.op("Cast", return_op, to_i=scalar_type_to_onnx[dtype]) + return return_op @parse_args('v', 't', 'v') @@ -870,15 +874,18 @@ def where(g, condition, self, other): return g.op("ATen", condition, self, other, operator_s="where") -@parse_args('v', 'i') -def log_softmax(g, input, dim=None): +@parse_args('v', 'i', 'i') +def log_softmax(g, input, dim=None, dtype=None): # PyTorch dim and ONNX axis have different meanings. # See Softmax comment for details. if dim < 0: dim = input.type().dim() + dim if input.type().dim() != dim + 1: return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input.") - return g.op("LogSoftmax", input, axis_i=dim) + return_op = g.op("LogSoftmax", input, axis_i=dim) + if dtype: + return_op = g.op("Cast", return_op, to_i=scalar_type_to_onnx[dtype]) + return return_op @parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i', 'is', 'i', 'i', 'i', 'i') |