summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--test/test_jit.py36
1 files changed, 24 insertions, 12 deletions
diff --git a/test/test_jit.py b/test/test_jit.py
index a246a42272..f725f1c350 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -5890,21 +5890,33 @@ a")
self.checkScript(func, ())
def test_tensor_shape_prop(self):
- template = dedent('''
- def func():
- li = {list_create}
- return torch.tensor(li)
- ''')
+ def func1():
+ return torch.tensor([1])
+
+ def func2():
+ return torch.tensor([False])
+
+ def func3():
+ return torch.tensor([2.5])
- list_input = ["[1]", "[False]", "[2.5]", "0.5", "1", "False", "[[1]]"]
+ def func4():
+ return torch.tensor(0.5)
+
+ def func5():
+ return torch.tensor(1)
+
+ def func6():
+ return torch.tensor(False)
+
+ def func7():
+ return torch.tensor([[1]])
+
+ list_input = [func1, func2, func3, func4, func5, func6, func7]
expected_shape = ["Long(*)", ("Byte(*)"), "Double(*)", "Double()", "Long()", "Byte()", "Long(*, *)"]
- for list_i, expect in zip(list_input, expected_shape):
- code = template.format(list_create=list_i)
- scope = {}
- exec(code, globals(), scope)
- cu = torch.jit.CompilationUnit(code)
- g = cu.func
+ for fn, expect in zip(list_input, expected_shape):
+ self.checkScript(fn, ())
+ g = torch.jit.script(fn)
torch._C._jit_pass_complete_shape_analysis(g.graph, (), False)
FileCheck().check(expect).check("aten::tensor").run(g.graph)