summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorElias Ellison <eellison@fb.com>2019-01-23 17:47:29 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-01-23 17:50:33 -0800
commit8710184eea6e20ad64c33f40cea584152bd8f3a7 (patch)
tree203a646aea8a5cb28e33ba85588b3b6d3ccf669e /test
parent4b06c063a5259f50ec4c3cdde621857ea125fa97 (diff)
downloadpytorch-8710184eea6e20ad64c33f40cea584152bd8f3a7.tar.gz
pytorch-8710184eea6e20ad64c33f40cea584152bd8f3a7.tar.bz2
pytorch-8710184eea6e20ad64c33f40cea584152bd8f3a7.zip
Constant propagation changes (#16244)
Summary: - remove loop node that is guaranteed not to execute - remove extra loop outputs that are no longer needed - if we are inlining an if node, only run constant propagation on the block that will execute - remove the recurse argument since we only expose the Graph Constant Propagation and it's not used This also includes a few extra hooks to python_ir that I think make it a little be easier to test graph conditions from python. Pull Request resolved: https://github.com/pytorch/pytorch/pull/16244 Differential Revision: D13791635 Pulled By: eellison fbshipit-source-id: d16351fffcfc8013b02015db200f8fde002e0577
Diffstat (limited to 'test')
-rw-r--r--test/test_jit.py55
1 files changed, 51 insertions, 4 deletions
diff --git a/test/test_jit.py b/test/test_jit.py
index b26521f7cc..63df8cf3bd 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -1719,6 +1719,20 @@ class TestJit(JitTestCase):
graph_str = str(constant_prop.graph)
self.assertTrue(graph_str.count("prim::None") == 0)
+ def test_constant_prop_if_inline(self):
+ @torch.jit.script
+ def constant_prop():
+ cond = True
+ a = 1
+ if cond:
+ a = 1 * 2
+ else:
+ a = 1 // 0
+ return a
+
+ # testing that 1 // 0 error is not thrownn
+ self.run_pass('constant_propagation', constant_prop.graph)
+
def test_trace_records_names(self):
def foo(bar, baz):
baz = bar + 3
@@ -1759,16 +1773,49 @@ class TestJit(JitTestCase):
def test_constant_prop_loop_constant(self):
@torch.jit.script
- def constant_prop():
+ def constant_prop(cond, iter):
+ # type: (bool, int) -> int
b = 0
while True:
- b = 1
+ print("stays")
+ for _ in range(2):
+ print("stays")
+ for _ in range(iter):
+ print("stays")
+ while cond:
+ print("stays")
while False:
- b = 2
+ print("removed")
+ for _i in range(0):
+ print("removed")
+ for _i in range(-4):
+ print("removed")
return b
self.run_pass('constant_propagation', constant_prop.graph)
- self.assertExpected(canonical(constant_prop.graph))
+ graph = canonical(constant_prop.graph)
+ self.assertTrue(graph.count("removed") == 0)
+ self.assertTrue(graph.count("stays") == 1) # constant gets pooled
+ self.assertTrue(graph.count("prim::Print") == 4)
+
+ def test_constant_prop_remove_output(self):
+ @torch.jit.script
+ def constant_prop(iter):
+ # type: (int) -> None
+ a = 1
+ b = 1
+ c = 1
+ for i in range(iter):
+ if False:
+ a = 10
+ if i == 5:
+ b = 2
+ c = 3
+ print(a, b, c)
+
+ graph = constant_prop.graph
+ self.run_pass('constant_propagation', graph)
+ self.assertTrue(graph.findNode("prim::Loop").outputsSize() == 2)
def test_trace_detach(self):
def foo(x, w):