diff options
author | Elias Ellison <eellison@fb.com> | 2019-03-27 18:11:45 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-03-27 18:16:05 -0700 |
commit | 0daafe02098825842fbe5d1682e88e63ae6868c1 (patch) | |
tree | 9335eaf0dea96f25c4119a026ed2403f1e5efc69 /test | |
parent | ad1ebf70827e367c8d0eae8852e11f2289301607 (diff) | |
download | pytorch-0daafe02098825842fbe5d1682e88e63ae6868c1.tar.gz pytorch-0daafe02098825842fbe5d1682e88e63ae6868c1.tar.bz2 pytorch-0daafe02098825842fbe5d1682e88e63ae6868c1.zip |
Add parsing to file check (#18304)
Summary:
This allows you to embed checks in IR, making the test more readable.
E.g.
```
graph_str = 'graph(%0 : Double(5, 5)):
# CHECK: aten::relu
%1 : Double(5, 5) = aten::relu(%0)
return (%1)'
FileCheck().run(graph_str, parseIR(graph_str))
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18304
Differential Revision: D14652372
Pulled By: eellison
fbshipit-source-id: 7430b9d1dc2b7584704375aac02d7392ecec76a0
Diffstat (limited to 'test')
-rw-r--r-- | test/cpp/jit/test_irparser.h | 14 | ||||
-rw-r--r-- | test/test_jit.py | 63 |
2 files changed, 76 insertions, 1 deletions
diff --git a/test/cpp/jit/test_irparser.h b/test/cpp/jit/test_irparser.h index c8c3eeaf9a..af8d28a927 100644 --- a/test/cpp/jit/test_irparser.h +++ b/test/cpp/jit/test_irparser.h @@ -2,6 +2,7 @@ #include <torch/csrc/jit/ir.h> #include <torch/csrc/jit/irparser.h> +#include <torch/csrc/jit/testing/file_check.h> #include "test/cpp/jit/test_base.h" #include <sstream> @@ -211,6 +212,19 @@ graph(%0 : Tensor, } AT_ASSERT(error_thrown); } + + { + auto graph = std::make_shared<Graph>(); + const std::string& text = + R"IR( + graph(%a): + # CHECK: return + return (%a))IR"; + + script::parseIR(text, &*graph); + graph->inputs()[0]->type()->expect<TensorType>(); + torch::jit::testing::FileCheck().run(text, *graph); + } } } // namespace jit } // namespace torch diff --git a/test/test_jit.py b/test/test_jit.py index 694926bf27..c0dc575334 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -43,7 +43,7 @@ from common_methods_invocations import create_input, unpack_variables, \ exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL from torch.testing import FileCheck from torch._C import TensorType, TupleType, FloatType, IntType, \ - ListType, StringType, DictType + ListType, StringType, DictType, parse_ir from copy import deepcopy import random from typing import List, Dict, Optional, Tuple @@ -6042,6 +6042,14 @@ a") m2.sub2.a.data.zero_() self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2))) + def test_irparser(self): + graph_str = """graph(%0 : Double(5, 5)): + # CHECK: aten::relu + %1 : Double(5, 5) = aten::relu(%0) + return (%1) + """ + FileCheck().run(graph_str, parse_ir(graph_str)) + def test_filecheck(self): def test_check(): file = "232" @@ -6134,6 +6142,59 @@ a") with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'): fb.run("22 1 22") + def test_filecheck_parse(self): + def test_check(): + file = """ + # CHECK: 2 + # CHECK: 3 + # CHECK: 2 + 232 + """ + FileCheck().run(checks_file=file, test_file=file) + file = """ + # CHECK: 232 + 232 + """ + FileCheck().run(file, "232") + with self.assertRaisesRegex(RuntimeError, 'Expected to find "232"'): + FileCheck().run(file, "22") + with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'): + FileCheck().run("# CHECK: 22", "23") + test_check() + + def test_check_count(): + file = "22222" + FileCheck().run("# CHECK-COUNT-5: 2", file) + FileCheck().run("# CHECK-COUNT-EXACTLY-5: 2", file) + FileCheck().run("# CHECK-COUNT-2: 22", file) + FileCheck().run("# CHECK-COUNT-1: 222", file) + + with self.assertRaisesRegex(RuntimeError, 'Expected to not find'): + FileCheck().run("# CHECK-COUNT-EXACTLY-2: 2", file) + test_check_count() + + def test_check_same(): + file = "22\n33" + FileCheck().run("# CHECK-SAME: 22", file) + + with self.assertRaisesRegex(RuntimeError, "Expected to not find"): + FileCheck().run("# CHECK-SAME: 33", file) + + file = "22 1 3" + + FileCheck().run("# CHECK: 2\n # CHECK-SAME: 3", file) + FileCheck().run("# CHECK-COUNT-2: 2\n # CHECK-SAME: 3", file) + test_check_same() + + def test_bad_input(): + with self.assertRaisesRegex(RuntimeError, "Check for bad input"): + FileCheck().run("", "1") + + with self.assertRaisesRegex(RuntimeError, "Could not parse check"): + FileCheck().run("# CHECK1", "") + + test_bad_input() + def test_script_module_call_noscript(self): class M(torch.jit.ScriptModule): def __init__(self): |