diff options
author | Lara <lahaidar@microsoft.com> | 2019-04-04 13:15:18 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-04 13:24:04 -0700 |
commit | 1ec1db477d3963e433f4cbe751142f454bd7e755 (patch) | |
tree | 0212ab0f5a1c520a60c26262166c92d1d0aa6faa | |
parent | b4d2df1fee35e9f2e8fb01297261e6c19d568e75 (diff) | |
download | pytorch-1ec1db477d3963e433f4cbe751142f454bd7e755.tar.gz pytorch-1ec1db477d3963e433f4cbe751142f454bd7e755.tar.bz2 pytorch-1ec1db477d3963e433f4cbe751142f454bd7e755.zip |
ONNX Export All Cases of Softmax
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18482
Reviewed By: zrphercule
Differential Revision: D14630697
Pulled By: houseroad
fbshipit-source-id: c06f1e3bead10a265c5f4ac3723d49f4caf46801
-rw-r--r-- | test/onnx/test_pytorch_onnx_caffe2.py | 19 | ||||
-rw-r--r-- | torch/onnx/symbolic.py | 22 |
2 files changed, 29 insertions, 12 deletions
diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py index 92c7efbc08..daea9ef412 100644 --- a/test/onnx/test_pytorch_onnx_caffe2.py +++ b/test/onnx/test_pytorch_onnx_caffe2.py @@ -934,11 +934,20 @@ class TestCaffe2Backend(unittest.TestCase): # TODO: Add test cases for prod once Caffe2 has support for ReduceProd def test_softmax(self): - for i in range(7)[2:]: - model = nn.Softmax(dim=i - 1) - dims = [2] * (i - 2) + [3, 4] - input = torch.ones(*dims, requires_grad=True) - self.run_model_test(model, train=False, batch_size=BATCH_SIZE, input=input) + for i in range(2, 8): + for d in range(0, i - 1): + model = nn.Softmax(dim=d) + dims = [2] * (i - 2) + [3, 4] + input = torch.ones(*dims, requires_grad=True) + self.run_model_test(model, train=False, batch_size=BATCH_SIZE, input=input) + + def test_softmax_dtype(self): + class SoftmaxModel(torch.nn.Module): + def forward(self, input): + return nn.functional.softmax(input, dim=0, dtype=torch.float64) + + x = torch.randn(1, 2, 3, requires_grad=True, dtype=torch.float32) + self.run_model_test(SoftmaxModel(), train=False, input=x, batch_size=BATCH_SIZE) def test_logsoftmax(self): for i in range(7)[2:]: diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index 1edc190444..837da84f8b 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -596,14 +596,22 @@ def softmax(g, input, dim, dtype=None): # [0.167, 0.167, 0.167]] # So only when dim and axis both equal to ndim - 1 (the last dimension), # their semantics are equivalent. - if dim < 0: - dim = input.type().dim() + dim - if input.type().dim() != dim + 1: - return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input.") - return_op = g.op('Softmax', input, axis_i=dim) + # So use softmax when dim and axis both equal to ndim - 1 + # otherwise compute softmax using a subgraph with other operators + if input.type().kind() == "CompleteTensorType" or input.type().kind() == "DimensionedTensorType": + if dim < 0: + dim = input.type().dim() + dim + if input.type().dim() == dim + 1: + softmax = g.op('Softmax', input, axis_i=dim) + if dtype: + softmax = g.op("Cast", softmax, to_i=scalar_type_to_onnx[dtype]) + return softmax + exp = g.op('Exp', input) + sum = g.op('ReduceSum', exp, axes_i=[dim]) + softmax = g.op('Div', exp, sum) if dtype: - return_op = g.op("Cast", return_op, to_i=scalar_type_to_onnx[dtype]) - return return_op + softmax = g.op("Cast", softmax, to_i=scalar_type_to_onnx[dtype]) + return softmax @parse_args('v', 't', 'v') |