diff options
author | Spandan Tiwari <sptiwari@microsoft.com> | 2019-01-03 10:29:03 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-01-03 10:35:25 -0800 |
commit | 947229ebd780b0ecead66029691fd18881947b59 (patch) | |
tree | dff0a58ff48a718d109aee4dcebee616959bb035 /test/onnx | |
parent | 279ca4acd2e1a48d30a69d17ee7f82e5fa959e1f (diff) | |
download | pytorch-947229ebd780b0ecead66029691fd18881947b59.tar.gz pytorch-947229ebd780b0ecead66029691fd18881947b59.tar.bz2 pytorch-947229ebd780b0ecead66029691fd18881947b59.zip |
Fix ONNX export of logical ops, including torch.ne, to have correct output datatype (#15677)
Summary:
This is the an updated version of the earlier PR https://github.com/pytorch/pytorch/pull/15185, since that one was closed.
Currently PyTorch ONNX exporter exports the logical ops (lt, gt, le, ge, eq, ne) with output type in corresponding ONNX ops as type tensor(uint8). But ONNX spec allows for only tensor(bool), which is why models that have these ops fail to load properly.
This issue is captured in #11339. Part of this issue, relating to the allowed input types, has been fixed in ONNX spec by houseroad. This PR fixes the other part pertaining to output type.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15677
Reviewed By: dzhulgakov
Differential Revision: D13568450
Pulled By: houseroad
fbshipit-source-id: a6afbea1afdb4edad8f8b1bc492f50b14e5f2fce
Diffstat (limited to 'test/onnx')
-rw-r--r-- | test/onnx/expect/TestOperators.test_equal.expect | 12 | ||||
-rw-r--r-- | test/onnx/expect/TestOperators.test_ge.expect | 12 | ||||
-rw-r--r-- | test/onnx/expect/TestOperators.test_gt.expect | 12 | ||||
-rw-r--r-- | test/onnx/expect/TestOperators.test_le.expect | 12 | ||||
-rw-r--r-- | test/onnx/expect/TestOperators.test_lt.expect | 12 | ||||
-rw-r--r-- | test/onnx/expect/TestOperators.test_ne.expect | 12 |
6 files changed, 66 insertions, 6 deletions
diff --git a/test/onnx/expect/TestOperators.test_equal.expect b/test/onnx/expect/TestOperators.test_equal.expect index c5117370dc..53d4f844fa 100644 --- a/test/onnx/expect/TestOperators.test_equal.expect +++ b/test/onnx/expect/TestOperators.test_equal.expect @@ -8,6 +8,16 @@ graph { output: "2" op_type: "Equal" } + node { + input: "2" + output: "3" + op_type: "Cast" + attribute { + name: "to" + i: 2 + type: INT + } + } name: "torch-jit-export" input { name: "0" @@ -48,7 +58,7 @@ graph { } } output { - name: "2" + name: "3" type { tensor_type { elem_type: 2 diff --git a/test/onnx/expect/TestOperators.test_ge.expect b/test/onnx/expect/TestOperators.test_ge.expect index a66c5f47f0..8a4b5d5758 100644 --- a/test/onnx/expect/TestOperators.test_ge.expect +++ b/test/onnx/expect/TestOperators.test_ge.expect @@ -13,6 +13,16 @@ graph { output: "3" op_type: "Not" } + node { + input: "3" + output: "4" + op_type: "Cast" + attribute { + name: "to" + i: 2 + type: INT + } + } name: "torch-jit-export" input { name: "0" @@ -47,7 +57,7 @@ graph { } } output { - name: "3" + name: "4" type { tensor_type { elem_type: 2 diff --git a/test/onnx/expect/TestOperators.test_gt.expect b/test/onnx/expect/TestOperators.test_gt.expect index 7680e0a66f..542bf81e34 100644 --- a/test/onnx/expect/TestOperators.test_gt.expect +++ b/test/onnx/expect/TestOperators.test_gt.expect @@ -8,6 +8,16 @@ graph { output: "2" op_type: "Greater" } + node { + input: "2" + output: "3" + op_type: "Cast" + attribute { + name: "to" + i: 2 + type: INT + } + } name: "torch-jit-export" input { name: "0" @@ -48,7 +58,7 @@ graph { } } output { - name: "2" + name: "3" type { tensor_type { elem_type: 2 diff --git a/test/onnx/expect/TestOperators.test_le.expect b/test/onnx/expect/TestOperators.test_le.expect index 0b17740b3d..923489e0b2 100644 --- a/test/onnx/expect/TestOperators.test_le.expect +++ b/test/onnx/expect/TestOperators.test_le.expect @@ -13,6 +13,16 @@ graph { output: "3" op_type: "Not" } + node { + input: "3" + output: "4" + op_type: "Cast" + attribute { + name: "to" + i: 2 + type: INT + } + } name: "torch-jit-export" input { name: "0" @@ -47,7 +57,7 @@ graph { } } output { - name: "3" + name: "4" type { tensor_type { elem_type: 2 diff --git a/test/onnx/expect/TestOperators.test_lt.expect b/test/onnx/expect/TestOperators.test_lt.expect index a5c67409d7..2befe25dc9 100644 --- a/test/onnx/expect/TestOperators.test_lt.expect +++ b/test/onnx/expect/TestOperators.test_lt.expect @@ -8,6 +8,16 @@ graph { output: "2" op_type: "Less" } + node { + input: "2" + output: "3" + op_type: "Cast" + attribute { + name: "to" + i: 2 + type: INT + } + } name: "torch-jit-export" input { name: "0" @@ -48,7 +58,7 @@ graph { } } output { - name: "2" + name: "3" type { tensor_type { elem_type: 2 diff --git a/test/onnx/expect/TestOperators.test_ne.expect b/test/onnx/expect/TestOperators.test_ne.expect index 078204f6aa..50505c4e1b 100644 --- a/test/onnx/expect/TestOperators.test_ne.expect +++ b/test/onnx/expect/TestOperators.test_ne.expect @@ -13,6 +13,16 @@ graph { output: "3" op_type: "Not" } + node { + input: "3" + output: "4" + op_type: "Cast" + attribute { + name: "to" + i: 2 + type: INT + } + } name: "torch-jit-export" input { name: "0" @@ -53,7 +63,7 @@ graph { } } output { - name: "3" + name: "4" type { tensor_type { elem_type: 2 |