diff options
author | Chunseok Lee <chunseok.lee@samsung.com> | 2020-10-29 13:12:50 +0900 |
---|---|---|
committer | Chunseok Lee <chunseok.lee@samsung.com> | 2020-10-29 13:12:50 +0900 |
commit | d6b371e095d737922187a518b8faba1ef6f3a2b1 (patch) | |
tree | 9d90c09c887b5111389dbedf924f59206411cd5a /compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp | |
parent | c55f8a6db48cda9d3a78048338b7f18c4cca62b8 (diff) | |
download | nnfw-d6b371e095d737922187a518b8faba1ef6f3a2b1.tar.gz nnfw-d6b371e095d737922187a518b8faba1ef6f3a2b1.tar.bz2 nnfw-d6b371e095d737922187a518b8faba1ef6f3a2b1.zip |
Imported Upstream version 0.4upstream/0.4
Diffstat (limited to 'compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp')
-rw-r--r-- | compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp | 27 |
1 files changed, 23 insertions, 4 deletions
diff --git a/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp index 574fa3993..a52af05a5 100644 --- a/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp +++ b/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp @@ -16,14 +16,17 @@ #include "StopGradientCanonicalizer.h" -#include <moco/IR/TFDialect.h> +#include "Dialect/TFDialect.h" +#include "Dialect/TFNodes.h" +#include "Dialect/TFNodeVisitor.h" +#include "Dialect/TFNodeImpl.h" #include <moco/Log.h> namespace { -bool canonicalize_stopgradient(loco::Graph *graph, moco::TFStopGradient *node) +bool canonicalize_stopgradient(loco::Graph *graph, moco::tf::TFStopGradient *node) { LOGGER(l); @@ -62,9 +65,25 @@ namespace moco namespace tf { -bool StopGradientCanonicalizer::transform(TFStopGradient *node) const +bool StopGradientCanonicalizer::run(loco::Graph *graph) { - return canonicalize_stopgradient(node->graph(), node); + auto active_nodes = loco::active_nodes(loco::output_nodes(graph)); + bool changed = false; + + for (auto node : active_nodes) + { + if (node->dialect() == TFDialect::get()) + { + auto tf_stopgradient = dynamic_cast<moco::tf::TFStopGradient *>(node); + if (tf_stopgradient != nullptr) + { + if (canonicalize_stopgradient(graph, tf_stopgradient)) + changed = true; + } + } + } + + return changed; } } // namespace tf |