summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorElias Ellison <eellison@fb.com>2019-03-27 18:11:45 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-03-27 18:16:05 -0700
commit0daafe02098825842fbe5d1682e88e63ae6868c1 (patch)
tree9335eaf0dea96f25c4119a026ed2403f1e5efc69 /test
parentad1ebf70827e367c8d0eae8852e11f2289301607 (diff)
downloadpytorch-0daafe02098825842fbe5d1682e88e63ae6868c1.tar.gz
pytorch-0daafe02098825842fbe5d1682e88e63ae6868c1.tar.bz2
pytorch-0daafe02098825842fbe5d1682e88e63ae6868c1.zip
Add parsing to file check (#18304)
Summary: This allows you to embed checks in IR, making the test more readable. E.g. ``` graph_str = 'graph(%0 : Double(5, 5)): # CHECK: aten::relu %1 : Double(5, 5) = aten::relu(%0) return (%1)' FileCheck().run(graph_str, parseIR(graph_str)) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/18304 Differential Revision: D14652372 Pulled By: eellison fbshipit-source-id: 7430b9d1dc2b7584704375aac02d7392ecec76a0
Diffstat (limited to 'test')
-rw-r--r--test/cpp/jit/test_irparser.h14
-rw-r--r--test/test_jit.py63
2 files changed, 76 insertions, 1 deletions
diff --git a/test/cpp/jit/test_irparser.h b/test/cpp/jit/test_irparser.h
index c8c3eeaf9a..af8d28a927 100644
--- a/test/cpp/jit/test_irparser.h
+++ b/test/cpp/jit/test_irparser.h
@@ -2,6 +2,7 @@
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/irparser.h>
+#include <torch/csrc/jit/testing/file_check.h>
#include "test/cpp/jit/test_base.h"
#include <sstream>
@@ -211,6 +212,19 @@ graph(%0 : Tensor,
}
AT_ASSERT(error_thrown);
}
+
+ {
+ auto graph = std::make_shared<Graph>();
+ const std::string& text =
+ R"IR(
+ graph(%a):
+ # CHECK: return
+ return (%a))IR";
+
+ script::parseIR(text, &*graph);
+ graph->inputs()[0]->type()->expect<TensorType>();
+ torch::jit::testing::FileCheck().run(text, *graph);
+ }
}
} // namespace jit
} // namespace torch
diff --git a/test/test_jit.py b/test/test_jit.py
index 694926bf27..c0dc575334 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -43,7 +43,7 @@ from common_methods_invocations import create_input, unpack_variables, \
exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL
from torch.testing import FileCheck
from torch._C import TensorType, TupleType, FloatType, IntType, \
- ListType, StringType, DictType
+ ListType, StringType, DictType, parse_ir
from copy import deepcopy
import random
from typing import List, Dict, Optional, Tuple
@@ -6042,6 +6042,14 @@ a")
m2.sub2.a.data.zero_()
self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2)))
+ def test_irparser(self):
+ graph_str = """graph(%0 : Double(5, 5)):
+ # CHECK: aten::relu
+ %1 : Double(5, 5) = aten::relu(%0)
+ return (%1)
+ """
+ FileCheck().run(graph_str, parse_ir(graph_str))
+
def test_filecheck(self):
def test_check():
file = "232"
@@ -6134,6 +6142,59 @@ a")
with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'):
fb.run("22 1 22")
+ def test_filecheck_parse(self):
+ def test_check():
+ file = """
+ # CHECK: 2
+ # CHECK: 3
+ # CHECK: 2
+ 232
+ """
+ FileCheck().run(checks_file=file, test_file=file)
+ file = """
+ # CHECK: 232
+ 232
+ """
+ FileCheck().run(file, "232")
+ with self.assertRaisesRegex(RuntimeError, 'Expected to find "232"'):
+ FileCheck().run(file, "22")
+ with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
+ FileCheck().run("# CHECK: 22", "23")
+ test_check()
+
+ def test_check_count():
+ file = "22222"
+ FileCheck().run("# CHECK-COUNT-5: 2", file)
+ FileCheck().run("# CHECK-COUNT-EXACTLY-5: 2", file)
+ FileCheck().run("# CHECK-COUNT-2: 22", file)
+ FileCheck().run("# CHECK-COUNT-1: 222", file)
+
+ with self.assertRaisesRegex(RuntimeError, 'Expected to not find'):
+ FileCheck().run("# CHECK-COUNT-EXACTLY-2: 2", file)
+ test_check_count()
+
+ def test_check_same():
+ file = "22\n33"
+ FileCheck().run("# CHECK-SAME: 22", file)
+
+ with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
+ FileCheck().run("# CHECK-SAME: 33", file)
+
+ file = "22 1 3"
+
+ FileCheck().run("# CHECK: 2\n # CHECK-SAME: 3", file)
+ FileCheck().run("# CHECK-COUNT-2: 2\n # CHECK-SAME: 3", file)
+ test_check_same()
+
+ def test_bad_input():
+ with self.assertRaisesRegex(RuntimeError, "Check for bad input"):
+ FileCheck().run("", "1")
+
+ with self.assertRaisesRegex(RuntimeError, "Could not parse check"):
+ FileCheck().run("# CHECK1", "")
+
+ test_bad_input()
+
def test_script_module_call_noscript(self):
class M(torch.jit.ScriptModule):
def __init__(self):