diff options
author | eellison <elias_ellison@brown.edu> | 2019-04-23 20:31:36 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-23 20:39:09 -0700 |
commit | d902774cadd085c89bd27391d1a3c5a8488235de (patch) | |
tree | 76297df7952ebc9c7a60b2d57267e632f968c5a7 /test | |
parent | ba1cf3871862b2ab5681c2a0e66ad22c7795e806 (diff) | |
download | pytorch-d902774cadd085c89bd27391d1a3c5a8488235de.tar.gz pytorch-d902774cadd085c89bd27391d1a3c5a8488235de.tar.bz2 pytorch-d902774cadd085c89bd27391d1a3c5a8488235de.zip |
Dont introduce aliasing in CSE or Constant Pooling (#19576)
Summary:
We can't introduce aliasing to a graph output, since they may be mutated after.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19576
Differential Revision: D15057734
Pulled By: eellison
fbshipit-source-id: 33594c05d985a0c58edebd6252e1ee2c0efb6f0e
Diffstat (limited to 'test')
-rw-r--r-- | test/cpp/jit/test_alias_analysis.h | 16 | ||||
-rw-r--r-- | test/cpp/jit/test_constant_pooling.h | 33 | ||||
-rw-r--r-- | test/test_jit.py | 50 |
3 files changed, 61 insertions, 38 deletions
diff --git a/test/cpp/jit/test_alias_analysis.h b/test/cpp/jit/test_alias_analysis.h index 87bdcecfcb..9d121c478d 100644 --- a/test/cpp/jit/test_alias_analysis.h +++ b/test/cpp/jit/test_alias_analysis.h @@ -507,7 +507,7 @@ void testContainerAliasing() { &*graph); auto node_iter = graph->block()->nodes().begin(); - node_iter++; // string + auto str_node = node_iter++; // string Node* ten_node = *node_iter++; AliasDb aliasDb(graph); @@ -515,6 +515,8 @@ void testContainerAliasing() { for (auto out : graph->outputs()) { AT_ASSERT(aliasDb.mayContainAlias(ten_node->output(), out)); } + AT_ASSERT(aliasDb.mayContainAlias({ten_node->output()}, graph->outputs())); + AT_ASSERT(!aliasDb.mayContainAlias(str_node->output(), graph->outputs())); } { @@ -533,13 +535,13 @@ void testContainerAliasing() { auto node_iter = graph->block()->nodes().begin(); node_iter++; // string - Node* ten_node = *node_iter++; + Node* int_node = *node_iter++; AliasDb aliasDb(graph); AT_ASSERT(graph->outputs().size() == 3); // primitive values don't need to alias container for (auto out : graph->outputs()) { - AT_ASSERT(!aliasDb.mayContainAlias(ten_node->output(), out)); + AT_ASSERT(!aliasDb.mayContainAlias(int_node->output(), out)); } } @@ -561,6 +563,7 @@ void testContainerAliasing() { for (auto input : graph->inputs()) { AT_ASSERT(aliasDb.mayContainAlias(input, tuple_node->output())); } + AT_ASSERT(aliasDb.mayContainAlias(graph->inputs(), graph->outputs())); } // Test tuple that doesn't come from construct @@ -648,6 +651,13 @@ graph(): AT_ASSERT(aliasDb.mayContainAlias(first_ten->output(), tup_node->output())); AT_ASSERT( !aliasDb.mayContainAlias(second_ten->output(), tup_node->output())); + + std::vector<Value*> first_st = {first_ten->output()}; + std::vector<Value*> second_st = {second_ten->output()}; + std::vector<Value*> tup_st = {tup_node->output()}; + AT_ASSERT(aliasDb.mayContainAlias(first_st, tup_st)); + AT_ASSERT(!aliasDb.mayContainAlias(first_st, second_st)); + AT_ASSERT(!aliasDb.mayContainAlias(second_st, tup_st)); } } diff --git a/test/cpp/jit/test_constant_pooling.h b/test/cpp/jit/test_constant_pooling.h index 9a566bbdbc..e8d0da2c7d 100644 --- a/test/cpp/jit/test_constant_pooling.h +++ b/test/cpp/jit/test_constant_pooling.h @@ -34,16 +34,16 @@ graph(): script::parseIR( R"IR( graph(%cond : Tensor): - %a : string = prim::Constant[value="bcd"]() + %a : str = prim::Constant[value="bcd"]() %3 : bool = prim::Bool(%cond) - %b : string = prim::If(%3) + %b : str = prim::If(%3) block0(): - %b.1 : string = prim::Constant[value="abc"]() + %b.1 : str = prim::Constant[value="abc"]() -> (%b.1) block1(): - %b.2 : string = prim::Constant[value="abc"]() + %b.2 : str = prim::Constant[value="abc"]() -> (%b.2) - %7 : (string, string) = prim::TupleConstruct(%a, %b) + %7 : (str, str) = prim::TupleConstruct(%a, %b) return (%7) )IR", &*graph); @@ -69,8 +69,8 @@ graph(): %y : Tensor = aten::tensor(%3, %10, %7, %15) %9 : int[] = prim::ListConstruct(%1, %2) %z : Tensor = aten::tensor(%9, %10, %7, %15) - %14 : (Tensor, Tensor) = prim::TupleConstruct(%x, %y) - return (%14) + %f = prim::Print(%x, %y, %z) + return (%1) )IR", &*graph); // three tensors created - two different devices among the three @@ -82,7 +82,24 @@ graph(): ->check_count("Long(2) = prim::Constant", 1, /*exactly*/ true) ->run(*graph); } + // don't create aliasing of graph outputs in constant pooling + { + auto graph = std::make_shared<Graph>(); + script::parseIR( + R"IR( +graph(%cond : Tensor): + %a : Tensor = prim::Constant() + %b : Tensor = prim::Constant() + %c : Tensor = prim::Constant() + %1 = prim::Print(%c) + return (%a, %b) + )IR", + &*graph); + ConstantPooling(graph); + testing::FileCheck() + .check_count("prim::Constant", 2, /*exactly*/ true) + ->run(*graph); + } } - } // namespace jit } // namespace torch diff --git a/test/test_jit.py b/test/test_jit.py index 1cb922b901..a246a42272 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1139,22 +1139,40 @@ class TestJit(JitTestCase): self.assertExportImport(trace, (x, y)) + def test_cse_not_introduce_aliasing(self): + @torch.jit.script + def tensor_alias_outputs(x): + return x + x, x + x + + self.run_pass('cse', tensor_alias_outputs.graph) + FileCheck().check_count("aten::add", 2).run(tensor_alias_outputs.graph) + + @torch.jit.script + def ints_alias_outputs(x): + # type: (int) -> Tuple[int, int] + return x + x, x + x + + # non-aliasing types can be CSEd + self.run_pass('cse', ints_alias_outputs.graph) + FileCheck().check_count("aten::add", 1, exactly=True).run(ints_alias_outputs.graph) + def test_recursive_cse(self): input_str = """ graph(%x : Tensor, - %y : Tensor): + %y : Tensor, + %20 : int): %2 : int = prim::Constant[value=1]() %3 : Tensor = aten::add(%x, %y, %2) - %4 : Tensor = aten::gt(%3, %x) + %4 : int = aten::add(%2, %20) %5 : bool = prim::Bool(%4) - %z : Tensor = prim::If(%5) + %z : int = prim::If(%5) # CHECK: block block0(): # CHECK-NOT: aten::add - %z.1 : Tensor = aten::add(%x, %y, %2) + %z.1 : int = aten::add(%2, %20) -> (%z.1) block1(): - -> (%x) + -> (%2) return (%z) """ graph = parse_ir(input_str) @@ -12793,28 +12811,6 @@ class TestAutodiffSubgraphSlicing(JitTestCase): # the same group; they should each be a separate DiffGraph self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2) - def test_mutation_subgraph_inlining(self): - # cannot move a node which has writers into a differentiable subgraph, - # bc CSE might lose context that it has writers - - def fn(x): - a = x.t() - a = a + 1 - c = x.t() - c = c + 1 - e = a + c - b = a.add_(x) - d = c.add_(x) - return e, b, d - - fn_script = torch.jit.script(fn) - outs1 = fn_script(torch.tensor(0.5, requires_grad=True)) - outs2 = fn(torch.tensor(0.5, requires_grad=True)) - for i in range(len(outs1)): - self.assertEqual(outs1[i], outs2[i]) - graph = fn_script.graph_for(torch.tensor(0.5, requires_grad=True)) - FileCheck().check_not("DifferentiableGraph").run(graph) - class TestCustomOperators(JitTestCase): |