diff options
author | Elias Ellison <eellison@fb.com> | 2019-01-23 17:47:29 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-01-23 17:50:33 -0800 |
commit | 8710184eea6e20ad64c33f40cea584152bd8f3a7 (patch) | |
tree | 203a646aea8a5cb28e33ba85588b3b6d3ccf669e /test | |
parent | 4b06c063a5259f50ec4c3cdde621857ea125fa97 (diff) | |
download | pytorch-8710184eea6e20ad64c33f40cea584152bd8f3a7.tar.gz pytorch-8710184eea6e20ad64c33f40cea584152bd8f3a7.tar.bz2 pytorch-8710184eea6e20ad64c33f40cea584152bd8f3a7.zip |
Constant propagation changes (#16244)
Summary:
- remove loop node that is guaranteed not to execute
- remove extra loop outputs that are no longer needed
- if we are inlining an if node, only run constant propagation on the block that will execute
- remove the recurse argument since we only expose the Graph Constant Propagation and it's not used
This also includes a few extra hooks to python_ir that I think make it a little be easier to test graph conditions from python.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16244
Differential Revision: D13791635
Pulled By: eellison
fbshipit-source-id: d16351fffcfc8013b02015db200f8fde002e0577
Diffstat (limited to 'test')
-rw-r--r-- | test/test_jit.py | 55 |
1 files changed, 51 insertions, 4 deletions
diff --git a/test/test_jit.py b/test/test_jit.py index b26521f7cc..63df8cf3bd 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1719,6 +1719,20 @@ class TestJit(JitTestCase): graph_str = str(constant_prop.graph) self.assertTrue(graph_str.count("prim::None") == 0) + def test_constant_prop_if_inline(self): + @torch.jit.script + def constant_prop(): + cond = True + a = 1 + if cond: + a = 1 * 2 + else: + a = 1 // 0 + return a + + # testing that 1 // 0 error is not thrownn + self.run_pass('constant_propagation', constant_prop.graph) + def test_trace_records_names(self): def foo(bar, baz): baz = bar + 3 @@ -1759,16 +1773,49 @@ class TestJit(JitTestCase): def test_constant_prop_loop_constant(self): @torch.jit.script - def constant_prop(): + def constant_prop(cond, iter): + # type: (bool, int) -> int b = 0 while True: - b = 1 + print("stays") + for _ in range(2): + print("stays") + for _ in range(iter): + print("stays") + while cond: + print("stays") while False: - b = 2 + print("removed") + for _i in range(0): + print("removed") + for _i in range(-4): + print("removed") return b self.run_pass('constant_propagation', constant_prop.graph) - self.assertExpected(canonical(constant_prop.graph)) + graph = canonical(constant_prop.graph) + self.assertTrue(graph.count("removed") == 0) + self.assertTrue(graph.count("stays") == 1) # constant gets pooled + self.assertTrue(graph.count("prim::Print") == 4) + + def test_constant_prop_remove_output(self): + @torch.jit.script + def constant_prop(iter): + # type: (int) -> None + a = 1 + b = 1 + c = 1 + for i in range(iter): + if False: + a = 10 + if i == 5: + b = 2 + c = 3 + print(a, b, c) + + graph = constant_prop.graph + self.run_pass('constant_propagation', graph) + self.assertTrue(graph.findNode("prim::Loop").outputsSize() == 2) def test_trace_detach(self): def foo(x, w): |