summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLara <lahaidar@microsoft.com>2019-04-04 13:15:18 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-04 13:24:04 -0700
commit1ec1db477d3963e433f4cbe751142f454bd7e755 (patch)
tree0212ab0f5a1c520a60c26262166c92d1d0aa6faa
parentb4d2df1fee35e9f2e8fb01297261e6c19d568e75 (diff)
downloadpytorch-1ec1db477d3963e433f4cbe751142f454bd7e755.tar.gz
pytorch-1ec1db477d3963e433f4cbe751142f454bd7e755.tar.bz2
pytorch-1ec1db477d3963e433f4cbe751142f454bd7e755.zip
ONNX Export All Cases of Softmax
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18482 Reviewed By: zrphercule Differential Revision: D14630697 Pulled By: houseroad fbshipit-source-id: c06f1e3bead10a265c5f4ac3723d49f4caf46801
-rw-r--r--test/onnx/test_pytorch_onnx_caffe2.py19
-rw-r--r--torch/onnx/symbolic.py22
2 files changed, 29 insertions, 12 deletions
diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py
index 92c7efbc08..daea9ef412 100644
--- a/test/onnx/test_pytorch_onnx_caffe2.py
+++ b/test/onnx/test_pytorch_onnx_caffe2.py
@@ -934,11 +934,20 @@ class TestCaffe2Backend(unittest.TestCase):
# TODO: Add test cases for prod once Caffe2 has support for ReduceProd
def test_softmax(self):
- for i in range(7)[2:]:
- model = nn.Softmax(dim=i - 1)
- dims = [2] * (i - 2) + [3, 4]
- input = torch.ones(*dims, requires_grad=True)
- self.run_model_test(model, train=False, batch_size=BATCH_SIZE, input=input)
+ for i in range(2, 8):
+ for d in range(0, i - 1):
+ model = nn.Softmax(dim=d)
+ dims = [2] * (i - 2) + [3, 4]
+ input = torch.ones(*dims, requires_grad=True)
+ self.run_model_test(model, train=False, batch_size=BATCH_SIZE, input=input)
+
+ def test_softmax_dtype(self):
+ class SoftmaxModel(torch.nn.Module):
+ def forward(self, input):
+ return nn.functional.softmax(input, dim=0, dtype=torch.float64)
+
+ x = torch.randn(1, 2, 3, requires_grad=True, dtype=torch.float32)
+ self.run_model_test(SoftmaxModel(), train=False, input=x, batch_size=BATCH_SIZE)
def test_logsoftmax(self):
for i in range(7)[2:]:
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py
index 1edc190444..837da84f8b 100644
--- a/torch/onnx/symbolic.py
+++ b/torch/onnx/symbolic.py
@@ -596,14 +596,22 @@ def softmax(g, input, dim, dtype=None):
# [0.167, 0.167, 0.167]]
# So only when dim and axis both equal to ndim - 1 (the last dimension),
# their semantics are equivalent.
- if dim < 0:
- dim = input.type().dim() + dim
- if input.type().dim() != dim + 1:
- return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input.")
- return_op = g.op('Softmax', input, axis_i=dim)
+ # So use softmax when dim and axis both equal to ndim - 1
+ # otherwise compute softmax using a subgraph with other operators
+ if input.type().kind() == "CompleteTensorType" or input.type().kind() == "DimensionedTensorType":
+ if dim < 0:
+ dim = input.type().dim() + dim
+ if input.type().dim() == dim + 1:
+ softmax = g.op('Softmax', input, axis_i=dim)
+ if dtype:
+ softmax = g.op("Cast", softmax, to_i=scalar_type_to_onnx[dtype])
+ return softmax
+ exp = g.op('Exp', input)
+ sum = g.op('ReduceSum', exp, axes_i=[dim])
+ softmax = g.op('Div', exp, sum)
if dtype:
- return_op = g.op("Cast", return_op, to_i=scalar_type_to_onnx[dtype])
- return return_op
+ softmax = g.op("Cast", softmax, to_i=scalar_type_to_onnx[dtype])
+ return softmax
@parse_args('v', 't', 'v')