diff options
author | Adam Paszke <adam.paszke@gmail.com> | 2018-09-05 06:28:44 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-09-05 06:35:13 -0700 |
commit | b7038f7c37e955f7400459bbfc9382a77b16377d (patch) | |
tree | 21ea7fb1a2a35f8ac6c307d786fbf2e1004c201f /test/test_jit.py | |
parent | b7cd4b692c294d93986da6565f3e0ff88ae6afae (diff) | |
download | pytorch-b7038f7c37e955f7400459bbfc9382a77b16377d.tar.gz pytorch-b7038f7c37e955f7400459bbfc9382a77b16377d.tar.bz2 pytorch-b7038f7c37e955f7400459bbfc9382a77b16377d.zip |
Treat numerical differences as warnings instead of errors when tracing (#11246)
Summary:
Also, make `torch.isclose` work with integral tensors and refactor `_check_trace` a bit.
zdevito
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11246
Differential Revision: D9652701
Pulled By: apaszke
fbshipit-source-id: fb0bdbfd1952e45e153541e4d471b423a5659f25
Diffstat (limited to 'test/test_jit.py')
-rw-r--r-- | test/test_jit.py | 55 |
1 files changed, 36 insertions, 19 deletions
diff --git a/test/test_jit.py b/test/test_jit.py index 9613ced33e..7d4cbef2c8 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -83,12 +83,12 @@ def LSTMCellC(*args, **kwargs): def LSTMCellS(x, hx, cx, w_ih, w_hh, b_ih, b_hh): gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) - ingate = F.sigmoid(ingate) - forgetgate = F.sigmoid(forgetgate) - cellgate = F.tanh(cellgate) - outgate = F.sigmoid(outgate) + ingate = torch.sigmoid(ingate) + forgetgate = torch.sigmoid(forgetgate) + cellgate = torch.tanh(cellgate) + outgate = torch.sigmoid(outgate) cy = (forgetgate * cx) + (ingate * cellgate) - hy = outgate * F.tanh(cy) + hy = outgate * torch.tanh(cy) return hy, cy @@ -6239,6 +6239,7 @@ a") y = torch.arange(0, x.shape[0]).double() return x + y.unsqueeze(1) + @suppress_warnings def test_trace_checker_dot_data(self): with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Tensor-valued Constant nodes differed in value ' r'across invocations'): @@ -6249,13 +6250,15 @@ a") @suppress_warnings def test_trace_checker_control_flow(self): + def foo(x): + for _ in range(x.size(0)): + x = torch.neg(x) + return x + with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'): - @_trace(torch.rand(3, 4), check_inputs=[(torch.rand(4, 4),)]) - def foo(x): - for _ in range(x.size(0)): - x = torch.neg(x) - return x + torch.jit.trace(foo, torch.randn(3, 4), check_inputs=[torch.randn(4, 4)]) + @suppress_warnings def test_trace_checker_memoization(self): with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'): def foo(x): @@ -6277,13 +6280,19 @@ a") for i in range(3): x[i, :] = torch.zeros(4) return x - self.checkTracerWarning(foo, torch.rand(3, 4)) + + self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(3, 4)]), + 'Output nr 1. of the traced function does not match the ' + 'corresponding output of the Python function') def test_trace_checker_inplace_on_view(self): def foo(x): x.view(-1).add_(-x.view(-1)) return x - self.checkTracerWarning(foo, torch.rand(3, 4), check_trace=False) + + self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]), + 'Output nr 1. of the traced function does not match the ' + 'corresponding output of the Python function') def test_lhs_index_fails(self): def foo(x): @@ -6297,11 +6306,22 @@ a") return y self.checkTrace(foo, (torch.rand(3, 4), torch.rand(4)), inputs_require_grads=False) + def test_inplace_warn(self): + def foo(x): + x.view(-1).add_(-x.view(-1)) + return x + self.checkTracerWarning(foo, torch.rand(3, 4)) + + @suppress_warnings def test_trace_checker_dropout_train(self): - with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Trace had nondeterministic nodes'): - @_trace(torch.rand(3, 4)) - def foo(x): - return torch.dropout(x, p=0.5, train=True) + def foo(x): + return torch.dropout(x, p=0.5, train=True) + + self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]), + 'Output nr 1. of the traced function does not match the ' + 'corresponding output of the Python function') + self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]), + 'Trace had nondeterministic nodes') def test_trace_checker_dropout_notrain(self): input = torch.rand(3, 4) @@ -6516,9 +6536,6 @@ class TestEndToEndHybridFrontendModels(JitTestCase): @skipIfRocm def test_snli(self): - # TODO: - # 1) nn.LSTM is called as a Python function https://github.com/pytorch/pytorch/issues/8449 - # 2) Dropout is called as a Python function https://github.com/pytorch/pytorch/issues/8450 class Bottle(nn.Module): def forward(self, input): |