summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--test/onnx/expect/TestOperators.test_nonzero.expect49
-rw-r--r--test/onnx/test_operators.py3
-rw-r--r--torch/onnx/symbolic.py5
3 files changed, 57 insertions, 0 deletions
diff --git a/test/onnx/expect/TestOperators.test_nonzero.expect b/test/onnx/expect/TestOperators.test_nonzero.expect
new file mode 100644
index 0000000000..ffe30b3601
--- /dev/null
+++ b/test/onnx/expect/TestOperators.test_nonzero.expect
@@ -0,0 +1,49 @@
+ir_version: 4
+producer_name: "pytorch"
+producer_version: "0.4"
+graph {
+ node {
+ input: "0"
+ output: "1"
+ op_type: "NonZero"
+ }
+ name: "torch-jit-export"
+ input {
+ name: "0"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_value: 2
+ }
+ }
+ }
+ }
+ }
+ output {
+ name: "1"
+ type {
+ tensor_type {
+ elem_type: 7
+ shape {
+ dim {
+ dim_value: 5
+ }
+ dim {
+ dim_value: 3
+ }
+ }
+ }
+ }
+ }
+}
+opset_import {
+ version: 9
+}
diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py
index aaffdfb972..4e1b6f276f 100644
--- a/test/onnx/test_operators.py
+++ b/test/onnx/test_operators.py
@@ -526,6 +526,9 @@ class TestOperators(TestCase):
x = torch.randn(3, 4, requires_grad=True)
self.assertONNX(lambda x: torch.max(functional.dropout(x, training=False)), x)
+ def test_nonzero(self):
+ x = torch.tensor([[[2., 2.], [1., 0.]], [[0., 0.], [1., 1.]]], requires_grad=True)
+ self.assertONNX(lambda x: torch.nonzero(x), x)
if __name__ == '__main__':
no_onnx_dep_flag = '--no-onnx'
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py
index 593bd308f6..3f00a7b978 100644
--- a/torch/onnx/symbolic.py
+++ b/torch/onnx/symbolic.py
@@ -1515,3 +1515,8 @@ def flatten(g, input, start_dim, end_dim):
shape = g.op("Constant", value_t=torch.LongTensor(output_dims))
p = _reshape_from_tensor(g, input, shape)
return p
+
+
+@parse_args('v')
+def nonzero(g, input):
+ return g.op('NonZero', input)