summaryrefslogtreecommitdiff
path: root/test/test_jit.py
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2018-09-05 06:28:44 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-09-05 06:35:13 -0700
commitb7038f7c37e955f7400459bbfc9382a77b16377d (patch)
tree21ea7fb1a2a35f8ac6c307d786fbf2e1004c201f /test/test_jit.py
parentb7cd4b692c294d93986da6565f3e0ff88ae6afae (diff)
downloadpytorch-b7038f7c37e955f7400459bbfc9382a77b16377d.tar.gz
pytorch-b7038f7c37e955f7400459bbfc9382a77b16377d.tar.bz2
pytorch-b7038f7c37e955f7400459bbfc9382a77b16377d.zip
Treat numerical differences as warnings instead of errors when tracing (#11246)
Summary: Also, make `torch.isclose` work with integral tensors and refactor `_check_trace` a bit. zdevito Pull Request resolved: https://github.com/pytorch/pytorch/pull/11246 Differential Revision: D9652701 Pulled By: apaszke fbshipit-source-id: fb0bdbfd1952e45e153541e4d471b423a5659f25
Diffstat (limited to 'test/test_jit.py')
-rw-r--r--test/test_jit.py55
1 files changed, 36 insertions, 19 deletions
diff --git a/test/test_jit.py b/test/test_jit.py
index 9613ced33e..7d4cbef2c8 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -83,12 +83,12 @@ def LSTMCellC(*args, **kwargs):
def LSTMCellS(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
- ingate = F.sigmoid(ingate)
- forgetgate = F.sigmoid(forgetgate)
- cellgate = F.tanh(cellgate)
- outgate = F.sigmoid(outgate)
+ ingate = torch.sigmoid(ingate)
+ forgetgate = torch.sigmoid(forgetgate)
+ cellgate = torch.tanh(cellgate)
+ outgate = torch.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
- hy = outgate * F.tanh(cy)
+ hy = outgate * torch.tanh(cy)
return hy, cy
@@ -6239,6 +6239,7 @@ a")
y = torch.arange(0, x.shape[0]).double()
return x + y.unsqueeze(1)
+ @suppress_warnings
def test_trace_checker_dot_data(self):
with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Tensor-valued Constant nodes differed in value '
r'across invocations'):
@@ -6249,13 +6250,15 @@ a")
@suppress_warnings
def test_trace_checker_control_flow(self):
+ def foo(x):
+ for _ in range(x.size(0)):
+ x = torch.neg(x)
+ return x
+
with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'):
- @_trace(torch.rand(3, 4), check_inputs=[(torch.rand(4, 4),)])
- def foo(x):
- for _ in range(x.size(0)):
- x = torch.neg(x)
- return x
+ torch.jit.trace(foo, torch.randn(3, 4), check_inputs=[torch.randn(4, 4)])
+ @suppress_warnings
def test_trace_checker_memoization(self):
with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'):
def foo(x):
@@ -6277,13 +6280,19 @@ a")
for i in range(3):
x[i, :] = torch.zeros(4)
return x
- self.checkTracerWarning(foo, torch.rand(3, 4))
+
+ self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(3, 4)]),
+ 'Output nr 1. of the traced function does not match the '
+ 'corresponding output of the Python function')
def test_trace_checker_inplace_on_view(self):
def foo(x):
x.view(-1).add_(-x.view(-1))
return x
- self.checkTracerWarning(foo, torch.rand(3, 4), check_trace=False)
+
+ self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]),
+ 'Output nr 1. of the traced function does not match the '
+ 'corresponding output of the Python function')
def test_lhs_index_fails(self):
def foo(x):
@@ -6297,11 +6306,22 @@ a")
return y
self.checkTrace(foo, (torch.rand(3, 4), torch.rand(4)), inputs_require_grads=False)
+ def test_inplace_warn(self):
+ def foo(x):
+ x.view(-1).add_(-x.view(-1))
+ return x
+ self.checkTracerWarning(foo, torch.rand(3, 4))
+
+ @suppress_warnings
def test_trace_checker_dropout_train(self):
- with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Trace had nondeterministic nodes'):
- @_trace(torch.rand(3, 4))
- def foo(x):
- return torch.dropout(x, p=0.5, train=True)
+ def foo(x):
+ return torch.dropout(x, p=0.5, train=True)
+
+ self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]),
+ 'Output nr 1. of the traced function does not match the '
+ 'corresponding output of the Python function')
+ self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]),
+ 'Trace had nondeterministic nodes')
def test_trace_checker_dropout_notrain(self):
input = torch.rand(3, 4)
@@ -6516,9 +6536,6 @@ class TestEndToEndHybridFrontendModels(JitTestCase):
@skipIfRocm
def test_snli(self):
- # TODO:
- # 1) nn.LSTM is called as a Python function https://github.com/pytorch/pytorch/issues/8449
- # 2) Dropout is called as a Python function https://github.com/pytorch/pytorch/issues/8450
class Bottle(nn.Module):
def forward(self, input):