diff options
-rw-r--r-- | compiler/loco/include/loco/IR/CanonicalDialect.h | 2 | ||||
-rw-r--r-- | compiler/loco/include/loco/IR/Graph.h | 15 | ||||
-rw-r--r-- | compiler/loco/src/IR/CanonicalDialect.cpp | 36 | ||||
-rw-r--r-- | compiler/loco/src/IR/Graph.cpp | 21 |
4 files changed, 71 insertions, 3 deletions
diff --git a/compiler/loco/include/loco/IR/CanonicalDialect.h b/compiler/loco/include/loco/IR/CanonicalDialect.h index fad3caac5..940d29a59 100644 --- a/compiler/loco/include/loco/IR/CanonicalDialect.h +++ b/compiler/loco/include/loco/IR/CanonicalDialect.h @@ -30,7 +30,7 @@ namespace loco class CanonicalDialect final : public Dialect { private: - CanonicalDialect() = default; + CanonicalDialect(); public: CanonicalDialect(const CanonicalDialect &) = delete; diff --git a/compiler/loco/include/loco/IR/Graph.h b/compiler/loco/include/loco/IR/Graph.h index f9a53f7ea..86cb65cab 100644 --- a/compiler/loco/include/loco/IR/Graph.h +++ b/compiler/loco/include/loco/IR/Graph.h @@ -234,6 +234,21 @@ private: OutputContext _output_ctx; }; +struct GraphOutputIndexQueryService : public DialectService +{ + virtual ~GraphOutputIndexQueryService() = default; + + /** + * @brief Check whether a given node is associated with any Graph-level output + */ + virtual bool associated(const Node *node) const = 0; + + /** + * WARNING! CALLER SHOULD GUARANTEE that associated(node) is true before invoking this API. + */ + virtual GraphOutputIndex index(const Node *node) const = 0; +}; + // TODO Use "const Graph *" std::vector<Node *> output_nodes(Graph *); diff --git a/compiler/loco/src/IR/CanonicalDialect.cpp b/compiler/loco/src/IR/CanonicalDialect.cpp index b46269d5f..f89ea447b 100644 --- a/compiler/loco/src/IR/CanonicalDialect.cpp +++ b/compiler/loco/src/IR/CanonicalDialect.cpp @@ -15,10 +15,46 @@ */ #include "loco/IR/CanonicalDialect.h" +#include "loco/IR/Graph.h" +#include "loco/IR/Nodes.h" + +#include <stdex/Memory.h> + +#include <cassert> + +namespace +{ + +struct GraphOutputIndexQueryServiceImpl final : public loco::GraphOutputIndexQueryService +{ + bool associated(const loco::Node *node) const final + { + if (auto push = dynamic_cast<const loco::Push *>(node)) + { + return push->indexed(); + } + return false; + } + + loco::GraphOutputIndex index(const loco::Node *node) const final + { + assert(associated(node)); + auto push = dynamic_cast<const loco::Push *>(node); + assert(push != nullptr); + return push->index(); + } +}; + +} // namespace namespace loco { +CanonicalDialect::CanonicalDialect() +{ + service<GraphOutputIndexQueryService>(stdex::make_unique<GraphOutputIndexQueryServiceImpl>()); +} + Dialect *CanonicalDialect::get(void) { static CanonicalDialect d; diff --git a/compiler/loco/src/IR/Graph.cpp b/compiler/loco/src/IR/Graph.cpp index 7029f92d0..345525b9e 100644 --- a/compiler/loco/src/IR/Graph.cpp +++ b/compiler/loco/src/IR/Graph.cpp @@ -74,12 +74,29 @@ std::set<loco::Node *> all_nodes(loco::Graph *g) std::vector<loco::Node *> output_nodes(loco::Graph *g) { + std::map<GraphOutputIndex, loco::Node *> table; + + for (uint32_t n = 0; n < g->nodes()->size(); ++n) + { + auto node = g->nodes()->at(n); + + if (auto service = node->dialect()->service<GraphOutputIndexQueryService>()) + { + if (service->associated(node)) + { + auto output_index = service->index(node); + assert(table.find(output_index) == table.end()); + table[output_index] = node; + } + } + } + std::vector<loco::Node *> res; for (uint32_t n = 0; n < g->outputs()->size(); ++n) { - auto node = push_node(g, n); - res.emplace_back(node); + auto it = table.find(n); + res.emplace_back(it == table.end() ? nullptr : it->second); } return res; |