summaryrefslogtreecommitdiff
path: root/test/onnx/test_verify.py
blob: d25d48442b9c1c70cbdba4067cd0ad912a56925b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
from torch.autograd import Function
from torch.nn import Module, Parameter
import caffe2.python.onnx.backend as backend
from verify import verify

from test_pytorch_common import TestCase, run_tests

import unittest


class TestVerify(TestCase):
    maxDiff = None

    def assertVerifyExpectFail(self, *args, **kwargs):
        try:
            verify(*args, **kwargs)
        except AssertionError as e:
            if str(e):
                # substring a small piece of string because the exact message
                # depends on system's formatting settings
                # self.assertExpected(str(e)[:60])
                # NB: why we comment out the above check? because numpy keeps
                # changing the error format, and we have to keep updating the
                # expect files let's relax this constraint
                return
            else:
                raise
        # Don't put this in the try block; the AssertionError will catch it
        self.assertTrue(False, msg="verify() did not fail when expected to")

    def test_result_different(self):
        class BrokenAdd(Function):
            @staticmethod
            def symbolic(g, a, b):
                return g.op("Add", a, b)

            @staticmethod
            def forward(ctx, a, b):
                return a.sub(b)  # yahaha! you found me!

        class MyModel(Module):
            def forward(self, x, y):
                return BrokenAdd().apply(x, y)

        x = torch.tensor([1, 2])
        y = torch.tensor([3, 4])
        self.assertVerifyExpectFail(MyModel(), (x, y), backend)

    def test_jumbled_params(self):
        class MyModel(Module):
            def __init__(self):
                super(MyModel, self).__init__()

            def forward(self, x):
                y = x * x
                self.param = Parameter(torch.tensor([2.0]))
                return y

        x = torch.tensor([1, 2])
        with self.assertRaisesRegex(RuntimeError, "state_dict changed"):
            verify(MyModel(), x, backend)

    def test_modifying_params(self):
        class MyModel(Module):
            def __init__(self):
                super(MyModel, self).__init__()
                self.param = Parameter(torch.tensor([2.0]))

            def forward(self, x):
                y = x * x
                self.param.data.add_(1.0)
                return y

        x = torch.tensor([1, 2])
        self.assertVerifyExpectFail(MyModel(), x, backend)

    def test_dynamic_model_structure(self):
        class MyModel(Module):
            def __init__(self):
                super(MyModel, self).__init__()
                self.iters = 0

            def forward(self, x):
                if self.iters % 2 == 0:
                    r = x * x
                else:
                    r = x + x
                self.iters += 1
                return r

        x = torch.tensor([1, 2])
        self.assertVerifyExpectFail(MyModel(), x, backend)

    @unittest.skip("Indexing is broken by #3725")
    def test_embedded_constant_difference(self):
        class MyModel(Module):
            def __init__(self):
                super(MyModel, self).__init__()
                self.iters = 0

            def forward(self, x):
                r = x[self.iters % 2]
                self.iters += 1
                return r

        x = torch.tensor([[1, 2], [3, 4]])
        self.assertVerifyExpectFail(MyModel(), x, backend)

    def test_explicit_test_args(self):
        class MyModel(Module):
            def forward(self, x):
                if x.data.sum() == 1.0:
                    return x + x
                else:
                    return x * x

        x = torch.tensor([[6, 2]])
        y = torch.tensor([[2, -1]])
        self.assertVerifyExpectFail(MyModel(), x, backend, test_args=[(y,)])


if __name__ == '__main__':
    run_tests()