diff options
author | Michael Suo <suo@fb.com> | 2019-01-30 11:06:32 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-01-30 11:28:03 -0800 |
commit | dc84ff1e5a74164d668fa1afac9e0e0f2315f3b2 (patch) | |
tree | d309fd3dd58ce2477f1d866e53aaf645553ee30a | |
parent | dff8165d04f589dcaca42b76650340e220628f40 (diff) | |
download | pytorch-dc84ff1e5a74164d668fa1afac9e0e0f2315f3b2.tar.gz pytorch-dc84ff1e5a74164d668fa1afac9e0e0f2315f3b2.tar.bz2 pytorch-dc84ff1e5a74164d668fa1afac9e0e0f2315f3b2.zip |
Use a points-to graph for alias analysis (#16386)
Summary:
This PR changes the way we store aliasing information from a "set" approach to a "points-to" analysis. Set-based approaches lose information in ways that make it difficult to do "live" updates to the alias DB as one as mutating the graph.
The tradeoff is that simple queries get more expensive, since they require traversing the points-to graph to answer most questions. In practice, this is unlikely to be that costly since we don't have massive aliasing chains, but we could create an approximation/caching layer if this becomes a problem.
My rough plan is:
1. This PR, switching to a points-to graph
2. Make it "live": analyzing a node should record all the edges the node added, so that we can rollback when the node is destroyed.
3. Reduce wildcard scope: we can make the wildcard a special vertex that points to anything that we're not "sure" about; namely, things that have been put inside lists, or graph inputs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16386
Differential Revision: D13855117
Pulled By: suo
fbshipit-source-id: f009f58143173c275501624eb105d07ab60fe5e1
-rw-r--r-- | test/cpp/jit/no-gtest.cpp | 1 | ||||
-rw-r--r-- | test/cpp/jit/tests.h | 25 | ||||
-rw-r--r-- | test/test_jit.py | 21 | ||||
-rw-r--r-- | torch/csrc/jit/passes/alias_analysis.cpp | 740 | ||||
-rw-r--r-- | torch/csrc/jit/passes/alias_analysis.h | 34 | ||||
-rw-r--r-- | torch/csrc/jit/passes/batch_mm.cpp | 2 | ||||
-rw-r--r-- | torch/csrc/jit/passes/common_subexpression_elimination.cpp | 2 | ||||
-rw-r--r-- | torch/csrc/jit/passes/constant_propagation.cpp | 2 | ||||
-rw-r--r-- | torch/csrc/jit/passes/create_autodiff_subgraphs.cpp | 2 | ||||
-rw-r--r-- | torch/csrc/jit/passes/dead_code_elimination.cpp | 5 | ||||
-rw-r--r-- | torch/csrc/jit/passes/graph_fuser.cpp | 4 | ||||
-rw-r--r-- | torch/csrc/jit/passes/shape_analysis.cpp | 2 | ||||
-rw-r--r-- | torch/csrc/jit/python_ir.cpp | 7 |
13 files changed, 556 insertions, 291 deletions
diff --git a/test/cpp/jit/no-gtest.cpp b/test/cpp/jit/no-gtest.cpp index 6be01f0693..22e3e128e1 100644 --- a/test/cpp/jit/no-gtest.cpp +++ b/test/cpp/jit/no-gtest.cpp @@ -29,6 +29,7 @@ std::string runJITCPPTests() { testTopologicalIndex(); testTopologicalMove(); testSubgraphUtils(); + testAliasAnalysis(); return out.str(); } diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h index 02c171ffb9..94e88534c1 100644 --- a/test/cpp/jit/tests.h +++ b/test/cpp/jit/tests.h @@ -27,17 +27,17 @@ #endif // defined(USE_GTEST) +#include "ATen/core/interned_strings.h" +#include "c10/util/Exception.h" #include "torch/csrc/autograd/generated/variable_factories.h" #include "torch/csrc/autograd/variable.h" #include "torch/csrc/jit/argument_spec.h" -#include "c10/util/Exception.h" #include "torch/csrc/jit/attributes.h" #include "torch/csrc/jit/autodiff.h" #include "torch/csrc/jit/code_template.h" #include "torch/csrc/jit/custom_operator.h" #include "torch/csrc/jit/dynamic_dag.h" #include "torch/csrc/jit/fuser/interface.h" -#include "ATen/core/interned_strings.h" #include "torch/csrc/jit/interpreter.h" #include "torch/csrc/jit/ir.h" #include "torch/csrc/jit/operator.h" @@ -56,12 +56,13 @@ #include "torch/csrc/jit/symbolic_variable.h" #include "torch/csrc/jit/tracer.h" #include "torch/csrc/utils/hash.h" +#include "torch/csrc/utils/memory.h" #include "torch/csrc/autograd/engine.h" #include "torch/csrc/autograd/variable.h" -#include "torch/csrc/jit/graph_executor.h" #include "ATen/core/ivalue.h" +#include "torch/csrc/jit/graph_executor.h" #include "torch/csrc/jit/script/compiler.h" #include "torch/csrc/jit/script/module.h" @@ -278,8 +279,8 @@ void testAttributes() { auto two = attr::device; auto three = attr::end; auto four = attr::perm; - Node *n = g.create(Symbol::fromQualString("foo::bar")); - Node &attr = *n; + Node* n = g.create(Symbol::fromQualString("foo::bar")); + Node& attr = *n; attr.f_(one, 3.4)->i_(two, 5)->s_(three, "what"); ASSERT_EQ(attr.f(one), 3.4); ASSERT_EQ(attr.s(three), "what"); @@ -291,8 +292,8 @@ void testAttributes() { attr.ss_(two, {"hi", "now"}); ASSERT_EQ(attr.ss(two).at(1), "now"); - Node *n2 = g.create(Symbol::fromQualString("foo::baz")); - Node &attr2 = *n2; + Node* n2 = g.create(Symbol::fromQualString("foo::baz")); + Node& attr2 = *n2; attr2.copyAttributes(attr); ASSERT_EQ(attr2.s(one), "no"); attr2.f_(one, 5); @@ -1783,7 +1784,7 @@ void testDynamicDAG() { struct TopoMoveTestFixture { TopoMoveTestFixture() { createGraph(); - aliasDb = AliasAnalysis(graph); + aliasDb = torch::make_unique<AliasDb>(graph); } // Nodes are named after their output. @@ -1912,7 +1913,7 @@ struct TopoMoveTestFixture { } std::shared_ptr<Graph> graph; - c10::optional<AliasDb> aliasDb; + std::unique_ptr<AliasDb> aliasDb; std::unordered_map<std::string, Node*> nodes; }; @@ -2037,7 +2038,7 @@ void testAliasAnalysis() { graph->lint(); - auto aliasDb = AliasAnalysis(graph); + AliasDb aliasDb(graph); // Can't move past a mutation of a used value AT_ASSERT(!aliasDb.moveAfterTopologicallyValid(c->node(), aMut->node())); AT_ASSERT(aliasDb.moveAfterTopologicallyValid(d->node(), c->node())); @@ -2059,10 +2060,10 @@ void testAliasAnalysis() { auto usesB = graph->insert(aten::add, {b, fresh}); auto aliasesB = graph->insert(aten::select, {a, constant, constant}); auto mutatesAliasOfB = graph->insert(aten::add_, {aliasesB, fresh}); - auto c = graph->insert(aten::add, {fresh, aliasesB}); + graph->insert(aten::add, {fresh, aliasesB}); graph->lint(); - auto aliasDb = AliasAnalysis(graph); + AliasDb aliasDb(graph); AT_ASSERT(!aliasDb.moveAfterTopologicallyValid( aliasesB->node(), mutatesAliasOfB->node())); AT_ASSERT(!aliasDb.moveAfterTopologicallyValid( diff --git a/test/test_jit.py b/test/test_jit.py index bdd31f40f7..22067843c7 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3837,7 +3837,7 @@ a") return a[3:10] == [3, 4] self.checkScript(test_backward_slice, ()) - def test_mutable_list(self): + def test_mutable_list_append(self): def test_append(): a = [0, 1] a.append(2) @@ -3845,6 +3845,7 @@ a") return a == [0, 1, 2, 3] self.checkScript(test_append, ()) + def test_mutable_list_append_2(self): def test_append_2(): a = [0, 1] a.append(2) @@ -3853,6 +3854,7 @@ a") return a == [1, 4] self.checkScript(test_append_2, ()) + def test_mutable_list_append_if(self): def test_append_if(): a = [1] if True: @@ -3860,6 +3862,7 @@ a") return a == [1, 4] self.checkScript(test_append_if, ()) + def test_mutable_list_append_if_else(self): def test_append_if_else(): a = [1] if False: @@ -3869,6 +3872,7 @@ a") return a == [1, 10] self.checkScript(test_append_if_else, ()) + def test_mutable_list_append_loop(self): def test_append_loop(): a = torch.jit.annotate(List[int], []) for i in range(5): @@ -3877,6 +3881,7 @@ a") return a == [0, 1, 2, 3, 4] self.checkScript(test_append_loop, ()) + def test_mutable_list_append_loop_if(self): def test_append_loop_if(): a = torch.jit.annotate(List[int], []) for i in range(5): @@ -3888,6 +3893,7 @@ a") return a == [0, 0, 0, 0, 4] self.checkScript(test_append_loop_if, ()) + def test_mutable_list_nested_loop(self): def test_nested_loop(): a = torch.jit.annotate(List[int], []) for i in range(2): @@ -3895,7 +3901,7 @@ a") a.append(i + j) return a == [0, 1, 1, 2] - self.checkScript(test_append_loop_if, ()) + self.checkScript(test_nested_loop, ()) def test_mutable_list_function_inline(self): @torch.jit.script @@ -9495,6 +9501,17 @@ a") self.assertExpectedGraph(foo.graph) + def test_mutable_dce_wildcards(self): + def fn(): + x = torch.ones(2, 3) + l = [] + l.append(x) + x_view = l[0] + x.add_(torch.ones(2, 3)) + return x_view + + self.checkScript(fn, ()) + def test_cpp_function_tensor_str(self): x = torch.randn(2, 2) scale = torch.randn(2, 2, requires_grad=True) diff --git a/torch/csrc/jit/passes/alias_analysis.cpp b/torch/csrc/jit/passes/alias_analysis.cpp index 716762b6e3..20da07b56c 100644 --- a/torch/csrc/jit/passes/alias_analysis.cpp +++ b/torch/csrc/jit/passes/alias_analysis.cpp @@ -1,6 +1,8 @@ #include <torch/csrc/jit/passes/alias_analysis.h> #include <torch/csrc/jit/script/error_report.h> +#include <torch/csrc/utils/memory.h> +#include <queue> namespace torch { namespace jit { @@ -21,39 +23,392 @@ bool shouldAnnotate(const Value* v) { } } // namespace -AliasDb::AliasDb(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) { - analyze(graph_); +// class AliasTracker +// +// This class tracks the "A points to B" graph for all values, as well as +// wildcards and writes. It is used by AliasDb to provide a higher-level API. +// +// NOTE: this implementation is not very efficient; it's designed to be easy to +// mutate as you modify the graph. +class AliasTracker { + public: + // Returns true iff `v` is present in the alias set tracker. + bool contains(const Value* v) const { + return map_.count(v); + } - // Build helper indices - // NOTE: that these assume that AliasDb is immutable once constructed. - // - Alias set -> value mapping - for (const auto& pr : valueToAlias_) { - const auto value = pr.first; - const auto& aliasInfo = pr.second; - // We don't support composite types yet - AT_ASSERT(aliasInfo.containedTypes().size() == 0); - for (const auto aliasSet : aliasInfo.sets()) { - aliasToValue_[aliasSet].insert(value); + bool writesTo(Node* n, const Value* v) const { + if (isWildcard(v)) { + return wildcardWriters_.count(n); } + + if (!map_.count(v)) { + return false; + } + + return map_.at(v)->writers.count(n); } -} -bool AliasDb::hasWildcard(const Node* n) const { - return wildcardNodes_.count(n) != 0; + // Whether `a` *may* point to `b` + bool pointsTo(const Value* a, const Value* b) const { + if (!map_.count(a)) { + return false; + } + if (isWildcard(a) || isWildcard(b)) { + return true; + } + + // BFS the subtree where the root is `a`s element and the branches are the + // `pointsTo` relationships. + const auto root = map_.at(a); + return root->bfs( + [&](const Element* el) { return el->value == b; }, + BfsDirection::POINTS_TO, + /*shortCircuit=*/true); + } + + // Make `v` point at `to`. + void makePointerTo(const Value* v, const Value* to) { + if (v == to) { + return; + } + + // If `to` is a wildcard, don't insert anything into the graph; wildcards + // are tracked separately since they have different aliasing rules. + if (isWildcard(to)) { + setWildcard(v); + return; + } + + if (!map_.count(to)) { + makeFreshValue(to); + } + + if (!map_.count(v)) { + makeFreshValue(v); + } + + auto vEl = map_.at(v); + auto toEl = map_.at(to); + + vEl->pointsTo.insert(toEl); + toEl->pointedFrom.insert(vEl); + } + + // Give `v` a fresh alias (i.e. it does not point to any value) + void makeFreshValue(const Value* v) { + auto el = torch::make_unique<Element>(); + el->value = v; + + auto rawPtr = el.get(); + elements_.emplace(rawPtr, std::move(el)); + map_.emplace(v, rawPtr); + } + + // Register `v` as a wildcard value. + void setWildcard(const Value* v) { + wildcards_.insert(v); + } + + // is `v` a wildcard? + bool isWildcard(const Value* v) const { + return wildcards_.count(v); + } + + // Register the fact that `n` writes to `v`. + void registerWrite(const Value* v, Node* n) { + numWrites_++; + + if (isWildcard(v)) { + wildcardWriters_.insert(n); + return; + } + + AT_ASSERT(map_.count(v)); + map_.at(v)->writers.insert(n); + } + + // Return all aliases of `v`. This is the full set of any other value that + // *may* represent the same memory location. + // NOTE: this does not consider wildcard values + std::unordered_set<const Value*> getAliases(const Value* v) const { + std::unordered_set<const Value*> ret; + if (!map_.count(v)) { + return ret; + } + + const auto root = map_.at(v); + + root->bfs( + [&](const Element* el) { + ret.insert(el->value); + return false; // fn has to return bool but we don't use the result + }, + BfsDirection::BOTH); + return ret; + } + + // Get all nodes that write to `v` or a value that may alias `v`. + std::unordered_set<Node*> getWrites(const Value* v) const { + std::unordered_set<Node*> ret; + if (!map_.count(v)) { + return ret; + } + + // Any write to a wilcard may write to `v`. + for (auto writer : wildcardWriters_) { + ret.insert(writer); + } + + if (useCache_) { + for (auto writer : getWritersCached(v)) { + ret.insert(writer); + } + return ret; + } + + const auto root = map_.at(v); + root->bfs( + [&](const Element* el) { + for (auto writer : el->writers) { + ret.insert(writer); + } + return false; // fn has to return bool but we don't use the result + }, + BfsDirection::BOTH); + + return ret; + } + + // Functionally equivalent to getWrites().size() > 0, but with a + // short-circuiting implementation to be faster. + bool hasWriters(const Value* v) const { + if (!map_.count(v)) { + return false; + } + + if (isWildcard(v)) { + // If `n` has a wildcard, any write in the graph may write to it. + // So the only way we know there are no writers is if there are no writes + // at all. + return numWrites_ == 0; + } + + if (wildcardWriters_.size() > 0) { + // A write to the wildcard may be a write to any value. + return true; + } + + if (useCache_) { + return hasWritersCached(v); + } + + const auto root = map_.at(v); + return root->bfs( + [&](const Element* el) { return el->writers.size() > 0; }, + BfsDirection::BOTH, + /*shortCircuit=*/true); + } + + // Get all nodes that write to a wildcard value. + const std::unordered_set<Node*>& getWildcardWriters() const { + return wildcardWriters_; + } + + void dump() const { + std::cout << "\n===2. ALIAS DB===\n"; + for (const auto& ptrPair : elements_) { + const auto element = ptrPair.first; + if (element->pointsTo.size() > 0) { + std::cout << element->value->uniqueName() << " points to: "; + for (const auto pointedTo : element->pointsTo) { + std::cout << pointedTo->value->uniqueName() << ", "; + } + std::cout << "\n"; + } + } + + std::cout << "\n===3. WILDCARDS===\n"; + for (const auto wildcard : wildcards_) { + std::cout << wildcard->uniqueName() << ", "; + } + std::cout << "\n"; + } + + private: + enum class BfsDirection { + POINTS_TO, + POINTED_FROM, + // Consider both pointer directions. The closure obtained from this + // represents the whole "alias set" of a value. + BOTH + }; + // `Element` represents the vertex in the points-to graph. It has a 1:1 + // relationship with IR `Value`s. + struct Element { + const Value* value = nullptr; + // All values that this value *may* point to. It's possible to have multiple + // values that you might point to due to control flow/complex ops + std::unordered_set<Element*> pointsTo; + // Backreference to values that point to `this` + std::unordered_set<Element*> pointedFrom; + // Nodes that write to this specific value. + std::unordered_set<Node*> writers; + + // Do a breadth-first search over the graph, starting at `this` and + // traversing in the direction `dir`.`fn` will be run on each element. + // + // If `shortCircuit` is set, then if `fn` evaluates to true the search will + // short-circuit and return true. You can use this to do existence checks + // on the graph or whatever. + template <typename Fn> + bool bfs(Fn fn, BfsDirection dir, bool shortCircuit = false) const { + std::queue<const Element*> queue; + std::unordered_set<const Element*> seen; + + queue.push(this); + while (!queue.empty()) { + const auto el = queue.front(); + queue.pop(); + seen.insert(el); + + if (fn(el) && shortCircuit) { + return true; + } + + switch (dir) { + case BfsDirection::POINTS_TO: { + for (auto ptr : el->pointsTo) { + if (!seen.count(ptr)) { + queue.push(ptr); + } + } + } break; + + case BfsDirection::POINTED_FROM: { + for (auto ptr : el->pointedFrom) { + if (!seen.count(ptr)) { + queue.push(ptr); + } + } + } break; + + case BfsDirection::BOTH: { + for (auto ptr : el->pointsTo) { + if (!seen.count(ptr)) { + queue.push(ptr); + } + } + for (auto ptr : el->pointedFrom) { + if (!seen.count(ptr)) { + queue.push(ptr); + } + } + } break; + } + } + return false; + } + }; + + // Structure that owns all the element pointers. It's a map of + // raw pointer -> unique_ptr to facilitate easy queries + std::unordered_map<Element*, std::unique_ptr<Element>> elements_; + // Index to look up whatever element corresponds to that value. + std::unordered_map<const Value*, Element*> map_; + // All values that may point to a wildcard value. + std::unordered_set<const Value*> wildcards_; + // All nodes that write to a wildcard + std::unordered_set<Node*> wildcardWriters_; + size_t numWrites_ = 0; + + /** + * Caching layer. + */ + using set_id_t = size_t; + bool useCache_ = true; + mutable std::unordered_map<const Element*, std::unordered_set<set_id_t>> + elementToSet_; + mutable std::unordered_map<set_id_t, std::unordered_set<Node*>> setToWrites_; + mutable bool cacheStale_ = true; + mutable set_id_t lastId = 0; + + // Cache results in a way to make common queries constant time. + void cache() const { + if (!cacheStale_) { + return; + } + + for (const auto& pr : elements_) { + const auto el = pr.first; + // For each value that does point to anything, assign a fresh set. + if (el->pointsTo.size() == 0) { + const auto id = getFreshId(); + assignSet(el, id); + + // Propagate this set to every element that points to `el` + el->bfs( + [&](const Element* pointerTo) { return assignSet(pointerTo, id); }, + BfsDirection::POINTED_FROM); + } + } + + cacheStale_ = false; + } + + bool hasWritersCached(const Value* v) const { + cache(); + for (const auto& set : elementToSet_.at(map_.at(v))) { + if (setToWrites_.count(set) && setToWrites_.at(set).size() > 0) { + return true; + } + } + return false; + } + + std::unordered_set<Node*> getWritersCached(const Value* v) const { + cache(); + std::unordered_set<Node*> ret; + for (const auto& set : elementToSet_.at(map_.at(v))) { + if (setToWrites_.count(set) > 0) { + for (auto write : setToWrites_.at(set)) { + ret.insert(write); + } + } + } + return ret; + } + + bool assignSet(const Element* el, set_id_t id) const { + elementToSet_[el].insert(id); + for (auto write : el->writers) { + setToWrites_[id].insert(write); + } + return true; + } + + set_id_t getFreshId() const { + return ++lastId; + }; +}; + +AliasDb::~AliasDb() = default; + +AliasDb::AliasDb(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) { + aliasTracker_ = torch::make_unique<AliasTracker>(); + analyze(graph_); } // Does `n` use or write to any wildcard aliases? -bool AliasDb::hasWildcardImpl(const Node* n) const { +bool AliasDb::hasWildcard(const Node* n) const { for (const auto input : n->inputs()) { - if (valueToAlias_.count(input) != 0 && - valueToAlias_.at(input).isWildcard()) { + if (aliasTracker_->isWildcard(input)) { return true; } } for (const auto output : n->outputs()) { - if (valueToAlias_.count(output) != 0 && - valueToAlias_.at(output).isWildcard()) { + if (aliasTracker_->isWildcard(output)) { return true; } } @@ -61,34 +416,25 @@ bool AliasDb::hasWildcardImpl(const Node* n) const { } bool AliasDb::writesTo(Node* n, const Value* v) const { - if (valueToAlias_.count(v) == 0) { + if (!shouldAnnotate(v)) { // This is a primitive type return false; } - - const auto& aliasInfo = valueToAlias_.at(v); - AT_ASSERT(aliasInfo.sets().size() > 0); - // We only need to check one alias set, since if this value belongs to - // multiple alias sets they are all written to - const auto& aliasSet = *aliasInfo.sets().begin(); - - if (aliasToWrites_.count(aliasSet) == 0) { - // no writes to this alias set - return false; - } - - const auto& writers = aliasToWrites_.at(aliasSet); - return writers.count(n) != 0; + return aliasTracker_->writesTo(n, v); } bool AliasDb::hasWriters(const Node* n) const { - if (hasWildcard(n)) { - // If `n` has a wildcard, any write in the graph may write to it. - // So the only way we know there are no writers is if there are no writes - // at all. - return !aliasToWrites_.empty(); + for (const auto input : n->inputs()) { + if (aliasTracker_->hasWriters(input)) { + return true; + } + } + for (const auto output : n->outputs()) { + if (aliasTracker_->hasWriters(output)) { + return true; + } } - return getWriters(n).size() != 0; + return false; } bool AliasDb::hasWritersBefore(const Node* n) const { @@ -130,52 +476,27 @@ bool AliasDb::writesToInputAlias(Node* n) const { // For all writes, check if the written value may alias a graph input return std::any_of(writes.cbegin(), writes.cend(), [&](const Value* v) { - const auto& aliasInfo = valueToAlias_.at(v); - const auto& aliasSets = aliasInfo.sets(); - - // Check every distinct alias set this value belongs to return std::any_of( - aliasSets.cbegin(), aliasSets.cend(), [&](const Symbol aliasSet) { - return graphInputAliases_.count(aliasSet) != 0; + graph_->inputs().cbegin(), + graph_->inputs().cend(), + [&](const Value* graphInput) { + return shouldAnnotate(graphInput) && + aliasTracker_->pointsTo(graphInput, v); }); }); } std::unordered_set<Node*> AliasDb::getWriters(const Node* n) const { - // Get all alias sets of this node - // ... check the inputs - std::unordered_set<Symbol> aliasSets; - for (const auto& input : n->inputs()) { - if (valueToAlias_.count(input) != 0) { - for (const auto& aliasSet : valueToAlias_.at(input).sets()) { - aliasSets.insert(aliasSet); - } - } - } - - // ... and the outputs - for (const auto& output : n->outputs()) { - if (valueToAlias_.count(output) != 0) { - for (const auto& aliasSet : valueToAlias_.at(output).sets()) { - aliasSets.insert(aliasSet); - } - } - } - - // Then get the union of all writers to all those alias sets std::unordered_set<Node*> writers; - for (const auto& alias : aliasSets) { - if (aliasToWrites_.count(alias) != 0) { - for (const auto writer : aliasToWrites_.at(alias)) { - writers.insert(writer); - } + + for (const auto input : n->inputs()) { + for (auto writer : aliasTracker_->getWrites(input)) { + writers.insert(writer); } } - // A write to the wildcard set should be considered a write to `n` - if (aliasToWrites_.count(AliasInfo::wildcardSet())) { - const auto& wildcardWriters = aliasToWrites_.at(AliasInfo::wildcardSet()); - for (auto writer : wildcardWriters) { + for (const auto output : n->outputs()) { + for (auto writer : aliasTracker_->getWrites(output)) { writers.insert(writer); } } @@ -185,18 +506,11 @@ std::unordered_set<Node*> AliasDb::getWriters(const Node* n) const { std::unordered_set<const Value*> AliasDb::getAliases(const Value* v) const { std::unordered_set<const Value*> ret; - if (!valueToAlias_.count(v)) { + if (!aliasTracker_->contains(v)) { return ret; } - const auto& aliasSets = valueToAlias_.at(v).sets(); - for (const auto& aliasSet : aliasSets) { - const auto& aliases = aliasToValue_.at(aliasSet); - for (auto alias : aliases) { - ret.insert(alias); - } - } - return ret; + return aliasTracker_->getAliases(v); } std::unordered_set<const Value*> AliasDb::getWrites(Node* n) const { @@ -217,47 +531,32 @@ std::unordered_set<const Value*> AliasDb::getWrites(Node* n) const { void AliasDb::dump() const { std::cout << "\n===1. GRAPH===\n"; graph_->dump(); - std::cout << "===2. ALIAS SETS===\n"; - for (const auto& pr : valueToAlias_) { - std::cout << "%" << pr.first->uniqueName() << " : " - << "("; - - bool first = true; - for (const auto& alias : pr.second.sets()) { - if (first) { - first = false; - } else { - std::cout << ", "; - } - std::cout << alias.toUnqualString(); - } - std::cout << ")\n"; - } - std::cout << "\n===3. WRITES===\n"; - for (const auto& pr : aliasToWrites_) { - std::cout << "Alias set " << pr.first.toUnqualString() << ":\n"; - for (const auto node : pr.second) { - std::cout << " " << *node; - } - std::cout << "\n"; - } + aliasTracker_->dump(); +} - std::cout << "\n===3. WILDCARD INDEX===\n"; - for (const auto node : wildcardNodes_) { - node->dump(); +// TODO: need to create a dummy "graph input alias" value in setTracker for all +// inputs of the same type to point to. Currently they all point to the first +// element, which is technically wrong. +static void makeAllAlias( + const std::vector<Value*> values, + AliasTracker& setTracker) { + if (values.size() > 0) { + setTracker.makeFreshValue(values[0]); + } + for (const auto value : values) { + setTracker.makePointerTo(value, values[0]); } } void AliasDb::analyze(const std::shared_ptr<Graph>& graph) { // Assign aliases to the graph's inputs, assuming that all inputs of a given // type may alias to each other. - const auto tensorAlias = getFreshAlias(/*isGraphInput=*/true); - // Create a separate alias set for each list type - std::map<TypeKind, Symbol> listTypeAliases; - // Create a separate alias set for each tuple type - std::map<TupleTypePtr, Symbol> tupleTypeAliases; - std::map<TypeKind, Symbol> optionalTypeAliases; + + // 1. Partition inputs by their type + std::map<TypeKind, std::vector<Value*>> listTypes; + std::unordered_map<TupleTypePtr, std::vector<Value*>> tupleTypes; + std::vector<Value*> tensors; for (auto input : graph->inputs()) { auto inputType = input->type(); @@ -267,7 +566,7 @@ void AliasDb::analyze(const std::shared_ptr<Graph>& graph) { } if (inputType->isSubtypeOf(DynamicType::get())) { - addAlias(input, tensorAlias); + tensors.push_back(input); } else if (inputType->kind() == TypeKind::ListType) { auto containedType = inputType->containedTypes().at(0); // All tensor subtypes may alias to each other, so we should consider all @@ -275,23 +574,24 @@ void AliasDb::analyze(const std::shared_ptr<Graph>& graph) { if (containedType->isSubtypeOf(DynamicType::get())) { containedType = DynamicType::get(); } - if (listTypeAliases.count(containedType->kind()) == 0) { - listTypeAliases[containedType->kind()] = - getFreshAlias(/*isGraphInput=*/true); - } - - addAlias(input, listTypeAliases.at(containedType->kind())); + listTypes[containedType->kind()].push_back(input); } else if (inputType->kind() == TypeKind::TupleType) { auto tupleType = inputType->cast<TupleType>(); - if (tupleTypeAliases.count(tupleType) == 0) { - tupleTypeAliases[tupleType] = getFreshAlias(/*isGraphInput=*/true); - } - addAlias(input, tupleTypeAliases.at(tupleType)); + tupleTypes[tupleType].push_back(input); } else { AT_ASSERT(!shouldAnnotate(input)); } } + // 2. Make all partitions alias each other + for (const auto& pr : listTypes) { + makeAllAlias(pr.second, *aliasTracker_); + } + for (const auto& pr : tupleTypes) { + makeAllAlias(pr.second, *aliasTracker_); + } + makeAllAlias(tensors, *aliasTracker_); + analyze(graph->block()); } @@ -301,12 +601,21 @@ void AliasDb::analyze(Block* block) { } } +void AliasDb::analyze(Node* node) { + analyzeImpl(node); + + // After analyzing, update the wildcard index + if (hasWildcard(node)) { + wildcardNodes_.insert(node); + } +} + // The basic strategy is: // 1. Retrieve alias information for every input. // 2. Use the node's schema's alias annotations to propgagate alias/write // information to the outputs. For unschematized nodes, a special analyzer // will have to be handwritten. -void AliasDb::analyze(Node* node) { +void AliasDb::analyzeImpl(Node* node) { // These nodes are not schematized, so we need to handle them specially // TODO do the thing that python_printer does to force operator writers to // register aliasing information @@ -372,8 +681,7 @@ void AliasDb::analyze(Node* node) { } // Bind formal alias annotation to actual alias sets - std::unordered_map<Symbol, AliasInfo> formalToActual; - formalToActual[AliasInfo::wildcardSet()] = AliasInfo::createWildcard(); + std::unordered_map<Symbol, Value*> formalToActual; for (size_t i = 0; i < schema.arguments().size(); i++) { const auto& formal = schema.arguments()[i].alias_info(); const auto& actualValue = node->inputs().at(i); @@ -399,16 +707,12 @@ void AliasDb::analyze(Node* node) { continue; } - const auto& actualAlias = valueToAlias_.at(actualValue); - // Bind the formal to the actual - formalToActual[formalAlias] = actualAlias; + formalToActual[formalAlias] = actualValue; - // Record all writes - for (const auto& alias : actualAlias.sets()) { - if (formal->isWrite()) { - aliasToWrites_[alias].insert(node); - } + // Record writes + if (formal->isWrite()) { + aliasTracker_->registerWrite(actualValue, node); } } @@ -430,6 +734,11 @@ void AliasDb::analyze(Node* node) { // We don't support composite types for alias analysis yet. AT_ASSERT(formal->containedTypes().size() == 0); + if (formal->isWildcard()) { + aliasTracker_->setWildcard(actual); + continue; + } + for (const auto& formalAlias : formal->sets()) { // If we encounter an alias annotation that wasn't in the inputs: if (!formalToActual.count(formalAlias)) { @@ -447,22 +756,15 @@ void AliasDb::analyze(Node* node) { continue; } - auto outputAlias = formalToActual.at(formalAlias); - - // Record writes - for (const auto& alias : outputAlias.sets()) { - if (formal->isWrite()) { - aliasToWrites_[alias].insert(node); - } - } + auto toAlias = formalToActual.at(formalAlias); + makeAliasOf(actual, toAlias); + } - addAlias(actual, outputAlias); + // Record writes + if (formal->isWrite()) { + aliasTracker_->registerWrite(actual, node); } } - // Keep the wildcard index up to date. - if (hasWildcardImpl(node)) { - wildcardNodes_.insert(node); - } } void AliasDb::analyzeIf(Node* node) { @@ -479,8 +781,8 @@ void AliasDb::analyzeIf(Node* node) { const auto trueOutput = trueBlock->outputs().at(i); const auto falseOutput = falseBlock->outputs().at(i); - addAlias(nodeOutput, trueOutput); - addAlias(nodeOutput, falseOutput); + makeAliasOf(nodeOutput, trueOutput); + makeAliasOf(nodeOutput, falseOutput); } } @@ -494,33 +796,14 @@ void AliasDb::analyzeLoop(Node* node) { // Run alias analysis on the loop body, iterating until the block output // alias info converges. - auto notConverged = true; - while (notConverged) { - // Copy node input aliases to block input - mapAliases(blockInputs, loopCarriedInputs); - - // Populate block output alias info by analyzing the body - analyze(bodyBlock); - - // Copy the alias info from the block output to the node output - mapAliases(node->outputs(), blockOutputs); - - // Merge alias info from block outputs to the node inputs. - notConverged = false; - for (size_t i = 0; i < blockOutputs.size(); i++) { - const auto input = loopCarriedInputs[i]; - const auto output = blockOutputs[i]; - - // Check whether or not this would change anything - if (valueToAlias_.count(input) != 0) { - AT_ASSERT(valueToAlias_.count(output) != 0) - if (!valueToAlias_[output].isSubsetOf(valueToAlias_[input])) { - notConverged = true; - } - } - addAlias(input, output); - } - } + // Copy node input aliases to block input + mapAliases(blockInputs, loopCarriedInputs); + + // Populate block output alias info by analyzing the body + analyze(bodyBlock); + + // Copy the alias info from the block output to the node output + mapAliases(node->outputs(), blockOutputs); } void AliasDb::analyzeSubgraph(Node* node) { @@ -538,7 +821,7 @@ void AliasDb::analyzeSubgraph(Node* node) { // subgraph block. AT_ASSERT(subgraphBlock->outputs().size() >= node->outputs().size()); for (size_t i = 0; i < node->outputs().size(); i++) { - addAlias(node->outputs()[i], subgraphBlock->outputs()[i]); + makeAliasOf(node->outputs()[i], subgraphBlock->outputs()[i]); } } @@ -553,15 +836,14 @@ void AliasDb::analyzeCreator(Node* node) { // gives up and creates wildcards for everything. void AliasDb::analyzeExtractor(Node* node) { for (const auto output : node->outputs()) { - addAlias(output, AliasInfo::createWildcard()); + aliasTracker_->setWildcard(output); } } // For torch.chunk(), all returned tensors may alias the input tensor void AliasDb::analyzeChunk(Node* node) { - auto alias = valueToAlias_.at(node->input()); for (auto output : node->outputs()) { - addAlias(output, alias); + makeAliasOf(output, node->input()); } } @@ -574,74 +856,42 @@ void AliasDb::analyzeBroadcastingChunk(Node* node) { for (size_t index = 0; index < inputs.size(); ++index) { // Each inputs[i] is aliased by exactly `nchunks` distinct output tensors: // inputs[i] produces chunks outputs[i * nchunks + k] for k in [0..nchunks) - auto alias = valueToAlias_.at(inputs.at(index)); auto output_begin = outputs.begin() + index * nchunks; for (auto it = output_begin; it != output_begin + nchunks; ++it) { - addAlias(*it, alias); + makeAliasOf(*it, inputs.at(index)); } } } -Symbol AliasDb::getFreshAlias(bool isGraphInput) { - auto num = std::stoll(latestSymbol_.toUnqualString()); - latestSymbol_ = Symbol::fromQualString("alias::" + std::to_string(++num)); - if (isGraphInput) { - graphInputAliases_.insert(latestSymbol_); - } - return latestSymbol_; -} - -// Give this alias to the value. If the value already has alias info, union -// with this alias -void AliasDb::addAlias(const Value* value, AliasInfo alias) { +// Register the fact that `value` is a pointer to `to` +void AliasDb::makeAliasOf(const Value* value, const Value* to) { if (!shouldAnnotate(value)) { + AT_ASSERT(!shouldAnnotate(to)); return; } - if (valueToAlias_.count(value) != 0) { - valueToAlias_[value].unionWith(alias); - } else { - valueToAlias_.insert({value, std::move(alias)}); - } + aliasTracker_->makePointerTo(value, to); } -// Give this alias to the value. If the value already has alias info, union -// with this alias -void AliasDb::addAlias(const Value* value, Symbol alias) { - if (!shouldAnnotate(value)) { - return; - } - if (valueToAlias_.count(value) != 0) { - valueToAlias_[value].addSet(alias); - } else { - AliasInfo aliasInfo; - aliasInfo.addSet(alias); - valueToAlias_.insert({value, std::move(aliasInfo)}); +// Make each value in the `from` list point to its partner in the `to` list +void AliasDb::mapAliases(at::ArrayRef<Value*> from, at::ArrayRef<Value*> to) { + AT_ASSERT(to.size() == from.size()); + for (size_t i = 0; i < to.size(); i++) { + makeAliasOf(from[i], to[i]); } } -// Union the alias info of `value` with `from` -void AliasDb::addAlias(const Value* value, const Value* from) { +void AliasDb::giveFreshAlias(const Value* value) { if (!shouldAnnotate(value)) { - AT_ASSERT(!shouldAnnotate(from)); return; } - addAlias(value, valueToAlias_.at(from)); -} -void AliasDb::mapAliases(at::ArrayRef<Value*> to, at::ArrayRef<Value*> from) { - AT_ASSERT(to.size() == from.size()); - for (size_t i = 0; i < to.size(); i++) { - addAlias(to[i], from[i]); - } -} - -void AliasDb::giveFreshAlias(const Value* value) { - if (valueToAlias_.count(value) != 0) { + if (aliasTracker_->contains(value)) { // Inside a loop, we may have given a fresh alias to this value already, so // skip return; } - addAlias(value, getFreshAlias()); + + aliasTracker_->makeFreshValue(value); } bool AliasDb::moveAfterTopologicallyValid(Node* n, Node* movePoint) { @@ -948,18 +1198,6 @@ void AliasDb::move(Node* toMove, Node* movePoint, MoveSide moveSide) { } } -c10::optional<const Node*> AliasDb::getLastWildcard() const { - auto it = std::max_element( - wildcardNodes_.cbegin(), - wildcardNodes_.cend(), - [this](const Node* a, const Node* b) { return isBeforeSameGraph(a, b); }); - if (it != wildcardNodes_.end()) { - return *it; - } else { - return c10::nullopt; - } -} - bool AliasDb::hasUntrackedEffects(Node* node) const { bool touchesWildcard = false; if (const auto lastWildcard = getLastWildcard()) { @@ -996,5 +1234,17 @@ bool AliasDb::isBeforeSameGraph(const Node* a, const Node* b) const { } AT_ASSERT(false); } + +c10::optional<const Node*> AliasDb::getLastWildcard() const { + auto it = std::max_element( + wildcardNodes_.cbegin(), + wildcardNodes_.cend(), + [this](const Node* a, const Node* b) { return isBeforeSameGraph(a, b); }); + if (it != wildcardNodes_.end()) { + return *it; + } else { + return c10::nullopt; + } +} } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/alias_analysis.h b/torch/csrc/jit/passes/alias_analysis.h index 39357e1c58..4ec511dc0b 100644 --- a/torch/csrc/jit/passes/alias_analysis.h +++ b/torch/csrc/jit/passes/alias_analysis.h @@ -5,6 +5,7 @@ namespace torch { namespace jit { +class AliasTracker; /** * Alias analysis pass. @@ -26,7 +27,8 @@ namespace jit { */ class AliasDb { public: - explicit AliasDb(std::shared_ptr<Graph> graph); + TORCH_API explicit AliasDb(std::shared_ptr<Graph> graph); + TORCH_API ~AliasDb(); // Does `n` write to any alias sets? bool hasWrites(Node* n) const; @@ -41,9 +43,6 @@ class AliasDb { // circumstances. bool hasUntrackedEffects(Node* n) const; - // Get all nodes that write to any alias set inputed/outputed by `n` - std::unordered_set<Node*> getWriters(const Node* n) const; - // Get all the values that `n` writes to. std::unordered_set<const Value*> getWrites(Node* n) const; @@ -72,7 +71,7 @@ class AliasDb { bool couldMoveBeforeTopologically(Node* n, Node* movePoint); // For debugging: print alias db state to stdout - void dump() const; + TORCH_API void dump() const; private: // Helper for topologically-safe node moves. @@ -87,12 +86,16 @@ class AliasDb { // Returns nullopt if there are no wildcard nodes c10::optional<const Node*> getLastWildcard() const; + // Get all nodes that write to any alias set inputed/outputed by `n` + std::unordered_set<Node*> getWriters(const Node* n) const; + // Does `n` write to a value that may alias one of the graph inputs? bool writesToInputAlias(Node* n) const; void analyze(const std::shared_ptr<Graph>& graph); void analyze(Block* block); void analyze(Node* node); + void analyzeImpl(Node* node); void analyzeIf(Node* node); void analyzeLoop(Node* node); @@ -102,33 +105,18 @@ class AliasDb { void analyzeChunk(Node* node); void analyzeBroadcastingChunk(Node* node); - Symbol getFreshAlias(bool isGraphInput = false); - void addAlias(const Value* value, AliasInfo alias); - - void addAlias(const Value* value, Symbol alias); - void addAlias(const Value* value, const Value* from); + void makeAliasOf(const Value* value, const Value* to); void mapAliases(at::ArrayRef<Value*> to, at::ArrayRef<Value*> from); void giveFreshAlias(const Value* value); bool hasUsesAfter(Symbol alias, const Node* n) const; - void buildWildcardIndex(const Block* b); - bool hasWildcardImpl(const Node* n) const; bool writesTo(Node* n, const Value* v) const; - bool isBeforeSameGraph(const Node* lhs, const Node* rhs) const; std::shared_ptr<Graph> graph_; - Symbol latestSymbol_ = Symbol::fromQualString("alias::0"); - std::unordered_map<const Value*, AliasInfo> valueToAlias_; - std::unordered_map<Symbol, std::unordered_set<const Value*>> aliasToValue_; - std::unordered_map<Symbol, std::unordered_set<Node*>> aliasToWrites_; - std::unordered_set<const Node*> wildcardNodes_; - std::unordered_set<Symbol> graphInputAliases_; std::unordered_map<const Graph*, const Node*> subgraphToOwner_; + std::unordered_set<const Node*> wildcardNodes_; + std::unique_ptr<AliasTracker> aliasTracker_; }; - -inline TORCH_API AliasDb AliasAnalysis(std::shared_ptr<Graph> graph) { - return AliasDb(std::move(graph)); -} } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/batch_mm.cpp b/torch/csrc/jit/passes/batch_mm.cpp index bad7d72859..72bc81a566 100644 --- a/torch/csrc/jit/passes/batch_mm.cpp +++ b/torch/csrc/jit/passes/batch_mm.cpp @@ -435,7 +435,7 @@ void BatchMM(std::shared_ptr<Graph>& graph) { // TODO(suo): make BatchMM mutability-safe return; } - auto alias_db = AliasAnalysis(graph); + AliasDb alias_db(graph); BatchMMTreeReduce(graph->block()); BatchMMSide(graph->block(), alias_db); EliminateDeadCode(graph); diff --git a/torch/csrc/jit/passes/common_subexpression_elimination.cpp b/torch/csrc/jit/passes/common_subexpression_elimination.cpp index dc4bfc97ec..c8e6c7bf08 100644 --- a/torch/csrc/jit/passes/common_subexpression_elimination.cpp +++ b/torch/csrc/jit/passes/common_subexpression_elimination.cpp @@ -67,7 +67,7 @@ void EliminateCommonSubexpression( } // namespace void EliminateCommonSubexpression(std::shared_ptr<Graph>& graph) { - const auto aliasDb = AliasAnalysis(graph); + AliasDb aliasDb(graph); EliminateCommonSubexpression( graph->block(), aliasDb, [](Node*) { return nullptr; }); } diff --git a/torch/csrc/jit/passes/constant_propagation.cpp b/torch/csrc/jit/passes/constant_propagation.cpp index ddadbdb2df..774ad3b2fb 100644 --- a/torch/csrc/jit/passes/constant_propagation.cpp +++ b/torch/csrc/jit/passes/constant_propagation.cpp @@ -213,7 +213,7 @@ void ConstantPropagation(Block* block, const AliasDb& aliasDb) { } // anonymous namespace void ConstantPropagation(std::shared_ptr<Graph>& graph) { - const auto aliasDb = AliasAnalysis(graph); + AliasDb aliasDb(graph); ConstantPropagation(graph->block(), aliasDb); EliminateDeadCode(graph); } diff --git a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp index 16e2363b12..97e36fa148 100644 --- a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp +++ b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp @@ -40,7 +40,7 @@ class SubgraphSlicer { bool any_changed = true; while (any_changed) { any_changed = false; - auto aliasDb = AliasAnalysis(graph_); + AliasDb aliasDb(graph_); for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) { bool changed; std::tie(it, changed) = scanNode(*it, aliasDb); diff --git a/torch/csrc/jit/passes/dead_code_elimination.cpp b/torch/csrc/jit/passes/dead_code_elimination.cpp index 238e6db7e2..31bc5c95e7 100644 --- a/torch/csrc/jit/passes/dead_code_elimination.cpp +++ b/torch/csrc/jit/passes/dead_code_elimination.cpp @@ -2,6 +2,7 @@ #include <torch/csrc/jit/ir_views.h> #include <torch/csrc/jit/passes/alias_analysis.h> +#include <torch/csrc/utils/memory.h> #include <unordered_map> @@ -15,7 +16,7 @@ using namespace ::c10::prim; class DeadCodeEliminator { public: explicit DeadCodeEliminator(std::shared_ptr<Graph> graph) - : aliasDb_(AliasAnalysis(std::move(graph))) {} + : aliasDb_(torch::make_unique<AliasDb>(std::move(graph))) {} DeadCodeEliminator() = default; // The algorithm is an inverse mark-and-sweep. Starting from the return node, @@ -268,7 +269,7 @@ class DeadCodeEliminator { } } - c10::optional<AliasDb> aliasDb_; + std::unique_ptr<AliasDb> aliasDb_ = nullptr; std::unordered_map<Node*, bool> memo_; std::unordered_set<Node*> marked_; std::unordered_set<const Value*> liveValues_; diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index 3d12c1bf87..89e27e84bb 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -167,7 +167,7 @@ Value* broadcastSizes(at::ArrayRef<Value*> sizes) { struct GraphFuser { Block* block_; - c10::optional<AliasDb> aliasDb_; + std::unique_ptr<AliasDb> aliasDb_; std::shared_ptr<Graph> graph_; GraphFuser(Block* block, std::shared_ptr<Graph> graph) @@ -1002,7 +1002,7 @@ struct GraphFuser { } void refreshAliasDb() { - aliasDb_ = AliasAnalysis(graph_); + aliasDb_ = torch::make_unique<AliasDb>(graph_); } bool canFuseWithConcat(Value* producer, Node* before_check) { diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 8db81076b6..c5254847e4 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -51,7 +51,7 @@ bool isValidReturnForRunning(Value* v) { class ShapePropagator { public: explicit ShapePropagator(std::shared_ptr<Graph> graph) - : aliasDb_(AliasAnalysis(std::move(graph))) {} + : aliasDb_(std::move(graph)) {} void PropagateShapeOnBlock(Block* block, bool insert_expands = true) { for (Node* node : block->nodes()) { diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index bb628a6019..3e73bde7c9 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -3,6 +3,7 @@ #include <torch/csrc/jit/argument_spec.h> #include <torch/csrc/jit/export.h> #include <torch/csrc/jit/ir.h> +#include <torch/csrc/jit/passes/alias_analysis.h> #include <torch/csrc/jit/passes/python_print.h> #include <torch/csrc/jit/passes/shape_analysis.h> #include <torch/csrc/jit/pybind.h> @@ -179,6 +180,12 @@ void initPythonIRBindings(PyObject* module_) { return ss.str(); }) .def( + "dump_alias_db", + [](std::shared_ptr<Graph> g) { + AliasDb db(g); + db.dump(); + }) + .def( "propagate_shapes", [](std::shared_ptr<Graph> g, std::vector<at::Tensor> inputs, |