summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authoreellison <elias_ellison@brown.edu>2019-03-25 21:48:11 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-03-25 21:50:57 -0700
commitdc6b5b2a52d182019f6f3f5f21b58e2fb592d993 (patch)
tree9573b62e403afd295da897e883298b49f7213cb9 /test
parenta729630cbfa5fe02d4b8651845f293d6c70e728f (diff)
downloadpytorch-dc6b5b2a52d182019f6f3f5f21b58e2fb592d993.tar.gz
pytorch-dc6b5b2a52d182019f6f3f5f21b58e2fb592d993.tar.bz2
pytorch-dc6b5b2a52d182019f6f3f5f21b58e2fb592d993.zip
Optimize boolean expressions & unwraps (#18259)
Summary: Simplify or eliminate boolean and/or expressions, optimize unwrapping a value that cannot be None, and optimize using `is` with a None and a non-None value Since peephole optimize is now introducing constants, i added another constant propagation pass after running it. Previously i had a PR that did this & optimized shape ops - i will add the shape optimizations in a separate PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18259 Differential Revision: D14602749 Pulled By: eellison fbshipit-source-id: 1c3f5a67067d8dfdf55d7b78dcb616472ea8a267
Diffstat (limited to 'test')
-rw-r--r--test/cpp/jit/test.cpp4
-rw-r--r--test/cpp/jit/test_peephole_optimize.h104
-rw-r--r--test/test_jit.py20
3 files changed, 127 insertions, 1 deletions
diff --git a/test/cpp/jit/test.cpp b/test/cpp/jit/test.cpp
index 1abb5ed505..1c4823dd98 100644
--- a/test/cpp/jit/test.cpp
+++ b/test/cpp/jit/test.cpp
@@ -24,6 +24,7 @@
#include <test/cpp/jit/test_ivalue.h>
#include <test/cpp/jit/test_misc.h>
#include <test/cpp/jit/test_netdef_converter.h>
+#include <test/cpp/jit/test_peephole_optimize.h>
#include <test/cpp/jit/test_subgraph_utils.h>
using namespace torch::jit::script;
@@ -61,7 +62,8 @@ namespace jit {
_(THNNConv) \
_(ATenNativeBatchNorm) \
_(NoneSchemaMatch) \
- _(ClassParser)
+ _(ClassParser) \
+ _(PeepholeOptimize)
#define TH_FORALL_TESTS_CUDA(_) \
_(ArgumentSpec) \
diff --git a/test/cpp/jit/test_peephole_optimize.h b/test/cpp/jit/test_peephole_optimize.h
new file mode 100644
index 0000000000..32aacf43e4
--- /dev/null
+++ b/test/cpp/jit/test_peephole_optimize.h
@@ -0,0 +1,104 @@
+#pragma once
+
+#include <test/cpp/jit/test_base.h>
+#include <test/cpp/jit/test_utils.h>
+
+#include <torch/csrc/jit/irparser.h>
+#include <torch/csrc/jit/passes/peephole.h>
+
+namespace torch {
+namespace jit {
+
+using namespace script;
+using namespace testing;
+
+namespace test {
+
+void testPeepholeOptimize() {
+ // test is / is not none optimization
+ {
+ auto graph = std::make_shared<Graph>();
+ parseIR(
+ R"IR(
+graph(%0 : int):
+ %1 : None = prim::Constant()
+ %2 : bool = aten::__is__(%0, %1)
+ %3 : bool = aten::__isnot__(%0, %1)
+ return (%2, %3)
+ )IR",
+ graph.get());
+ PeepholeOptimize(graph);
+ testing::FileCheck()
+ .check_not("aten::__is__")
+ ->check_not("aten::__isnot__")
+ ->run(*graph);
+ }
+ {
+ auto graph = std::make_shared<Graph>();
+ parseIR(
+ R"IR(
+graph(%0: int?):
+ %1 : None = prim::Constant()
+ %2 : bool = aten::__is__(%0, %1)
+ %3 : bool = aten::__isnot__(%0, %1)
+ return (%2, %3)
+ )IR",
+ graph.get());
+ PeepholeOptimize(graph);
+ testing::FileCheck()
+ .check("aten::__is__")
+ ->check("aten::__isnot__")
+ ->run(*graph);
+ }
+
+ {
+ auto graph = std::make_shared<Graph>();
+ parseIR(
+ R"IR(
+graph(%0: int?):
+ %1 : Tensor = prim::AutogradZero()
+ %2 : None = prim::Constant()
+ %4 : bool = aten::__is__(%0, %1)
+ %5 : bool = aten::__isnot__(%1, %2)
+ return (%4, %5)
+ )IR",
+ graph.get());
+ PeepholeOptimize(graph);
+ testing::FileCheck()
+ .check("aten::__is__")
+ ->check_not("aten::__isnot__")
+ ->run(*graph);
+ }
+
+ // test unwrap optional
+ {
+ auto graph = std::make_shared<Graph>();
+ parseIR(
+ R"IR(
+graph():
+ %1 : Float(*, *, *) = prim::Constant()
+ %2 : bool = aten::_unwrap_optional(%1)
+ %3 : bool = prim::unchecked_unwrap_optional(%1)
+ return (%2, %3)
+ )IR",
+ graph.get());
+ PeepholeOptimize(graph);
+ testing::FileCheck().check_not("unwrap")->run(*graph);
+ }
+ {
+ auto graph = std::make_shared<Graph>();
+ parseIR(
+ R"IR(
+graph(%1 : Float(*, *, *)?):
+ %2 : bool = aten::_unwrap_optional(%1)
+ %3 : bool = prim::unchecked_unwrap_optional(%1)
+ return (%2, %3)
+ )IR",
+ graph.get());
+ PeepholeOptimize(graph);
+ testing::FileCheck().check_count("unwrap", 2)->run(*graph);
+ }
+}
+} // namespace test
+} // namespace jit
+} // namespace torch
diff --git a/test/test_jit.py b/test/test_jit.py
index c2318e1dfd..6ef2c1fca8 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -1881,6 +1881,26 @@ class TestJit(JitTestCase):
# testing that 1 // 0 error is not thrownn
self.run_pass('constant_propagation', constant_prop.graph)
+ def test_short_circuit_optimization(self):
+ @torch.jit.script
+ def const_expressions(x):
+ # type: (int) -> Tuple[bool, bool]
+ return x == 1 and False, x == 1 or True
+ self.run_pass('constant_propagation', const_expressions.graph)
+ FileCheck().check_not("prim::If").check_not("aten::eq").run(const_expressions.graph)
+ self.assertEqual(const_expressions(1), (False, True))
+
+ @torch.jit.script
+ def redundant_expressions(x):
+ # type: (int) -> Tuple[bool, bool]
+ return x == 1 and True, x == 1 or False
+
+ self.run_pass('peephole', redundant_expressions.graph)
+ self.assertEqual(redundant_expressions(1), (True, True))
+ self.assertEqual(redundant_expressions(0), (False, False))
+ # and True / or False are removed from graph
+ FileCheck().check("aten::eq").check_not("prim::If").run(redundant_expressions.graph)
+
def test_trace_records_names(self):
def foo(bar, baz):
baz = bar + 3