diff options
Diffstat (limited to 'compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp')
-rw-r--r-- | compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp | 36 |
1 files changed, 8 insertions, 28 deletions
diff --git a/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp index a3fcc3b47..f5b991206 100644 --- a/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp +++ b/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp @@ -16,19 +16,15 @@ #include "SqueezeCanonicalizer.h" -#include "Annotations/ShapeInferenceData.h" - -#include "Dialect/TFDialect.h" -#include "Dialect/TFNodes.h" -#include "Dialect/TFNodeVisitor.h" -#include "Dialect/TFNodeImpl.h" +#include <moco/IR/TFDialect.h> +#include <moco/Support/TFShapeInferenceHelper.h> #include <moco/Log.h> namespace { -bool canonicalize_squeeze_to_reshape(loco::Graph *graph, moco::tf::TFSqueeze *node) +bool canonicalize_squeeze_to_reshape(loco::Graph *graph, moco::TFSqueeze *node) { LOGGER(l); @@ -46,12 +42,12 @@ bool canonicalize_squeeze_to_reshape(loco::Graph *graph, moco::tf::TFSqueeze *no * In ---- FixedReshape ----- Out(s) */ - auto squeeze_shape = node->annot<moco::tf::ShapeInferenceData>(); + auto nodeshape = moco::node_shape(node); // canonicalize into FixedReshape is valid when squeeze has shape info // TODO Support general Squeeze case - assert(squeeze_shape); + assert(nodeshape.domain() != loco::Domain::Unknown); - auto squeeze_tensor_shape = squeeze_shape->tensor_shape(); + auto squeeze_tensor_shape = nodeshape.as<loco::TensorShape>(); // Create loco node to replace auto reshape = graph->nodes()->create<loco::FixedReshape>(); @@ -81,25 +77,9 @@ namespace moco namespace tf { -bool SqueezeCanonicalizer::run(loco::Graph *graph) +bool SqueezeCanonicalizer::transform(TFSqueeze *node) const { - auto active_nodes = loco::active_nodes(loco::output_nodes(graph)); - bool changed = false; - - for (auto node : active_nodes) - { - if (node->dialect() == TFDialect::get()) - { - auto tf_squeeze = dynamic_cast<moco::tf::TFSqueeze *>(node); - if (tf_squeeze != nullptr) - { - if (canonicalize_squeeze_to_reshape(graph, tf_squeeze)) - changed = true; - } - } - } - - return changed; + return canonicalize_squeeze_to_reshape(node->graph(), node); } } // namespace tf |