summaryrefslogtreecommitdiff
path: root/torch/onnx
diff options
context:
space:
mode:
authorLara <lahaidar@microsoft.com>2019-04-04 13:15:18 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-04 13:24:04 -0700
commit1ec1db477d3963e433f4cbe751142f454bd7e755 (patch)
tree0212ab0f5a1c520a60c26262166c92d1d0aa6faa /torch/onnx
parentb4d2df1fee35e9f2e8fb01297261e6c19d568e75 (diff)
downloadpytorch-1ec1db477d3963e433f4cbe751142f454bd7e755.tar.gz
pytorch-1ec1db477d3963e433f4cbe751142f454bd7e755.tar.bz2
pytorch-1ec1db477d3963e433f4cbe751142f454bd7e755.zip
ONNX Export All Cases of Softmax
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18482 Reviewed By: zrphercule Differential Revision: D14630697 Pulled By: houseroad fbshipit-source-id: c06f1e3bead10a265c5f4ac3723d49f4caf46801
Diffstat (limited to 'torch/onnx')
-rw-r--r--torch/onnx/symbolic.py22
1 files changed, 15 insertions, 7 deletions
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py
index 1edc190444..837da84f8b 100644
--- a/torch/onnx/symbolic.py
+++ b/torch/onnx/symbolic.py
@@ -596,14 +596,22 @@ def softmax(g, input, dim, dtype=None):
# [0.167, 0.167, 0.167]]
# So only when dim and axis both equal to ndim - 1 (the last dimension),
# their semantics are equivalent.
- if dim < 0:
- dim = input.type().dim() + dim
- if input.type().dim() != dim + 1:
- return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input.")
- return_op = g.op('Softmax', input, axis_i=dim)
+ # So use softmax when dim and axis both equal to ndim - 1
+ # otherwise compute softmax using a subgraph with other operators
+ if input.type().kind() == "CompleteTensorType" or input.type().kind() == "DimensionedTensorType":
+ if dim < 0:
+ dim = input.type().dim() + dim
+ if input.type().dim() == dim + 1:
+ softmax = g.op('Softmax', input, axis_i=dim)
+ if dtype:
+ softmax = g.op("Cast", softmax, to_i=scalar_type_to_onnx[dtype])
+ return softmax
+ exp = g.op('Exp', input)
+ sum = g.op('ReduceSum', exp, axes_i=[dim])
+ softmax = g.op('Div', exp, sum)
if dtype:
- return_op = g.op("Cast", return_op, to_i=scalar_type_to_onnx[dtype])
- return return_op
+ softmax = g.op("Cast", softmax, to_i=scalar_type_to_onnx[dtype])
+ return softmax
@parse_args('v', 't', 'v')