diff options
author | eellison <elias_ellison@brown.edu> | 2019-04-23 20:31:36 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-23 20:39:09 -0700 |
commit | d902774cadd085c89bd27391d1a3c5a8488235de (patch) | |
tree | 76297df7952ebc9c7a60b2d57267e632f968c5a7 /test/test_jit.py | |
parent | ba1cf3871862b2ab5681c2a0e66ad22c7795e806 (diff) | |
download | pytorch-d902774cadd085c89bd27391d1a3c5a8488235de.tar.gz pytorch-d902774cadd085c89bd27391d1a3c5a8488235de.tar.bz2 pytorch-d902774cadd085c89bd27391d1a3c5a8488235de.zip |
Dont introduce aliasing in CSE or Constant Pooling (#19576)
Summary:
We can't introduce aliasing to a graph output, since they may be mutated after.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19576
Differential Revision: D15057734
Pulled By: eellison
fbshipit-source-id: 33594c05d985a0c58edebd6252e1ee2c0efb6f0e
Diffstat (limited to 'test/test_jit.py')
-rw-r--r-- | test/test_jit.py | 50 |
1 files changed, 23 insertions, 27 deletions
diff --git a/test/test_jit.py b/test/test_jit.py index 1cb922b901..a246a42272 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1139,22 +1139,40 @@ class TestJit(JitTestCase): self.assertExportImport(trace, (x, y)) + def test_cse_not_introduce_aliasing(self): + @torch.jit.script + def tensor_alias_outputs(x): + return x + x, x + x + + self.run_pass('cse', tensor_alias_outputs.graph) + FileCheck().check_count("aten::add", 2).run(tensor_alias_outputs.graph) + + @torch.jit.script + def ints_alias_outputs(x): + # type: (int) -> Tuple[int, int] + return x + x, x + x + + # non-aliasing types can be CSEd + self.run_pass('cse', ints_alias_outputs.graph) + FileCheck().check_count("aten::add", 1, exactly=True).run(ints_alias_outputs.graph) + def test_recursive_cse(self): input_str = """ graph(%x : Tensor, - %y : Tensor): + %y : Tensor, + %20 : int): %2 : int = prim::Constant[value=1]() %3 : Tensor = aten::add(%x, %y, %2) - %4 : Tensor = aten::gt(%3, %x) + %4 : int = aten::add(%2, %20) %5 : bool = prim::Bool(%4) - %z : Tensor = prim::If(%5) + %z : int = prim::If(%5) # CHECK: block block0(): # CHECK-NOT: aten::add - %z.1 : Tensor = aten::add(%x, %y, %2) + %z.1 : int = aten::add(%2, %20) -> (%z.1) block1(): - -> (%x) + -> (%2) return (%z) """ graph = parse_ir(input_str) @@ -12793,28 +12811,6 @@ class TestAutodiffSubgraphSlicing(JitTestCase): # the same group; they should each be a separate DiffGraph self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2) - def test_mutation_subgraph_inlining(self): - # cannot move a node which has writers into a differentiable subgraph, - # bc CSE might lose context that it has writers - - def fn(x): - a = x.t() - a = a + 1 - c = x.t() - c = c + 1 - e = a + c - b = a.add_(x) - d = c.add_(x) - return e, b, d - - fn_script = torch.jit.script(fn) - outs1 = fn_script(torch.tensor(0.5, requires_grad=True)) - outs2 = fn(torch.tensor(0.5, requires_grad=True)) - for i in range(len(outs1)): - self.assertEqual(outs1[i], outs2[i]) - graph = fn_script.graph_for(torch.tensor(0.5, requires_grad=True)) - FileCheck().check_not("DifferentiableGraph").run(graph) - class TestCustomOperators(JitTestCase): |