summaryrefslogtreecommitdiff
path: root/test/test_jit.py
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2018-08-22 15:21:04 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-08-22 15:37:10 -0700
commitf72e813c2f8995dea26b0d302fd1b363f5c2b9e2 (patch)
tree345f817b8c1e6b34356a2908a5b02f01e3574981 /test/test_jit.py
parent043a2e36e57970fca630880670800447cc75e82c (diff)
downloadpytorch-f72e813c2f8995dea26b0d302fd1b363f5c2b9e2.tar.gz
pytorch-f72e813c2f8995dea26b0d302fd1b363f5c2b9e2.tar.bz2
pytorch-f72e813c2f8995dea26b0d302fd1b363f5c2b9e2.zip
Allow tracing functions that take tuples of tensors as inputs (#10637)
Summary: And return tuples. zdevito Pull Request resolved: https://github.com/pytorch/pytorch/pull/10637 Reviewed By: eellison Differential Revision: D9385892 Pulled By: apaszke fbshipit-source-id: 542f4444d909fb246d7f1d88d6fb98345de2d431
Diffstat (limited to 'test/test_jit.py')
-rw-r--r--test/test_jit.py23
1 files changed, 13 insertions, 10 deletions
diff --git a/test/test_jit.py b/test/test_jit.py
index d34b6a0b6b..508e133daf 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -331,8 +331,9 @@ class JitTestCase(TestCase):
class TestJit(JitTestCase):
def assertExportImport(self, trace, inputs):
+ graph = trace if isinstance(trace, torch._C.Graph) else trace.graph()
m = torch.jit.ScriptModule()
- m._create_method_from_graph("forward", trace.graph())
+ m._create_method_from_graph("forward", graph)
m_import = self.getExportImportCopy(m)
self.assertEqual(m.forward(*inputs), m_import.forward(*inputs))
@@ -932,6 +933,16 @@ class TestJit(JitTestCase):
def test_trace_size_with_grad(self):
self.do_trace_size(True)
+ def test_trace_tuple(self):
+ def fn(x, y):
+ return x, (x * y[1], x * y[0])
+
+ x, y = torch.randn(2, 2), (torch.ones(2, 2), torch.randn(2, 2))
+ traced_fn = torch.jit.trace(x, y)(fn)
+ self.assertEqual(traced_fn(x, y), fn(x, y))
+ self.assertExpectedGraph(traced_fn.graph)
+ self.assertExportImport(traced_fn.graph, (x, y))
+
# TODO: implement
@unittest.expectedFailure
def test_output_unflatten(self):
@@ -6568,20 +6579,12 @@ class TestCustomOperators(JitTestCase):
# Replace with actual test once we support lists.
with self.assertRaisesRegex(
RuntimeError,
- "Lists and tuples are not supported yet"
+ "Lists and strings are not supported yet"
):
a, b = torch.ones(5), torch.zeros(5)
output = torch.ops.aten.stack([a, b])
self.assertEqual(output, torch.ones(10))
- def test_passing_and_returning_tuples(self):
- # Replace with actual test once we support tuples.
- with self.assertRaisesRegex(
- RuntimeError,
- "Lists and tuples are not supported yet"
- ):
- torch.ops.aten.max_pool2d(torch.ones(5, 5), [2, 2])
-
def test_script_graph_contains_custom_op(self):
@torch.jit.script
def func(x):