summaryrefslogtreecommitdiff
path: root/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp
diff options
context:
space:
mode:
authorChunseok Lee <chunseok.lee@samsung.com>2020-10-29 13:12:50 +0900
committerChunseok Lee <chunseok.lee@samsung.com>2020-10-29 13:12:50 +0900
commitd6b371e095d737922187a518b8faba1ef6f3a2b1 (patch)
tree9d90c09c887b5111389dbedf924f59206411cd5a /compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp
parentc55f8a6db48cda9d3a78048338b7f18c4cca62b8 (diff)
downloadnnfw-d6b371e095d737922187a518b8faba1ef6f3a2b1.tar.gz
nnfw-d6b371e095d737922187a518b8faba1ef6f3a2b1.tar.bz2
nnfw-d6b371e095d737922187a518b8faba1ef6f3a2b1.zip
Imported Upstream version 0.4upstream/0.4
Diffstat (limited to 'compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp')
-rw-r--r--compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp27
1 files changed, 23 insertions, 4 deletions
diff --git a/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp
index 7965dc931..20cd0bab9 100644
--- a/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp
@@ -16,14 +16,17 @@
#include "ReluCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
#include <stdex/Memory.h>
namespace
{
-bool canonicalize_relu(loco::Graph *graph, moco::TFRelu *node)
+bool canonicalize_relu(loco::Graph *graph, moco::tf::TFRelu *node)
{
/**
* @note This will replace TFRelu node with Canonical ReLU
@@ -61,9 +64,25 @@ namespace moco
namespace tf
{
-bool ReluCanonicalizer::transform(TFRelu *node) const
+bool ReluCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_relu(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFRelu *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_relu(graph, tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf