diff options
Diffstat (limited to 'test/test_jit.py')
-rw-r--r-- | test/test_jit.py | 63 |
1 files changed, 62 insertions, 1 deletions
diff --git a/test/test_jit.py b/test/test_jit.py index 694926bf27..c0dc575334 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -43,7 +43,7 @@ from common_methods_invocations import create_input, unpack_variables, \ exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL from torch.testing import FileCheck from torch._C import TensorType, TupleType, FloatType, IntType, \ - ListType, StringType, DictType + ListType, StringType, DictType, parse_ir from copy import deepcopy import random from typing import List, Dict, Optional, Tuple @@ -6042,6 +6042,14 @@ a") m2.sub2.a.data.zero_() self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2))) + def test_irparser(self): + graph_str = """graph(%0 : Double(5, 5)): + # CHECK: aten::relu + %1 : Double(5, 5) = aten::relu(%0) + return (%1) + """ + FileCheck().run(graph_str, parse_ir(graph_str)) + def test_filecheck(self): def test_check(): file = "232" @@ -6134,6 +6142,59 @@ a") with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'): fb.run("22 1 22") + def test_filecheck_parse(self): + def test_check(): + file = """ + # CHECK: 2 + # CHECK: 3 + # CHECK: 2 + 232 + """ + FileCheck().run(checks_file=file, test_file=file) + file = """ + # CHECK: 232 + 232 + """ + FileCheck().run(file, "232") + with self.assertRaisesRegex(RuntimeError, 'Expected to find "232"'): + FileCheck().run(file, "22") + with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'): + FileCheck().run("# CHECK: 22", "23") + test_check() + + def test_check_count(): + file = "22222" + FileCheck().run("# CHECK-COUNT-5: 2", file) + FileCheck().run("# CHECK-COUNT-EXACTLY-5: 2", file) + FileCheck().run("# CHECK-COUNT-2: 22", file) + FileCheck().run("# CHECK-COUNT-1: 222", file) + + with self.assertRaisesRegex(RuntimeError, 'Expected to not find'): + FileCheck().run("# CHECK-COUNT-EXACTLY-2: 2", file) + test_check_count() + + def test_check_same(): + file = "22\n33" + FileCheck().run("# CHECK-SAME: 22", file) + + with self.assertRaisesRegex(RuntimeError, "Expected to not find"): + FileCheck().run("# CHECK-SAME: 33", file) + + file = "22 1 3" + + FileCheck().run("# CHECK: 2\n # CHECK-SAME: 3", file) + FileCheck().run("# CHECK-COUNT-2: 2\n # CHECK-SAME: 3", file) + test_check_same() + + def test_bad_input(): + with self.assertRaisesRegex(RuntimeError, "Check for bad input"): + FileCheck().run("", "1") + + with self.assertRaisesRegex(RuntimeError, "Could not parse check"): + FileCheck().run("# CHECK1", "") + + test_bad_input() + def test_script_module_call_noscript(self): class M(torch.jit.ScriptModule): def __init__(self): |