diff options
author | Chunseok Lee <chunseok.lee@samsung.com> | 2020-10-29 13:12:50 +0900 |
---|---|---|
committer | Chunseok Lee <chunseok.lee@samsung.com> | 2020-10-29 13:12:50 +0900 |
commit | d6b371e095d737922187a518b8faba1ef6f3a2b1 (patch) | |
tree | 9d90c09c887b5111389dbedf924f59206411cd5a /compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp | |
parent | c55f8a6db48cda9d3a78048338b7f18c4cca62b8 (diff) | |
download | nnfw-d6b371e095d737922187a518b8faba1ef6f3a2b1.tar.gz nnfw-d6b371e095d737922187a518b8faba1ef6f3a2b1.tar.bz2 nnfw-d6b371e095d737922187a518b8faba1ef6f3a2b1.zip |
Imported Upstream version 0.4upstream/0.4
Diffstat (limited to 'compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp')
-rw-r--r-- | compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp | 27 |
1 files changed, 23 insertions, 4 deletions
diff --git a/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp index 7965dc931..20cd0bab9 100644 --- a/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp +++ b/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp @@ -16,14 +16,17 @@ #include "ReluCanonicalizer.h" -#include <moco/IR/TFDialect.h> +#include "Dialect/TFDialect.h" +#include "Dialect/TFNodes.h" +#include "Dialect/TFNodeVisitor.h" +#include "Dialect/TFNodeImpl.h" #include <stdex/Memory.h> namespace { -bool canonicalize_relu(loco::Graph *graph, moco::TFRelu *node) +bool canonicalize_relu(loco::Graph *graph, moco::tf::TFRelu *node) { /** * @note This will replace TFRelu node with Canonical ReLU @@ -61,9 +64,25 @@ namespace moco namespace tf { -bool ReluCanonicalizer::transform(TFRelu *node) const +bool ReluCanonicalizer::run(loco::Graph *graph) { - return canonicalize_relu(node->graph(), node); + auto active_nodes = loco::active_nodes(loco::output_nodes(graph)); + bool changed = false; + + for (auto node : active_nodes) + { + if (node->dialect() == TFDialect::get()) + { + auto tf_node = dynamic_cast<moco::tf::TFRelu *>(node); + if (tf_node != nullptr) + { + if (canonicalize_relu(graph, tf_node)) + changed = true; + } + } + } + + return changed; } } // namespace tf |