diff options
author | eellison <elias_ellison@brown.edu> | 2019-03-25 21:48:11 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-03-25 21:50:57 -0700 |
commit | dc6b5b2a52d182019f6f3f5f21b58e2fb592d993 (patch) | |
tree | 9573b62e403afd295da897e883298b49f7213cb9 /test | |
parent | a729630cbfa5fe02d4b8651845f293d6c70e728f (diff) | |
download | pytorch-dc6b5b2a52d182019f6f3f5f21b58e2fb592d993.tar.gz pytorch-dc6b5b2a52d182019f6f3f5f21b58e2fb592d993.tar.bz2 pytorch-dc6b5b2a52d182019f6f3f5f21b58e2fb592d993.zip |
Optimize boolean expressions & unwraps (#18259)
Summary:
Simplify or eliminate boolean and/or expressions, optimize unwrapping a value that cannot be None, and optimize using `is` with a None and a non-None value
Since peephole optimize is now introducing constants, i added another constant propagation pass after running it.
Previously i had a PR that did this & optimized shape ops - i will add the shape optimizations in a separate PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18259
Differential Revision: D14602749
Pulled By: eellison
fbshipit-source-id: 1c3f5a67067d8dfdf55d7b78dcb616472ea8a267
Diffstat (limited to 'test')
-rw-r--r-- | test/cpp/jit/test.cpp | 4 | ||||
-rw-r--r-- | test/cpp/jit/test_peephole_optimize.h | 104 | ||||
-rw-r--r-- | test/test_jit.py | 20 |
3 files changed, 127 insertions, 1 deletions
diff --git a/test/cpp/jit/test.cpp b/test/cpp/jit/test.cpp index 1abb5ed505..1c4823dd98 100644 --- a/test/cpp/jit/test.cpp +++ b/test/cpp/jit/test.cpp @@ -24,6 +24,7 @@ #include <test/cpp/jit/test_ivalue.h> #include <test/cpp/jit/test_misc.h> #include <test/cpp/jit/test_netdef_converter.h> +#include <test/cpp/jit/test_peephole_optimize.h> #include <test/cpp/jit/test_subgraph_utils.h> using namespace torch::jit::script; @@ -61,7 +62,8 @@ namespace jit { _(THNNConv) \ _(ATenNativeBatchNorm) \ _(NoneSchemaMatch) \ - _(ClassParser) + _(ClassParser) \ + _(PeepholeOptimize) #define TH_FORALL_TESTS_CUDA(_) \ _(ArgumentSpec) \ diff --git a/test/cpp/jit/test_peephole_optimize.h b/test/cpp/jit/test_peephole_optimize.h new file mode 100644 index 0000000000..32aacf43e4 --- /dev/null +++ b/test/cpp/jit/test_peephole_optimize.h @@ -0,0 +1,104 @@ +#pragma once + +#include <test/cpp/jit/test_base.h> +#include <test/cpp/jit/test_utils.h> + +#include <torch/csrc/jit/irparser.h> +#include <torch/csrc/jit/passes/peephole.h> + +namespace torch { +namespace jit { + +using namespace script; +using namespace testing; + +namespace test { + +void testPeepholeOptimize() { + // test is / is not none optimization + { + auto graph = std::make_shared<Graph>(); + parseIR( + R"IR( +graph(%0 : int): + %1 : None = prim::Constant() + %2 : bool = aten::__is__(%0, %1) + %3 : bool = aten::__isnot__(%0, %1) + return (%2, %3) + )IR", + graph.get()); + PeepholeOptimize(graph); + testing::FileCheck() + .check_not("aten::__is__") + ->check_not("aten::__isnot__") + ->run(*graph); + } + { + auto graph = std::make_shared<Graph>(); + parseIR( + R"IR( +graph(%0: int?): + %1 : None = prim::Constant() + %2 : bool = aten::__is__(%0, %1) + %3 : bool = aten::__isnot__(%0, %1) + return (%2, %3) + )IR", + graph.get()); + PeepholeOptimize(graph); + testing::FileCheck() + .check("aten::__is__") + ->check("aten::__isnot__") + ->run(*graph); + } + + { + auto graph = std::make_shared<Graph>(); + parseIR( + R"IR( +graph(%0: int?): + %1 : Tensor = prim::AutogradZero() + %2 : None = prim::Constant() + %4 : bool = aten::__is__(%0, %1) + %5 : bool = aten::__isnot__(%1, %2) + return (%4, %5) + )IR", + graph.get()); + PeepholeOptimize(graph); + testing::FileCheck() + .check("aten::__is__") + ->check_not("aten::__isnot__") + ->run(*graph); + } + + // test unwrap optional + { + auto graph = std::make_shared<Graph>(); + parseIR( + R"IR( +graph(): + %1 : Float(*, *, *) = prim::Constant() + %2 : bool = aten::_unwrap_optional(%1) + %3 : bool = prim::unchecked_unwrap_optional(%1) + return (%2, %3) + )IR", + graph.get()); + PeepholeOptimize(graph); + testing::FileCheck().check_not("unwrap")->run(*graph); + } + { + auto graph = std::make_shared<Graph>(); + parseIR( + R"IR( +graph(%1 : Float(*, *, *)?): + %2 : bool = aten::_unwrap_optional(%1) + %3 : bool = prim::unchecked_unwrap_optional(%1) + return (%2, %3) + )IR", + graph.get()); + PeepholeOptimize(graph); + testing::FileCheck().check_count("unwrap", 2)->run(*graph); + } +} +} // namespace test +} // namespace jit +} // namespace torch diff --git a/test/test_jit.py b/test/test_jit.py index c2318e1dfd..6ef2c1fca8 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1881,6 +1881,26 @@ class TestJit(JitTestCase): # testing that 1 // 0 error is not thrownn self.run_pass('constant_propagation', constant_prop.graph) + def test_short_circuit_optimization(self): + @torch.jit.script + def const_expressions(x): + # type: (int) -> Tuple[bool, bool] + return x == 1 and False, x == 1 or True + self.run_pass('constant_propagation', const_expressions.graph) + FileCheck().check_not("prim::If").check_not("aten::eq").run(const_expressions.graph) + self.assertEqual(const_expressions(1), (False, True)) + + @torch.jit.script + def redundant_expressions(x): + # type: (int) -> Tuple[bool, bool] + return x == 1 and True, x == 1 or False + + self.run_pass('peephole', redundant_expressions.graph) + self.assertEqual(redundant_expressions(1), (True, True)) + self.assertEqual(redundant_expressions(0), (False, False)) + # and True / or False are removed from graph + FileCheck().check("aten::eq").check_not("prim::If").run(redundant_expressions.graph) + def test_trace_records_names(self): def foo(bar, baz): baz = bar + 3 |