summaryrefslogtreecommitdiff
path: root/test/onnx
diff options
context:
space:
mode:
authorSpandan Tiwari <sptiwari@microsoft.com>2019-01-03 10:29:03 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-01-03 10:35:25 -0800
commit947229ebd780b0ecead66029691fd18881947b59 (patch)
treedff0a58ff48a718d109aee4dcebee616959bb035 /test/onnx
parent279ca4acd2e1a48d30a69d17ee7f82e5fa959e1f (diff)
downloadpytorch-947229ebd780b0ecead66029691fd18881947b59.tar.gz
pytorch-947229ebd780b0ecead66029691fd18881947b59.tar.bz2
pytorch-947229ebd780b0ecead66029691fd18881947b59.zip
Fix ONNX export of logical ops, including torch.ne, to have correct output datatype (#15677)
Summary: This is the an updated version of the earlier PR https://github.com/pytorch/pytorch/pull/15185, since that one was closed. Currently PyTorch ONNX exporter exports the logical ops (lt, gt, le, ge, eq, ne) with output type in corresponding ONNX ops as type tensor(uint8). But ONNX spec allows for only tensor(bool), which is why models that have these ops fail to load properly. This issue is captured in #11339. Part of this issue, relating to the allowed input types, has been fixed in ONNX spec by houseroad. This PR fixes the other part pertaining to output type. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15677 Reviewed By: dzhulgakov Differential Revision: D13568450 Pulled By: houseroad fbshipit-source-id: a6afbea1afdb4edad8f8b1bc492f50b14e5f2fce
Diffstat (limited to 'test/onnx')
-rw-r--r--test/onnx/expect/TestOperators.test_equal.expect12
-rw-r--r--test/onnx/expect/TestOperators.test_ge.expect12
-rw-r--r--test/onnx/expect/TestOperators.test_gt.expect12
-rw-r--r--test/onnx/expect/TestOperators.test_le.expect12
-rw-r--r--test/onnx/expect/TestOperators.test_lt.expect12
-rw-r--r--test/onnx/expect/TestOperators.test_ne.expect12
6 files changed, 66 insertions, 6 deletions
diff --git a/test/onnx/expect/TestOperators.test_equal.expect b/test/onnx/expect/TestOperators.test_equal.expect
index c5117370dc..53d4f844fa 100644
--- a/test/onnx/expect/TestOperators.test_equal.expect
+++ b/test/onnx/expect/TestOperators.test_equal.expect
@@ -8,6 +8,16 @@ graph {
output: "2"
op_type: "Equal"
}
+ node {
+ input: "2"
+ output: "3"
+ op_type: "Cast"
+ attribute {
+ name: "to"
+ i: 2
+ type: INT
+ }
+ }
name: "torch-jit-export"
input {
name: "0"
@@ -48,7 +58,7 @@ graph {
}
}
output {
- name: "2"
+ name: "3"
type {
tensor_type {
elem_type: 2
diff --git a/test/onnx/expect/TestOperators.test_ge.expect b/test/onnx/expect/TestOperators.test_ge.expect
index a66c5f47f0..8a4b5d5758 100644
--- a/test/onnx/expect/TestOperators.test_ge.expect
+++ b/test/onnx/expect/TestOperators.test_ge.expect
@@ -13,6 +13,16 @@ graph {
output: "3"
op_type: "Not"
}
+ node {
+ input: "3"
+ output: "4"
+ op_type: "Cast"
+ attribute {
+ name: "to"
+ i: 2
+ type: INT
+ }
+ }
name: "torch-jit-export"
input {
name: "0"
@@ -47,7 +57,7 @@ graph {
}
}
output {
- name: "3"
+ name: "4"
type {
tensor_type {
elem_type: 2
diff --git a/test/onnx/expect/TestOperators.test_gt.expect b/test/onnx/expect/TestOperators.test_gt.expect
index 7680e0a66f..542bf81e34 100644
--- a/test/onnx/expect/TestOperators.test_gt.expect
+++ b/test/onnx/expect/TestOperators.test_gt.expect
@@ -8,6 +8,16 @@ graph {
output: "2"
op_type: "Greater"
}
+ node {
+ input: "2"
+ output: "3"
+ op_type: "Cast"
+ attribute {
+ name: "to"
+ i: 2
+ type: INT
+ }
+ }
name: "torch-jit-export"
input {
name: "0"
@@ -48,7 +58,7 @@ graph {
}
}
output {
- name: "2"
+ name: "3"
type {
tensor_type {
elem_type: 2
diff --git a/test/onnx/expect/TestOperators.test_le.expect b/test/onnx/expect/TestOperators.test_le.expect
index 0b17740b3d..923489e0b2 100644
--- a/test/onnx/expect/TestOperators.test_le.expect
+++ b/test/onnx/expect/TestOperators.test_le.expect
@@ -13,6 +13,16 @@ graph {
output: "3"
op_type: "Not"
}
+ node {
+ input: "3"
+ output: "4"
+ op_type: "Cast"
+ attribute {
+ name: "to"
+ i: 2
+ type: INT
+ }
+ }
name: "torch-jit-export"
input {
name: "0"
@@ -47,7 +57,7 @@ graph {
}
}
output {
- name: "3"
+ name: "4"
type {
tensor_type {
elem_type: 2
diff --git a/test/onnx/expect/TestOperators.test_lt.expect b/test/onnx/expect/TestOperators.test_lt.expect
index a5c67409d7..2befe25dc9 100644
--- a/test/onnx/expect/TestOperators.test_lt.expect
+++ b/test/onnx/expect/TestOperators.test_lt.expect
@@ -8,6 +8,16 @@ graph {
output: "2"
op_type: "Less"
}
+ node {
+ input: "2"
+ output: "3"
+ op_type: "Cast"
+ attribute {
+ name: "to"
+ i: 2
+ type: INT
+ }
+ }
name: "torch-jit-export"
input {
name: "0"
@@ -48,7 +58,7 @@ graph {
}
}
output {
- name: "2"
+ name: "3"
type {
tensor_type {
elem_type: 2
diff --git a/test/onnx/expect/TestOperators.test_ne.expect b/test/onnx/expect/TestOperators.test_ne.expect
index 078204f6aa..50505c4e1b 100644
--- a/test/onnx/expect/TestOperators.test_ne.expect
+++ b/test/onnx/expect/TestOperators.test_ne.expect
@@ -13,6 +13,16 @@ graph {
output: "3"
op_type: "Not"
}
+ node {
+ input: "3"
+ output: "4"
+ op_type: "Cast"
+ attribute {
+ name: "to"
+ i: 2
+ type: INT
+ }
+ }
name: "torch-jit-export"
input {
name: "0"
@@ -53,7 +63,7 @@ graph {
}
}
output {
- name: "3"
+ name: "4"
type {
tensor_type {
elem_type: 2