diff options
author | Lara <lahaidar@microsoft.com> | 2019-04-04 13:15:18 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-04 13:24:04 -0700 |
commit | 1ec1db477d3963e433f4cbe751142f454bd7e755 (patch) | |
tree | 0212ab0f5a1c520a60c26262166c92d1d0aa6faa /torch/onnx | |
parent | b4d2df1fee35e9f2e8fb01297261e6c19d568e75 (diff) | |
download | pytorch-1ec1db477d3963e433f4cbe751142f454bd7e755.tar.gz pytorch-1ec1db477d3963e433f4cbe751142f454bd7e755.tar.bz2 pytorch-1ec1db477d3963e433f4cbe751142f454bd7e755.zip |
ONNX Export All Cases of Softmax
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18482
Reviewed By: zrphercule
Differential Revision: D14630697
Pulled By: houseroad
fbshipit-source-id: c06f1e3bead10a265c5f4ac3723d49f4caf46801
Diffstat (limited to 'torch/onnx')
-rw-r--r-- | torch/onnx/symbolic.py | 22 |
1 files changed, 15 insertions, 7 deletions
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index 1edc190444..837da84f8b 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -596,14 +596,22 @@ def softmax(g, input, dim, dtype=None): # [0.167, 0.167, 0.167]] # So only when dim and axis both equal to ndim - 1 (the last dimension), # their semantics are equivalent. - if dim < 0: - dim = input.type().dim() + dim - if input.type().dim() != dim + 1: - return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input.") - return_op = g.op('Softmax', input, axis_i=dim) + # So use softmax when dim and axis both equal to ndim - 1 + # otherwise compute softmax using a subgraph with other operators + if input.type().kind() == "CompleteTensorType" or input.type().kind() == "DimensionedTensorType": + if dim < 0: + dim = input.type().dim() + dim + if input.type().dim() == dim + 1: + softmax = g.op('Softmax', input, axis_i=dim) + if dtype: + softmax = g.op("Cast", softmax, to_i=scalar_type_to_onnx[dtype]) + return softmax + exp = g.op('Exp', input) + sum = g.op('ReduceSum', exp, axes_i=[dim]) + softmax = g.op('Div', exp, sum) if dtype: - return_op = g.op("Cast", return_op, to_i=scalar_type_to_onnx[dtype]) - return return_op + softmax = g.op("Cast", softmax, to_i=scalar_type_to_onnx[dtype]) + return softmax @parse_args('v', 't', 'v') |