summaryrefslogtreecommitdiff
path: root/test/test_jit.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_jit.py')
-rw-r--r--test/test_jit.py63
1 files changed, 62 insertions, 1 deletions
diff --git a/test/test_jit.py b/test/test_jit.py
index 694926bf27..c0dc575334 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -43,7 +43,7 @@ from common_methods_invocations import create_input, unpack_variables, \
exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL
from torch.testing import FileCheck
from torch._C import TensorType, TupleType, FloatType, IntType, \
- ListType, StringType, DictType
+ ListType, StringType, DictType, parse_ir
from copy import deepcopy
import random
from typing import List, Dict, Optional, Tuple
@@ -6042,6 +6042,14 @@ a")
m2.sub2.a.data.zero_()
self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2)))
+ def test_irparser(self):
+ graph_str = """graph(%0 : Double(5, 5)):
+ # CHECK: aten::relu
+ %1 : Double(5, 5) = aten::relu(%0)
+ return (%1)
+ """
+ FileCheck().run(graph_str, parse_ir(graph_str))
+
def test_filecheck(self):
def test_check():
file = "232"
@@ -6134,6 +6142,59 @@ a")
with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'):
fb.run("22 1 22")
+ def test_filecheck_parse(self):
+ def test_check():
+ file = """
+ # CHECK: 2
+ # CHECK: 3
+ # CHECK: 2
+ 232
+ """
+ FileCheck().run(checks_file=file, test_file=file)
+ file = """
+ # CHECK: 232
+ 232
+ """
+ FileCheck().run(file, "232")
+ with self.assertRaisesRegex(RuntimeError, 'Expected to find "232"'):
+ FileCheck().run(file, "22")
+ with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
+ FileCheck().run("# CHECK: 22", "23")
+ test_check()
+
+ def test_check_count():
+ file = "22222"
+ FileCheck().run("# CHECK-COUNT-5: 2", file)
+ FileCheck().run("# CHECK-COUNT-EXACTLY-5: 2", file)
+ FileCheck().run("# CHECK-COUNT-2: 22", file)
+ FileCheck().run("# CHECK-COUNT-1: 222", file)
+
+ with self.assertRaisesRegex(RuntimeError, 'Expected to not find'):
+ FileCheck().run("# CHECK-COUNT-EXACTLY-2: 2", file)
+ test_check_count()
+
+ def test_check_same():
+ file = "22\n33"
+ FileCheck().run("# CHECK-SAME: 22", file)
+
+ with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
+ FileCheck().run("# CHECK-SAME: 33", file)
+
+ file = "22 1 3"
+
+ FileCheck().run("# CHECK: 2\n # CHECK-SAME: 3", file)
+ FileCheck().run("# CHECK-COUNT-2: 2\n # CHECK-SAME: 3", file)
+ test_check_same()
+
+ def test_bad_input():
+ with self.assertRaisesRegex(RuntimeError, "Check for bad input"):
+ FileCheck().run("", "1")
+
+ with self.assertRaisesRegex(RuntimeError, "Could not parse check"):
+ FileCheck().run("# CHECK1", "")
+
+ test_bad_input()
+
def test_script_module_call_noscript(self):
class M(torch.jit.ScriptModule):
def __init__(self):