summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>2019-09-17 16:07:13 +0900
committerGitHub Enterprise <noreply-CODE@samsung.com>2019-09-17 16:07:13 +0900
commit1d5e414d30613a6a7b1f0178d52935bf16b9f661 (patch)
treed2e64595fa8bef39415ce5e50e22189f8f434e5f
parent4731b226438d62ea61edfbe2e532d3cfc225e66f (diff)
downloadnnfw-1d5e414d30613a6a7b1f0178d52935bf16b9f661.tar.gz
nnfw-1d5e414d30613a6a7b1f0178d52935bf16b9f661.tar.bz2
nnfw-1d5e414d30613a6a7b1f0178d52935bf16b9f661.zip
[loco] Introduce GraphOutputIndexQueryService (#7479)
This commit introduces GraphOutputIndexQueryService interface, and revises output_nodes helper to use this new service interface. This commit also implements GraphOutputIndexQueryService for canonical dialect in order to guarantee backward compatibility. Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
-rw-r--r--compiler/loco/include/loco/IR/CanonicalDialect.h2
-rw-r--r--compiler/loco/include/loco/IR/Graph.h15
-rw-r--r--compiler/loco/src/IR/CanonicalDialect.cpp36
-rw-r--r--compiler/loco/src/IR/Graph.cpp21
4 files changed, 71 insertions, 3 deletions
diff --git a/compiler/loco/include/loco/IR/CanonicalDialect.h b/compiler/loco/include/loco/IR/CanonicalDialect.h
index fad3caac5..940d29a59 100644
--- a/compiler/loco/include/loco/IR/CanonicalDialect.h
+++ b/compiler/loco/include/loco/IR/CanonicalDialect.h
@@ -30,7 +30,7 @@ namespace loco
class CanonicalDialect final : public Dialect
{
private:
- CanonicalDialect() = default;
+ CanonicalDialect();
public:
CanonicalDialect(const CanonicalDialect &) = delete;
diff --git a/compiler/loco/include/loco/IR/Graph.h b/compiler/loco/include/loco/IR/Graph.h
index f9a53f7ea..86cb65cab 100644
--- a/compiler/loco/include/loco/IR/Graph.h
+++ b/compiler/loco/include/loco/IR/Graph.h
@@ -234,6 +234,21 @@ private:
OutputContext _output_ctx;
};
+struct GraphOutputIndexQueryService : public DialectService
+{
+ virtual ~GraphOutputIndexQueryService() = default;
+
+ /**
+ * @brief Check whether a given node is associated with any Graph-level output
+ */
+ virtual bool associated(const Node *node) const = 0;
+
+ /**
+ * WARNING! CALLER SHOULD GUARANTEE that associated(node) is true before invoking this API.
+ */
+ virtual GraphOutputIndex index(const Node *node) const = 0;
+};
+
// TODO Use "const Graph *"
std::vector<Node *> output_nodes(Graph *);
diff --git a/compiler/loco/src/IR/CanonicalDialect.cpp b/compiler/loco/src/IR/CanonicalDialect.cpp
index b46269d5f..f89ea447b 100644
--- a/compiler/loco/src/IR/CanonicalDialect.cpp
+++ b/compiler/loco/src/IR/CanonicalDialect.cpp
@@ -15,10 +15,46 @@
*/
#include "loco/IR/CanonicalDialect.h"
+#include "loco/IR/Graph.h"
+#include "loco/IR/Nodes.h"
+
+#include <stdex/Memory.h>
+
+#include <cassert>
+
+namespace
+{
+
+struct GraphOutputIndexQueryServiceImpl final : public loco::GraphOutputIndexQueryService
+{
+ bool associated(const loco::Node *node) const final
+ {
+ if (auto push = dynamic_cast<const loco::Push *>(node))
+ {
+ return push->indexed();
+ }
+ return false;
+ }
+
+ loco::GraphOutputIndex index(const loco::Node *node) const final
+ {
+ assert(associated(node));
+ auto push = dynamic_cast<const loco::Push *>(node);
+ assert(push != nullptr);
+ return push->index();
+ }
+};
+
+} // namespace
namespace loco
{
+CanonicalDialect::CanonicalDialect()
+{
+ service<GraphOutputIndexQueryService>(stdex::make_unique<GraphOutputIndexQueryServiceImpl>());
+}
+
Dialect *CanonicalDialect::get(void)
{
static CanonicalDialect d;
diff --git a/compiler/loco/src/IR/Graph.cpp b/compiler/loco/src/IR/Graph.cpp
index 7029f92d0..345525b9e 100644
--- a/compiler/loco/src/IR/Graph.cpp
+++ b/compiler/loco/src/IR/Graph.cpp
@@ -74,12 +74,29 @@ std::set<loco::Node *> all_nodes(loco::Graph *g)
std::vector<loco::Node *> output_nodes(loco::Graph *g)
{
+ std::map<GraphOutputIndex, loco::Node *> table;
+
+ for (uint32_t n = 0; n < g->nodes()->size(); ++n)
+ {
+ auto node = g->nodes()->at(n);
+
+ if (auto service = node->dialect()->service<GraphOutputIndexQueryService>())
+ {
+ if (service->associated(node))
+ {
+ auto output_index = service->index(node);
+ assert(table.find(output_index) == table.end());
+ table[output_index] = node;
+ }
+ }
+ }
+
std::vector<loco::Node *> res;
for (uint32_t n = 0; n < g->outputs()->size(); ++n)
{
- auto node = push_node(g, n);
- res.emplace_back(node);
+ auto it = table.find(n);
+ res.emplace_back(it == table.end() ? nullptr : it->second);
}
return res;