summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/cpp/jit/test_alias_analysis.h16
-rw-r--r--test/cpp/jit/test_constant_pooling.h33
-rw-r--r--test/test_jit.py50
3 files changed, 61 insertions, 38 deletions
diff --git a/test/cpp/jit/test_alias_analysis.h b/test/cpp/jit/test_alias_analysis.h
index 87bdcecfcb..9d121c478d 100644
--- a/test/cpp/jit/test_alias_analysis.h
+++ b/test/cpp/jit/test_alias_analysis.h
@@ -507,7 +507,7 @@ void testContainerAliasing() {
&*graph);
auto node_iter = graph->block()->nodes().begin();
- node_iter++; // string
+ auto str_node = node_iter++; // string
Node* ten_node = *node_iter++;
AliasDb aliasDb(graph);
@@ -515,6 +515,8 @@ void testContainerAliasing() {
for (auto out : graph->outputs()) {
AT_ASSERT(aliasDb.mayContainAlias(ten_node->output(), out));
}
+ AT_ASSERT(aliasDb.mayContainAlias({ten_node->output()}, graph->outputs()));
+ AT_ASSERT(!aliasDb.mayContainAlias(str_node->output(), graph->outputs()));
}
{
@@ -533,13 +535,13 @@ void testContainerAliasing() {
auto node_iter = graph->block()->nodes().begin();
node_iter++; // string
- Node* ten_node = *node_iter++;
+ Node* int_node = *node_iter++;
AliasDb aliasDb(graph);
AT_ASSERT(graph->outputs().size() == 3);
// primitive values don't need to alias container
for (auto out : graph->outputs()) {
- AT_ASSERT(!aliasDb.mayContainAlias(ten_node->output(), out));
+ AT_ASSERT(!aliasDb.mayContainAlias(int_node->output(), out));
}
}
@@ -561,6 +563,7 @@ void testContainerAliasing() {
for (auto input : graph->inputs()) {
AT_ASSERT(aliasDb.mayContainAlias(input, tuple_node->output()));
}
+ AT_ASSERT(aliasDb.mayContainAlias(graph->inputs(), graph->outputs()));
}
// Test tuple that doesn't come from construct
@@ -648,6 +651,13 @@ graph():
AT_ASSERT(aliasDb.mayContainAlias(first_ten->output(), tup_node->output()));
AT_ASSERT(
!aliasDb.mayContainAlias(second_ten->output(), tup_node->output()));
+
+ std::vector<Value*> first_st = {first_ten->output()};
+ std::vector<Value*> second_st = {second_ten->output()};
+ std::vector<Value*> tup_st = {tup_node->output()};
+ AT_ASSERT(aliasDb.mayContainAlias(first_st, tup_st));
+ AT_ASSERT(!aliasDb.mayContainAlias(first_st, second_st));
+ AT_ASSERT(!aliasDb.mayContainAlias(second_st, tup_st));
}
}
diff --git a/test/cpp/jit/test_constant_pooling.h b/test/cpp/jit/test_constant_pooling.h
index 9a566bbdbc..e8d0da2c7d 100644
--- a/test/cpp/jit/test_constant_pooling.h
+++ b/test/cpp/jit/test_constant_pooling.h
@@ -34,16 +34,16 @@ graph():
script::parseIR(
R"IR(
graph(%cond : Tensor):
- %a : string = prim::Constant[value="bcd"]()
+ %a : str = prim::Constant[value="bcd"]()
%3 : bool = prim::Bool(%cond)
- %b : string = prim::If(%3)
+ %b : str = prim::If(%3)
block0():
- %b.1 : string = prim::Constant[value="abc"]()
+ %b.1 : str = prim::Constant[value="abc"]()
-> (%b.1)
block1():
- %b.2 : string = prim::Constant[value="abc"]()
+ %b.2 : str = prim::Constant[value="abc"]()
-> (%b.2)
- %7 : (string, string) = prim::TupleConstruct(%a, %b)
+ %7 : (str, str) = prim::TupleConstruct(%a, %b)
return (%7)
)IR",
&*graph);
@@ -69,8 +69,8 @@ graph():
%y : Tensor = aten::tensor(%3, %10, %7, %15)
%9 : int[] = prim::ListConstruct(%1, %2)
%z : Tensor = aten::tensor(%9, %10, %7, %15)
- %14 : (Tensor, Tensor) = prim::TupleConstruct(%x, %y)
- return (%14)
+ %f = prim::Print(%x, %y, %z)
+ return (%1)
)IR",
&*graph);
// three tensors created - two different devices among the three
@@ -82,7 +82,24 @@ graph():
->check_count("Long(2) = prim::Constant", 1, /*exactly*/ true)
->run(*graph);
}
+ // don't create aliasing of graph outputs in constant pooling
+ {
+ auto graph = std::make_shared<Graph>();
+ script::parseIR(
+ R"IR(
+graph(%cond : Tensor):
+ %a : Tensor = prim::Constant()
+ %b : Tensor = prim::Constant()
+ %c : Tensor = prim::Constant()
+ %1 = prim::Print(%c)
+ return (%a, %b)
+ )IR",
+ &*graph);
+ ConstantPooling(graph);
+ testing::FileCheck()
+ .check_count("prim::Constant", 2, /*exactly*/ true)
+ ->run(*graph);
+ }
}
-
} // namespace jit
} // namespace torch
diff --git a/test/test_jit.py b/test/test_jit.py
index 1cb922b901..a246a42272 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -1139,22 +1139,40 @@ class TestJit(JitTestCase):
self.assertExportImport(trace, (x, y))
+ def test_cse_not_introduce_aliasing(self):
+ @torch.jit.script
+ def tensor_alias_outputs(x):
+ return x + x, x + x
+
+ self.run_pass('cse', tensor_alias_outputs.graph)
+ FileCheck().check_count("aten::add", 2).run(tensor_alias_outputs.graph)
+
+ @torch.jit.script
+ def ints_alias_outputs(x):
+ # type: (int) -> Tuple[int, int]
+ return x + x, x + x
+
+ # non-aliasing types can be CSEd
+ self.run_pass('cse', ints_alias_outputs.graph)
+ FileCheck().check_count("aten::add", 1, exactly=True).run(ints_alias_outputs.graph)
+
def test_recursive_cse(self):
input_str = """
graph(%x : Tensor,
- %y : Tensor):
+ %y : Tensor,
+ %20 : int):
%2 : int = prim::Constant[value=1]()
%3 : Tensor = aten::add(%x, %y, %2)
- %4 : Tensor = aten::gt(%3, %x)
+ %4 : int = aten::add(%2, %20)
%5 : bool = prim::Bool(%4)
- %z : Tensor = prim::If(%5)
+ %z : int = prim::If(%5)
# CHECK: block
block0():
# CHECK-NOT: aten::add
- %z.1 : Tensor = aten::add(%x, %y, %2)
+ %z.1 : int = aten::add(%2, %20)
-> (%z.1)
block1():
- -> (%x)
+ -> (%2)
return (%z)
"""
graph = parse_ir(input_str)
@@ -12793,28 +12811,6 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
# the same group; they should each be a separate DiffGraph
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2)
- def test_mutation_subgraph_inlining(self):
- # cannot move a node which has writers into a differentiable subgraph,
- # bc CSE might lose context that it has writers
-
- def fn(x):
- a = x.t()
- a = a + 1
- c = x.t()
- c = c + 1
- e = a + c
- b = a.add_(x)
- d = c.add_(x)
- return e, b, d
-
- fn_script = torch.jit.script(fn)
- outs1 = fn_script(torch.tensor(0.5, requires_grad=True))
- outs2 = fn(torch.tensor(0.5, requires_grad=True))
- for i in range(len(outs1)):
- self.assertEqual(outs1[i], outs2[i])
- graph = fn_script.graph_for(torch.tensor(0.5, requires_grad=True))
- FileCheck().check_not("DifferentiableGraph").run(graph)
-
class TestCustomOperators(JitTestCase):