summaryrefslogtreecommitdiff
path: root/test/onnx
diff options
context:
space:
mode:
authorBowenBao <semisqg@gmail.com>2019-02-13 23:43:14 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-13 23:52:42 -0800
commit19addc7eb0bbb0074a3a57b6598382aeaa2222c9 (patch)
treef0c4f877ce90b0de5265b1c51745bdb3c62f5185 /test/onnx
parent5a26579e27aab8ab13860ff25f8c370ab2a60f91 (diff)
downloadpytorch-19addc7eb0bbb0074a3a57b6598382aeaa2222c9.tar.gz
pytorch-19addc7eb0bbb0074a3a57b6598382aeaa2222c9.tar.bz2
pytorch-19addc7eb0bbb0074a3a57b6598382aeaa2222c9.zip
Support nonzero onnx export
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17036 Differential Revision: D14079676 Pulled By: houseroad fbshipit-source-id: 562b538dd9ab330c26f15fdb34c98dc7a23571a1
Diffstat (limited to 'test/onnx')
-rw-r--r--test/onnx/expect/TestOperators.test_nonzero.expect49
-rw-r--r--test/onnx/test_operators.py3
2 files changed, 52 insertions, 0 deletions
diff --git a/test/onnx/expect/TestOperators.test_nonzero.expect b/test/onnx/expect/TestOperators.test_nonzero.expect
new file mode 100644
index 0000000000..ffe30b3601
--- /dev/null
+++ b/test/onnx/expect/TestOperators.test_nonzero.expect
@@ -0,0 +1,49 @@
+ir_version: 4
+producer_name: "pytorch"
+producer_version: "0.4"
+graph {
+ node {
+ input: "0"
+ output: "1"
+ op_type: "NonZero"
+ }
+ name: "torch-jit-export"
+ input {
+ name: "0"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_value: 2
+ }
+ }
+ }
+ }
+ }
+ output {
+ name: "1"
+ type {
+ tensor_type {
+ elem_type: 7
+ shape {
+ dim {
+ dim_value: 5
+ }
+ dim {
+ dim_value: 3
+ }
+ }
+ }
+ }
+ }
+}
+opset_import {
+ version: 9
+}
diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py
index aaffdfb972..4e1b6f276f 100644
--- a/test/onnx/test_operators.py
+++ b/test/onnx/test_operators.py
@@ -526,6 +526,9 @@ class TestOperators(TestCase):
x = torch.randn(3, 4, requires_grad=True)
self.assertONNX(lambda x: torch.max(functional.dropout(x, training=False)), x)
+ def test_nonzero(self):
+ x = torch.tensor([[[2., 2.], [1., 0.]], [[0., 0.], [1., 1.]]], requires_grad=True)
+ self.assertONNX(lambda x: torch.nonzero(x), x)
if __name__ == '__main__':
no_onnx_dep_flag = '--no-onnx'