summaryrefslogtreecommitdiff
path: root/test/test_jit.py
diff options
context:
space:
mode:
authoreellison <elias_ellison@brown.edu>2019-04-23 20:31:36 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-23 20:39:09 -0700
commitd902774cadd085c89bd27391d1a3c5a8488235de (patch)
tree76297df7952ebc9c7a60b2d57267e632f968c5a7 /test/test_jit.py
parentba1cf3871862b2ab5681c2a0e66ad22c7795e806 (diff)
downloadpytorch-d902774cadd085c89bd27391d1a3c5a8488235de.tar.gz
pytorch-d902774cadd085c89bd27391d1a3c5a8488235de.tar.bz2
pytorch-d902774cadd085c89bd27391d1a3c5a8488235de.zip
Dont introduce aliasing in CSE or Constant Pooling (#19576)
Summary: We can't introduce aliasing to a graph output, since they may be mutated after. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19576 Differential Revision: D15057734 Pulled By: eellison fbshipit-source-id: 33594c05d985a0c58edebd6252e1ee2c0efb6f0e
Diffstat (limited to 'test/test_jit.py')
-rw-r--r--test/test_jit.py50
1 files changed, 23 insertions, 27 deletions
diff --git a/test/test_jit.py b/test/test_jit.py
index 1cb922b901..a246a42272 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -1139,22 +1139,40 @@ class TestJit(JitTestCase):
self.assertExportImport(trace, (x, y))
+ def test_cse_not_introduce_aliasing(self):
+ @torch.jit.script
+ def tensor_alias_outputs(x):
+ return x + x, x + x
+
+ self.run_pass('cse', tensor_alias_outputs.graph)
+ FileCheck().check_count("aten::add", 2).run(tensor_alias_outputs.graph)
+
+ @torch.jit.script
+ def ints_alias_outputs(x):
+ # type: (int) -> Tuple[int, int]
+ return x + x, x + x
+
+ # non-aliasing types can be CSEd
+ self.run_pass('cse', ints_alias_outputs.graph)
+ FileCheck().check_count("aten::add", 1, exactly=True).run(ints_alias_outputs.graph)
+
def test_recursive_cse(self):
input_str = """
graph(%x : Tensor,
- %y : Tensor):
+ %y : Tensor,
+ %20 : int):
%2 : int = prim::Constant[value=1]()
%3 : Tensor = aten::add(%x, %y, %2)
- %4 : Tensor = aten::gt(%3, %x)
+ %4 : int = aten::add(%2, %20)
%5 : bool = prim::Bool(%4)
- %z : Tensor = prim::If(%5)
+ %z : int = prim::If(%5)
# CHECK: block
block0():
# CHECK-NOT: aten::add
- %z.1 : Tensor = aten::add(%x, %y, %2)
+ %z.1 : int = aten::add(%2, %20)
-> (%z.1)
block1():
- -> (%x)
+ -> (%2)
return (%z)
"""
graph = parse_ir(input_str)
@@ -12793,28 +12811,6 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
# the same group; they should each be a separate DiffGraph
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2)
- def test_mutation_subgraph_inlining(self):
- # cannot move a node which has writers into a differentiable subgraph,
- # bc CSE might lose context that it has writers
-
- def fn(x):
- a = x.t()
- a = a + 1
- c = x.t()
- c = c + 1
- e = a + c
- b = a.add_(x)
- d = c.add_(x)
- return e, b, d
-
- fn_script = torch.jit.script(fn)
- outs1 = fn_script(torch.tensor(0.5, requires_grad=True))
- outs2 = fn(torch.tensor(0.5, requires_grad=True))
- for i in range(len(outs1)):
- self.assertEqual(outs1[i], outs2[i])
- graph = fn_script.graph_for(torch.tensor(0.5, requires_grad=True))
- FileCheck().check_not("DifferentiableGraph").run(graph)
-
class TestCustomOperators(JitTestCase):