diff options
-rw-r--r-- | test/onnx/expect/TestOperators.test_nonzero.expect | 49 | ||||
-rw-r--r-- | test/onnx/test_operators.py | 3 | ||||
-rw-r--r-- | torch/onnx/symbolic.py | 5 |
3 files changed, 57 insertions, 0 deletions
diff --git a/test/onnx/expect/TestOperators.test_nonzero.expect b/test/onnx/expect/TestOperators.test_nonzero.expect new file mode 100644 index 0000000000..ffe30b3601 --- /dev/null +++ b/test/onnx/expect/TestOperators.test_nonzero.expect @@ -0,0 +1,49 @@ +ir_version: 4 +producer_name: "pytorch" +producer_version: "0.4" +graph { + node { + input: "0" + output: "1" + op_type: "NonZero" + } + name: "torch-jit-export" + input { + name: "0" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + output { + name: "1" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 5 + } + dim { + dim_value: 3 + } + } + } + } + } +} +opset_import { + version: 9 +} diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index aaffdfb972..4e1b6f276f 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -526,6 +526,9 @@ class TestOperators(TestCase): x = torch.randn(3, 4, requires_grad=True) self.assertONNX(lambda x: torch.max(functional.dropout(x, training=False)), x) + def test_nonzero(self): + x = torch.tensor([[[2., 2.], [1., 0.]], [[0., 0.], [1., 1.]]], requires_grad=True) + self.assertONNX(lambda x: torch.nonzero(x), x) if __name__ == '__main__': no_onnx_dep_flag = '--no-onnx' diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index 593bd308f6..3f00a7b978 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -1515,3 +1515,8 @@ def flatten(g, input, start_dim, end_dim): shape = g.op("Constant", value_t=torch.LongTensor(output_dims)) p = _reshape_from_tensor(g, input, shape) return p + + +@parse_args('v') +def nonzero(g, input): + return g.op('NonZero', input) |