diff options
-rw-r--r-- | test/test_jit.py | 36 |
1 files changed, 24 insertions, 12 deletions
diff --git a/test/test_jit.py b/test/test_jit.py index a246a42272..f725f1c350 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -5890,21 +5890,33 @@ a") self.checkScript(func, ()) def test_tensor_shape_prop(self): - template = dedent(''' - def func(): - li = {list_create} - return torch.tensor(li) - ''') + def func1(): + return torch.tensor([1]) + + def func2(): + return torch.tensor([False]) + + def func3(): + return torch.tensor([2.5]) - list_input = ["[1]", "[False]", "[2.5]", "0.5", "1", "False", "[[1]]"] + def func4(): + return torch.tensor(0.5) + + def func5(): + return torch.tensor(1) + + def func6(): + return torch.tensor(False) + + def func7(): + return torch.tensor([[1]]) + + list_input = [func1, func2, func3, func4, func5, func6, func7] expected_shape = ["Long(*)", ("Byte(*)"), "Double(*)", "Double()", "Long()", "Byte()", "Long(*, *)"] - for list_i, expect in zip(list_input, expected_shape): - code = template.format(list_create=list_i) - scope = {} - exec(code, globals(), scope) - cu = torch.jit.CompilationUnit(code) - g = cu.func + for fn, expect in zip(list_input, expected_shape): + self.checkScript(fn, ()) + g = torch.jit.script(fn) torch._C._jit_pass_complete_shape_analysis(g.graph, (), False) FileCheck().check(expect).check("aten::tensor").run(g.graph) |