diff options
Diffstat (limited to 'compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp')
-rw-r--r-- | compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp | 37 |
1 files changed, 29 insertions, 8 deletions
diff --git a/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp index 98af7b693..3b5043fa7 100644 --- a/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp +++ b/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp @@ -16,15 +16,19 @@ #include "SoftmaxCanonicalizer.h" -#include <moco/IR/TFDialect.h> -#include <moco/Support/TFShapeInferenceHelper.h> +#include "Annotations/ShapeInferenceData.h" + +#include "Dialect/TFDialect.h" +#include "Dialect/TFNodes.h" +#include "Dialect/TFNodeVisitor.h" +#include "Dialect/TFNodeImpl.h" #include <moco/Log.h> namespace { -bool canonicalize_softmax(loco::Graph *graph, moco::TFSoftmax *node) +bool canonicalize_softmax(loco::Graph *graph, moco::tf::TFSoftmax *node) { LOGGER(l); @@ -42,11 +46,12 @@ bool canonicalize_softmax(loco::Graph *graph, moco::TFSoftmax *node) * In ---- TensorSoftmax ----- Out(s) */ - auto nodeshape = moco::node_shape(node); + auto softmax_shape = node->annot<moco::tf::ShapeInferenceData>(); + // Canonicalization into TensorSoftmax is valid when softmax has shape info - assert(nodeshape.domain() != loco::Domain::Unknown); + assert(softmax_shape); - auto softmax_tensor_shape = nodeshape.as<loco::TensorShape>(); + auto softmax_tensor_shape = softmax_shape->tensor_shape(); // Create loco node to replace auto softmax = graph->nodes()->create<loco::TensorSoftmax>(); @@ -69,9 +74,25 @@ namespace moco namespace tf { -bool SoftmaxCanonicalizer::transform(TFSoftmax *node) const +bool SoftmaxCanonicalizer::run(loco::Graph *graph) { - return canonicalize_softmax(node->graph(), node); + auto active_nodes = loco::active_nodes(loco::output_nodes(graph)); + bool changed = false; + + for (auto node : active_nodes) + { + if (node->dialect() == TFDialect::get()) + { + auto tf_softmax = dynamic_cast<moco::tf::TFSoftmax *>(node); + if (tf_softmax != nullptr) + { + if (canonicalize_softmax(graph, tf_softmax)) + changed = true; + } + } + } + + return changed; } } // namespace tf |