summaryrefslogtreecommitdiff
path: root/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp')
-rw-r--r--compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp37
1 files changed, 29 insertions, 8 deletions
diff --git a/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp
index 98af7b693..3b5043fa7 100644
--- a/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp
@@ -16,15 +16,19 @@
#include "SoftmaxCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
-#include <moco/Support/TFShapeInferenceHelper.h>
+#include "Annotations/ShapeInferenceData.h"
+
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
#include <moco/Log.h>
namespace
{
-bool canonicalize_softmax(loco::Graph *graph, moco::TFSoftmax *node)
+bool canonicalize_softmax(loco::Graph *graph, moco::tf::TFSoftmax *node)
{
LOGGER(l);
@@ -42,11 +46,12 @@ bool canonicalize_softmax(loco::Graph *graph, moco::TFSoftmax *node)
* In ---- TensorSoftmax ----- Out(s)
*/
- auto nodeshape = moco::node_shape(node);
+ auto softmax_shape = node->annot<moco::tf::ShapeInferenceData>();
+
// Canonicalization into TensorSoftmax is valid when softmax has shape info
- assert(nodeshape.domain() != loco::Domain::Unknown);
+ assert(softmax_shape);
- auto softmax_tensor_shape = nodeshape.as<loco::TensorShape>();
+ auto softmax_tensor_shape = softmax_shape->tensor_shape();
// Create loco node to replace
auto softmax = graph->nodes()->create<loco::TensorSoftmax>();
@@ -69,9 +74,25 @@ namespace moco
namespace tf
{
-bool SoftmaxCanonicalizer::transform(TFSoftmax *node) const
+bool SoftmaxCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_softmax(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_softmax = dynamic_cast<moco::tf::TFSoftmax *>(node);
+ if (tf_softmax != nullptr)
+ {
+ if (canonicalize_softmax(graph, tf_softmax))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf