summaryrefslogtreecommitdiff
path: root/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp
diff options
context:
space:
mode:
authorChunseok Lee <chunseok.lee@samsung.com>2020-10-29 13:12:50 +0900
committerChunseok Lee <chunseok.lee@samsung.com>2020-10-29 13:12:50 +0900
commitd6b371e095d737922187a518b8faba1ef6f3a2b1 (patch)
tree9d90c09c887b5111389dbedf924f59206411cd5a /compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp
parentc55f8a6db48cda9d3a78048338b7f18c4cca62b8 (diff)
downloadnnfw-d6b371e095d737922187a518b8faba1ef6f3a2b1.tar.gz
nnfw-d6b371e095d737922187a518b8faba1ef6f3a2b1.tar.bz2
nnfw-d6b371e095d737922187a518b8faba1ef6f3a2b1.zip
Imported Upstream version 0.4upstream/0.4
Diffstat (limited to 'compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp')
-rw-r--r--compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp27
1 files changed, 23 insertions, 4 deletions
diff --git a/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp
index 574fa3993..a52af05a5 100644
--- a/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp
@@ -16,14 +16,17 @@
#include "StopGradientCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
#include <moco/Log.h>
namespace
{
-bool canonicalize_stopgradient(loco::Graph *graph, moco::TFStopGradient *node)
+bool canonicalize_stopgradient(loco::Graph *graph, moco::tf::TFStopGradient *node)
{
LOGGER(l);
@@ -62,9 +65,25 @@ namespace moco
namespace tf
{
-bool StopGradientCanonicalizer::transform(TFStopGradient *node) const
+bool StopGradientCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_stopgradient(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_stopgradient = dynamic_cast<moco::tf::TFStopGradient *>(node);
+ if (tf_stopgradient != nullptr)
+ {
+ if (canonicalize_stopgradient(graph, tf_stopgradient))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf