summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoreellison <elias_ellison@brown.edu>2019-04-23 20:31:36 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-23 20:39:09 -0700
commitd902774cadd085c89bd27391d1a3c5a8488235de (patch)
tree76297df7952ebc9c7a60b2d57267e632f968c5a7
parentba1cf3871862b2ab5681c2a0e66ad22c7795e806 (diff)
downloadpytorch-d902774cadd085c89bd27391d1a3c5a8488235de.tar.gz
pytorch-d902774cadd085c89bd27391d1a3c5a8488235de.tar.bz2
pytorch-d902774cadd085c89bd27391d1a3c5a8488235de.zip
Dont introduce aliasing in CSE or Constant Pooling (#19576)
Summary: We can't introduce aliasing to a graph output, since they may be mutated after. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19576 Differential Revision: D15057734 Pulled By: eellison fbshipit-source-id: 33594c05d985a0c58edebd6252e1ee2c0efb6f0e
-rw-r--r--test/cpp/jit/test_alias_analysis.h16
-rw-r--r--test/cpp/jit/test_constant_pooling.h33
-rw-r--r--test/test_jit.py50
-rw-r--r--torch/csrc/jit/autodiff.cpp10
-rw-r--r--torch/csrc/jit/passes/alias_analysis.cpp52
-rw-r--r--torch/csrc/jit/passes/alias_analysis.h9
-rw-r--r--torch/csrc/jit/passes/common_subexpression_elimination.cpp17
-rw-r--r--torch/csrc/jit/passes/constant_pooling.cpp23
-rw-r--r--torch/csrc/jit/passes/create_autodiff_subgraphs.cpp12
-rw-r--r--torch/csrc/jit/passes/utils/memory_dag.cpp42
-rw-r--r--torch/csrc/jit/passes/utils/memory_dag.h8
11 files changed, 198 insertions, 74 deletions
diff --git a/test/cpp/jit/test_alias_analysis.h b/test/cpp/jit/test_alias_analysis.h
index 87bdcecfcb..9d121c478d 100644
--- a/test/cpp/jit/test_alias_analysis.h
+++ b/test/cpp/jit/test_alias_analysis.h
@@ -507,7 +507,7 @@ void testContainerAliasing() {
&*graph);
auto node_iter = graph->block()->nodes().begin();
- node_iter++; // string
+ auto str_node = node_iter++; // string
Node* ten_node = *node_iter++;
AliasDb aliasDb(graph);
@@ -515,6 +515,8 @@ void testContainerAliasing() {
for (auto out : graph->outputs()) {
AT_ASSERT(aliasDb.mayContainAlias(ten_node->output(), out));
}
+ AT_ASSERT(aliasDb.mayContainAlias({ten_node->output()}, graph->outputs()));
+ AT_ASSERT(!aliasDb.mayContainAlias(str_node->output(), graph->outputs()));
}
{
@@ -533,13 +535,13 @@ void testContainerAliasing() {
auto node_iter = graph->block()->nodes().begin();
node_iter++; // string
- Node* ten_node = *node_iter++;
+ Node* int_node = *node_iter++;
AliasDb aliasDb(graph);
AT_ASSERT(graph->outputs().size() == 3);
// primitive values don't need to alias container
for (auto out : graph->outputs()) {
- AT_ASSERT(!aliasDb.mayContainAlias(ten_node->output(), out));
+ AT_ASSERT(!aliasDb.mayContainAlias(int_node->output(), out));
}
}
@@ -561,6 +563,7 @@ void testContainerAliasing() {
for (auto input : graph->inputs()) {
AT_ASSERT(aliasDb.mayContainAlias(input, tuple_node->output()));
}
+ AT_ASSERT(aliasDb.mayContainAlias(graph->inputs(), graph->outputs()));
}
// Test tuple that doesn't come from construct
@@ -648,6 +651,13 @@ graph():
AT_ASSERT(aliasDb.mayContainAlias(first_ten->output(), tup_node->output()));
AT_ASSERT(
!aliasDb.mayContainAlias(second_ten->output(), tup_node->output()));
+
+ std::vector<Value*> first_st = {first_ten->output()};
+ std::vector<Value*> second_st = {second_ten->output()};
+ std::vector<Value*> tup_st = {tup_node->output()};
+ AT_ASSERT(aliasDb.mayContainAlias(first_st, tup_st));
+ AT_ASSERT(!aliasDb.mayContainAlias(first_st, second_st));
+ AT_ASSERT(!aliasDb.mayContainAlias(second_st, tup_st));
}
}
diff --git a/test/cpp/jit/test_constant_pooling.h b/test/cpp/jit/test_constant_pooling.h
index 9a566bbdbc..e8d0da2c7d 100644
--- a/test/cpp/jit/test_constant_pooling.h
+++ b/test/cpp/jit/test_constant_pooling.h
@@ -34,16 +34,16 @@ graph():
script::parseIR(
R"IR(
graph(%cond : Tensor):
- %a : string = prim::Constant[value="bcd"]()
+ %a : str = prim::Constant[value="bcd"]()
%3 : bool = prim::Bool(%cond)
- %b : string = prim::If(%3)
+ %b : str = prim::If(%3)
block0():
- %b.1 : string = prim::Constant[value="abc"]()
+ %b.1 : str = prim::Constant[value="abc"]()
-> (%b.1)
block1():
- %b.2 : string = prim::Constant[value="abc"]()
+ %b.2 : str = prim::Constant[value="abc"]()
-> (%b.2)
- %7 : (string, string) = prim::TupleConstruct(%a, %b)
+ %7 : (str, str) = prim::TupleConstruct(%a, %b)
return (%7)
)IR",
&*graph);
@@ -69,8 +69,8 @@ graph():
%y : Tensor = aten::tensor(%3, %10, %7, %15)
%9 : int[] = prim::ListConstruct(%1, %2)
%z : Tensor = aten::tensor(%9, %10, %7, %15)
- %14 : (Tensor, Tensor) = prim::TupleConstruct(%x, %y)
- return (%14)
+ %f = prim::Print(%x, %y, %z)
+ return (%1)
)IR",
&*graph);
// three tensors created - two different devices among the three
@@ -82,7 +82,24 @@ graph():
->check_count("Long(2) = prim::Constant", 1, /*exactly*/ true)
->run(*graph);
}
+ // don't create aliasing of graph outputs in constant pooling
+ {
+ auto graph = std::make_shared<Graph>();
+ script::parseIR(
+ R"IR(
+graph(%cond : Tensor):
+ %a : Tensor = prim::Constant()
+ %b : Tensor = prim::Constant()
+ %c : Tensor = prim::Constant()
+ %1 = prim::Print(%c)
+ return (%a, %b)
+ )IR",
+ &*graph);
+ ConstantPooling(graph);
+ testing::FileCheck()
+ .check_count("prim::Constant", 2, /*exactly*/ true)
+ ->run(*graph);
+ }
}
-
} // namespace jit
} // namespace torch
diff --git a/test/test_jit.py b/test/test_jit.py
index 1cb922b901..a246a42272 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -1139,22 +1139,40 @@ class TestJit(JitTestCase):
self.assertExportImport(trace, (x, y))
+ def test_cse_not_introduce_aliasing(self):
+ @torch.jit.script
+ def tensor_alias_outputs(x):
+ return x + x, x + x
+
+ self.run_pass('cse', tensor_alias_outputs.graph)
+ FileCheck().check_count("aten::add", 2).run(tensor_alias_outputs.graph)
+
+ @torch.jit.script
+ def ints_alias_outputs(x):
+ # type: (int) -> Tuple[int, int]
+ return x + x, x + x
+
+ # non-aliasing types can be CSEd
+ self.run_pass('cse', ints_alias_outputs.graph)
+ FileCheck().check_count("aten::add", 1, exactly=True).run(ints_alias_outputs.graph)
+
def test_recursive_cse(self):
input_str = """
graph(%x : Tensor,
- %y : Tensor):
+ %y : Tensor,
+ %20 : int):
%2 : int = prim::Constant[value=1]()
%3 : Tensor = aten::add(%x, %y, %2)
- %4 : Tensor = aten::gt(%3, %x)
+ %4 : int = aten::add(%2, %20)
%5 : bool = prim::Bool(%4)
- %z : Tensor = prim::If(%5)
+ %z : int = prim::If(%5)
# CHECK: block
block0():
# CHECK-NOT: aten::add
- %z.1 : Tensor = aten::add(%x, %y, %2)
+ %z.1 : int = aten::add(%2, %20)
-> (%z.1)
block1():
- -> (%x)
+ -> (%2)
return (%z)
"""
graph = parse_ir(input_str)
@@ -12793,28 +12811,6 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
# the same group; they should each be a separate DiffGraph
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2)
- def test_mutation_subgraph_inlining(self):
- # cannot move a node which has writers into a differentiable subgraph,
- # bc CSE might lose context that it has writers
-
- def fn(x):
- a = x.t()
- a = a + 1
- c = x.t()
- c = c + 1
- e = a + c
- b = a.add_(x)
- d = c.add_(x)
- return e, b, d
-
- fn_script = torch.jit.script(fn)
- outs1 = fn_script(torch.tensor(0.5, requires_grad=True))
- outs2 = fn(torch.tensor(0.5, requires_grad=True))
- for i in range(len(outs1)):
- self.assertEqual(outs1[i], outs2[i])
- graph = fn_script.graph_for(torch.tensor(0.5, requires_grad=True))
- FileCheck().check_not("DifferentiableGraph").run(graph)
-
class TestCustomOperators(JitTestCase):
diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp
index 922681ec85..5120a04e2b 100644
--- a/torch/csrc/jit/autodiff.cpp
+++ b/torch/csrc/jit/autodiff.cpp
@@ -804,7 +804,14 @@ static void lambdaLiftReverse(Gradient& grad_desc, ReverseDetails& rev_info) {
// we create an incorrect sum that doesn't use prev vjp, replace uses, and
// fix the sum.
Value* new_vjp = createAutogradAdd(tmp_vjp_in, tmp_vjp_in);
- new_vjp->node()->moveAfter(tmp_vjp_prev->node());
+ if (tmp_vjp_prev->node()->kind() == prim::Param) {
+ // can't move a node after a block param node
+ new_vjp->node()->moveBefore(
+ *tmp_vjp_prev->node()->owningBlock()->nodes().begin());
+ } else {
+ new_vjp->node()->moveAfter(tmp_vjp_prev->node());
+ }
+
tmp_vjp_prev->replaceAllUsesWith(new_vjp);
new_vjp->node()->replaceInput(1, tmp_vjp_prev);
grad_desc.df_input_vjps.emplace_back(i);
@@ -859,6 +866,5 @@ Gradient differentiate(std::shared_ptr<Graph>& graph) {
ConstantPooling(grad_desc.df);
return grad_desc;
}
-
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/passes/alias_analysis.cpp b/torch/csrc/jit/passes/alias_analysis.cpp
index 99581fab64..ab4c33febc 100644
--- a/torch/csrc/jit/passes/alias_analysis.cpp
+++ b/torch/csrc/jit/passes/alias_analysis.cpp
@@ -384,10 +384,13 @@ void AliasDb::analyzeImpl(Node* node) {
return analyzeWait(node);
case prim::TupleConstruct:
return analyzeTupleConstruct(node);
+ case prim::GradOf:
+ return analyzeGradOf(node);
case prim::Constant:
case prim::DictConstruct:
case prim::ListConstruct:
case prim::AutogradZero:
+ case prim::AutogradAdd:
case prim::FusedConcat:
case prim::MMTreeReduce:
case prim::MMBatchSide:
@@ -594,6 +597,12 @@ void AliasDb::analyzeLoop(Node* node) {
mapAliases(node->outputs(), blockOutputs);
}
+void AliasDb::analyzeGradOf(Node* node) {
+ const auto grad_of_block = node->blocks().at(0);
+ analyze(grad_of_block);
+ mapAliases(node->outputs(), grad_of_block->outputs());
+}
+
void AliasDb::analyzeSubgraph(Node* node) {
const auto subgraph = node->g(attr::Subgraph).get();
@@ -704,6 +713,11 @@ void AliasDb::analyzeWait(Node* node) {
}
void AliasDb::analyzeTupleConstruct(Node* node) {
+ // Because we currently mark all Tuples as needing annotation
+ // (even those containing just prmitive types), an element needs to be created
+ // for TupleConstruct. When that changes we can create an element
+ // only if it contains elements which need annotation
+ getOrCreateElement(node->output());
for (const auto& input : node->inputs()) {
if (shouldAnnotate(input)) {
addToContainedElements(input, node->output());
@@ -831,16 +845,40 @@ bool AliasDb::cannotCheckAliasContainment(const Value* elem) const {
return false;
}
-bool AliasDb::mayContainAlias(const Value* a, const Value* b) const {
- if (!shouldAnnotate(a) || !shouldAnnotate(b)) {
- return false;
+bool AliasDb::mayContainAlias(Value* a, Value* b) const {
+ const std::vector<Value*> a_vec = {a};
+ const std::vector<Value*> b_vec = {b};
+
+ return mayContainAlias(a_vec, b_vec);
+}
+
+bool AliasDb::mayContainAlias(
+ const at::ArrayRef<Value*>& a,
+ const at::ArrayRef<Value*>& b) const {
+ std::vector<Element*> a_elements;
+ for (const auto& val : a) {
+ if (cannotCheckAliasContainment(val)) {
+ return true;
+ }
+ if (shouldAnnotate(val)) {
+ a_elements.push_back(elementMap_.at(val));
+ }
}
- if (cannotCheckAliasContainment(a) || cannotCheckAliasContainment(b)) {
- return true;
+ if (a_elements.size() == 0) {
+ return false;
}
- return memoryDAG_->mayContainAlias(elementMap_.at(a), elementMap_.at(b));
+ std::vector<Element*> b_elements;
+ for (const auto& val : b) {
+ if (cannotCheckAliasContainment(val)) {
+ return true;
+ }
+ if (shouldAnnotate(val)) {
+ b_elements.push_back(elementMap_.at(val));
+ }
+ }
+ return memoryDAG_->mayContainAlias(a_elements, b_elements);
}
// Make each value in the `from` list point to its partner in the `to` list
@@ -1241,6 +1279,7 @@ TORCH_API bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
prim::TupleConstruct,
prim::AutogradZero,
prim::FusedConcat,
+ prim::GradOf,
prim::MMTreeReduce,
prim::MMBatchSide,
prim::BroadcastSizes,
@@ -1256,6 +1295,7 @@ TORCH_API bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
prim::BroadcastingChunk,
prim::fork,
prim::CreateObject,
+ prim::AutogradAdd,
prim::GetAttr,
prim::SetAttr,
aten::wait,
diff --git a/torch/csrc/jit/passes/alias_analysis.h b/torch/csrc/jit/passes/alias_analysis.h
index e173e91eeb..7de15cd61d 100644
--- a/torch/csrc/jit/passes/alias_analysis.h
+++ b/torch/csrc/jit/passes/alias_analysis.h
@@ -47,7 +47,13 @@ class AliasDb {
// Does `a` and `b` potentially share a memory location or do either
// hold in memory any element that exists in the other
- bool mayContainAlias(const Value* a, const Value* b) const;
+ bool mayContainAlias(Value* a, Value* b) const;
+
+ // Do any values in group `a` share a memory location or hold in memory
+ // any element that exists in group `b`
+ bool mayContainAlias(
+ const at::ArrayRef<Value*>& a,
+ const at::ArrayRef<Value*>& b) const;
// Do `a` and `b` potentially share a memory location?
bool mayAlias(const Value* a, const Value* b) const;
@@ -189,6 +195,7 @@ class AliasDb {
void analyzeBroadcastingChunk(Node* node);
void analyzeFork(Node* node);
void analyzeWait(Node* node);
+ void analyzeGradOf(Node* node);
void analyzeSetAttr(Node* node);
void analyzeTupleConstruct(Node* node);
void analyzeCustomOp(Node* node);
diff --git a/torch/csrc/jit/passes/common_subexpression_elimination.cpp b/torch/csrc/jit/passes/common_subexpression_elimination.cpp
index 2d082ceb2c..2d6f897956 100644
--- a/torch/csrc/jit/passes/common_subexpression_elimination.cpp
+++ b/torch/csrc/jit/passes/common_subexpression_elimination.cpp
@@ -9,6 +9,7 @@
namespace torch {
namespace jit {
namespace {
+
// The function implements common subexpression elimination.
// Since the nodes are visited in topological order, one pass is enough.
void EliminateCommonSubexpression(
@@ -42,7 +43,15 @@ void EliminateCommonSubexpression(
// Check for CSE opportunities in the parent block.
auto parent_lookup = parent_lookup_fn(node);
+ auto g_out = node->owningGraph()->outputs();
if (parent_lookup) {
+ // since the graph outputs may be mutated after they are returned,
+ // don't introduce new aliasing among graph outputs
+ if (aliasDb.mayContainAlias(node->outputs(), g_out) &&
+ aliasDb.mayContainAlias(parent_lookup->outputs(), g_out)) {
+ continue;
+ }
+
node->replaceAllUsesWith(parent_lookup);
it.destroyCurrent();
continue;
@@ -53,6 +62,14 @@ void EliminateCommonSubexpression(
if (!subit.second) {
// Subexpression exists, replace the uses of node, and destroy it.
auto existing = *subit.first;
+
+ // don't introduce new aliasing among graph outputs
+ if (aliasDb.mayContainAlias(
+ node->outputs(), node->owningGraph()->outputs()) &&
+ aliasDb.mayContainAlias(existing->outputs(), g_out)) {
+ continue;
+ }
+
node->replaceAllUsesWith(existing);
// Destroy the node.
it.destroyCurrent();
diff --git a/torch/csrc/jit/passes/constant_pooling.cpp b/torch/csrc/jit/passes/constant_pooling.cpp
index 5421c8ccf1..e577b107bb 100644
--- a/torch/csrc/jit/passes/constant_pooling.cpp
+++ b/torch/csrc/jit/passes/constant_pooling.cpp
@@ -1,7 +1,8 @@
+#include <torch/csrc/jit/passes/constant_pooling.h>
#include <ATen/core/interned_strings.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/node_hashing.h>
-#include <torch/csrc/jit/passes/constant_pooling.h>
+#include <torch/csrc/jit/passes/alias_analysis.h>
#include <unordered_set>
namespace torch {
@@ -13,7 +14,8 @@ namespace {
// Move all constants to the beginning of the graph, and deduplicate
void ConstantPooling(
Block* block,
- std::unordered_set<Node*, HashNode, EqualNode>& constants) {
+ std::unordered_set<Node*, HashNode, EqualNode>& constants,
+ const AliasDb& aliasDb) {
for (auto it = block->nodes().begin(); it != block->nodes().end();) {
auto node = *it;
// node may be moved to a different block so advance iterator now
@@ -21,7 +23,7 @@ void ConstantPooling(
if (!node->blocks().empty()) {
// Traverse sub-blocks.
for (auto block : node->blocks()) {
- ConstantPooling(block, constants);
+ ConstantPooling(block, constants, aliasDb);
}
continue;
}
@@ -35,6 +37,16 @@ void ConstantPooling(
if (!subit.second) {
// constant exists, replace the uses of node, and destroy it.
auto existing = *subit.first;
+
+ // since the graph outputs may be mutated after they are returned,
+ // don't introduce new aliasing among graph outputs
+ if (aliasDb.mayContainAlias(
+ node->outputs(), node->owningGraph()->outputs()) &&
+ aliasDb.mayContainAlias(
+ existing->outputs(), node->owningGraph()->outputs())) {
+ continue;
+ }
+
node->replaceAllUsesWith(existing);
node->destroy();
continue;
@@ -46,13 +58,12 @@ void ConstantPooling(
node->moveBefore(first_node);
}
}
-
} // anonymous namespace
void ConstantPooling(const std::shared_ptr<Graph>& graph) {
+ AliasDb aliasDb(graph);
std::unordered_set<Node*, HashNode, EqualNode> constants;
- ConstantPooling(graph->block(), constants);
+ ConstantPooling(graph->block(), constants, aliasDb);
}
-
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp
index b7a55b82b9..cd814a3530 100644
--- a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp
+++ b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp
@@ -110,7 +110,7 @@ class SubgraphSlicer {
return result;
}
- bool shouldConsiderForMerge(Node* node, const AliasDb& aliasDb) {
+ bool shouldConsiderForMerge(Node* node) {
// if we're already in the process of merging
if (node->kind() == prim::DifferentiableGraph) {
return true;
@@ -118,19 +118,13 @@ class SubgraphSlicer {
if (node->kind() == prim::Constant) {
return false;
}
- // when a node which has writers is moved into a subgraph it may lose
- // context and CSE could merge it with another node that has writers
- // TODO: @eellison Fix problem more generally in CSE, land PR #18500
- if (aliasDb.hasWriters(node)) {
- return false;
- }
return isDifferentiable(node);
}
std::pair<graph_node_list::iterator, bool> scanNode(
Node* consumer,
AliasDb& aliasDb) {
- if (shouldConsiderForMerge(consumer, aliasDb)) {
+ if (shouldConsiderForMerge(consumer)) {
if (consumer->kind() != prim::DifferentiableGraph) {
consumer = SubgraphUtils::createSingletonSubgraph(
consumer, prim::DifferentiableGraph);
@@ -155,7 +149,7 @@ class SubgraphSlicer {
Node* producer,
AliasDb& aliasDb) {
AT_ASSERT(consumer->kind() == prim::DifferentiableGraph);
- bool canMerge = shouldConsiderForMerge(producer, aliasDb) &&
+ bool canMerge = shouldConsiderForMerge(producer) &&
aliasDb.moveBeforeTopologicallyValid(producer, consumer);
if (!canMerge) {
diff --git a/torch/csrc/jit/passes/utils/memory_dag.cpp b/torch/csrc/jit/passes/utils/memory_dag.cpp
index 6b56588cd1..1dc74c55c1 100644
--- a/torch/csrc/jit/passes/utils/memory_dag.cpp
+++ b/torch/csrc/jit/passes/utils/memory_dag.cpp
@@ -2,7 +2,6 @@
#include <torch/csrc/utils/memory.h>
#include <algorithm>
-#include <iostream>
#include <queue>
namespace torch {
@@ -16,10 +15,9 @@ bool MemoryDAG::mayAlias(const Element* a, const Element* b) const {
return mayAliasImpl(a, b);
}
-bool MemoryDAG::mayAliasImpl(const Element* a, const Element* b) const {
- const auto aMemLoc = a->getMemoryLocations();
- const auto bMemLoc = b->getMemoryLocations();
-
+bool MemoryDAG::memoryLocationOverlap(
+ const std::unordered_set<const Element*>& aMemLoc,
+ const std::unordered_set<const Element*>& bMemLoc) const {
// XXX: This could be more efficiently done as a bitwise AND on two bitfields
// that represent memory location membership. If these comparisons end up
// being a bottleneck, consider implementing it that way.
@@ -30,9 +28,17 @@ bool MemoryDAG::mayAliasImpl(const Element* a, const Element* b) const {
}
}
}
+
return false;
}
+bool MemoryDAG::mayAliasImpl(const Element* a, const Element* b) const {
+ const auto aMemLoc = a->getMemoryLocations();
+ const auto bMemLoc = b->getMemoryLocations();
+
+ return memoryLocationOverlap(aMemLoc, bMemLoc);
+}
+
bool MemoryDAG::mayContainAlias(const Element* a, const Element* b) const {
return mayContainAliasImpl(a, b);
}
@@ -67,15 +73,27 @@ bool MemoryDAG::mayContainAliasImpl(const Element* a, const Element* b) const {
collectAllContainedMemoryLocations(a, all_a_mlocs);
collectAllContainedMemoryLocations(b, all_b_mlocs);
- for (const auto a_mem : all_a_mlocs) {
- for (const auto b_mem : all_b_mlocs) {
- if (a_mem == b_mem) {
- return true;
- }
- }
+ return memoryLocationOverlap(all_a_mlocs, all_b_mlocs);
+}
+
+bool MemoryDAG::mayContainAlias(
+ const at::ArrayRef<Element*>& a,
+ const at::ArrayRef<Element*>& b) const {
+ if (a.size() == 0 || b.size() == 0) {
+ return false;
}
- return false;
+ std::unordered_set<const Element*> all_a_mlocs;
+ for (const auto& elem : a) {
+ collectAllContainedMemoryLocations(elem, all_a_mlocs);
+ }
+
+ std::unordered_set<const Element*> all_b_mlocs;
+ for (const auto& elem : b) {
+ collectAllContainedMemoryLocations(elem, all_b_mlocs);
+ }
+
+ return memoryLocationOverlap(all_a_mlocs, all_b_mlocs);
}
// Make `v` point at `to`.
diff --git a/torch/csrc/jit/passes/utils/memory_dag.h b/torch/csrc/jit/passes/utils/memory_dag.h
index 76193cc526..cffb3eb116 100644
--- a/torch/csrc/jit/passes/utils/memory_dag.h
+++ b/torch/csrc/jit/passes/utils/memory_dag.h
@@ -1,5 +1,6 @@
#pragma once
+#include <c10/util/ArrayRef.h>
#include <memory>
#include <unordered_map>
#include <unordered_set>
@@ -45,6 +46,10 @@ class MemoryDAG {
bool mayContainAlias(const Element* a, const Element* b) const;
bool mayContainAlias(Element* a, Element* b) const;
+ bool mayContainAlias(
+ const at::ArrayRef<Element*>& a,
+ const at::ArrayRef<Element*>& b) const;
+
// Do any values in group `a` potentially share a memory location with any
// value in group `b`?
//
@@ -86,6 +91,9 @@ class MemoryDAG {
}
private:
+ bool memoryLocationOverlap(
+ const std::unordered_set<const Element*>& a,
+ const std::unordered_set<const Element*>& b) const;
bool mayAliasImpl(const Element* a, const Element* b) const;
bool mayContainAliasImpl(const Element* contained, const Element* container)
const;