diff options
author | Adam Paszke <adam.paszke@gmail.com> | 2018-08-22 15:21:04 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-08-22 15:37:10 -0700 |
commit | f72e813c2f8995dea26b0d302fd1b363f5c2b9e2 (patch) | |
tree | 345f817b8c1e6b34356a2908a5b02f01e3574981 /test/test_jit.py | |
parent | 043a2e36e57970fca630880670800447cc75e82c (diff) | |
download | pytorch-f72e813c2f8995dea26b0d302fd1b363f5c2b9e2.tar.gz pytorch-f72e813c2f8995dea26b0d302fd1b363f5c2b9e2.tar.bz2 pytorch-f72e813c2f8995dea26b0d302fd1b363f5c2b9e2.zip |
Allow tracing functions that take tuples of tensors as inputs (#10637)
Summary:
And return tuples.
zdevito
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10637
Reviewed By: eellison
Differential Revision: D9385892
Pulled By: apaszke
fbshipit-source-id: 542f4444d909fb246d7f1d88d6fb98345de2d431
Diffstat (limited to 'test/test_jit.py')
-rw-r--r-- | test/test_jit.py | 23 |
1 files changed, 13 insertions, 10 deletions
diff --git a/test/test_jit.py b/test/test_jit.py index d34b6a0b6b..508e133daf 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -331,8 +331,9 @@ class JitTestCase(TestCase): class TestJit(JitTestCase): def assertExportImport(self, trace, inputs): + graph = trace if isinstance(trace, torch._C.Graph) else trace.graph() m = torch.jit.ScriptModule() - m._create_method_from_graph("forward", trace.graph()) + m._create_method_from_graph("forward", graph) m_import = self.getExportImportCopy(m) self.assertEqual(m.forward(*inputs), m_import.forward(*inputs)) @@ -932,6 +933,16 @@ class TestJit(JitTestCase): def test_trace_size_with_grad(self): self.do_trace_size(True) + def test_trace_tuple(self): + def fn(x, y): + return x, (x * y[1], x * y[0]) + + x, y = torch.randn(2, 2), (torch.ones(2, 2), torch.randn(2, 2)) + traced_fn = torch.jit.trace(x, y)(fn) + self.assertEqual(traced_fn(x, y), fn(x, y)) + self.assertExpectedGraph(traced_fn.graph) + self.assertExportImport(traced_fn.graph, (x, y)) + # TODO: implement @unittest.expectedFailure def test_output_unflatten(self): @@ -6568,20 +6579,12 @@ class TestCustomOperators(JitTestCase): # Replace with actual test once we support lists. with self.assertRaisesRegex( RuntimeError, - "Lists and tuples are not supported yet" + "Lists and strings are not supported yet" ): a, b = torch.ones(5), torch.zeros(5) output = torch.ops.aten.stack([a, b]) self.assertEqual(output, torch.ones(10)) - def test_passing_and_returning_tuples(self): - # Replace with actual test once we support tuples. - with self.assertRaisesRegex( - RuntimeError, - "Lists and tuples are not supported yet" - ): - torch.ops.aten.max_pool2d(torch.ones(5, 5), [2, 2]) - def test_script_graph_contains_custom_op(self): @torch.jit.script def func(x): |