summaryrefslogtreecommitdiff
path: root/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp')
-rw-r--r--compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp36
1 files changed, 8 insertions, 28 deletions
diff --git a/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp
index a3fcc3b47..f5b991206 100644
--- a/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp
@@ -16,19 +16,15 @@
#include "SqueezeCanonicalizer.h"
-#include "Annotations/ShapeInferenceData.h"
-
-#include "Dialect/TFDialect.h"
-#include "Dialect/TFNodes.h"
-#include "Dialect/TFNodeVisitor.h"
-#include "Dialect/TFNodeImpl.h"
+#include <moco/IR/TFDialect.h>
+#include <moco/Support/TFShapeInferenceHelper.h>
#include <moco/Log.h>
namespace
{
-bool canonicalize_squeeze_to_reshape(loco::Graph *graph, moco::tf::TFSqueeze *node)
+bool canonicalize_squeeze_to_reshape(loco::Graph *graph, moco::TFSqueeze *node)
{
LOGGER(l);
@@ -46,12 +42,12 @@ bool canonicalize_squeeze_to_reshape(loco::Graph *graph, moco::tf::TFSqueeze *no
* In ---- FixedReshape ----- Out(s)
*/
- auto squeeze_shape = node->annot<moco::tf::ShapeInferenceData>();
+ auto nodeshape = moco::node_shape(node);
// canonicalize into FixedReshape is valid when squeeze has shape info
// TODO Support general Squeeze case
- assert(squeeze_shape);
+ assert(nodeshape.domain() != loco::Domain::Unknown);
- auto squeeze_tensor_shape = squeeze_shape->tensor_shape();
+ auto squeeze_tensor_shape = nodeshape.as<loco::TensorShape>();
// Create loco node to replace
auto reshape = graph->nodes()->create<loco::FixedReshape>();
@@ -81,25 +77,9 @@ namespace moco
namespace tf
{
-bool SqueezeCanonicalizer::run(loco::Graph *graph)
+bool SqueezeCanonicalizer::transform(TFSqueeze *node) const
{
- auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
- bool changed = false;
-
- for (auto node : active_nodes)
- {
- if (node->dialect() == TFDialect::get())
- {
- auto tf_squeeze = dynamic_cast<moco::tf::TFSqueeze *>(node);
- if (tf_squeeze != nullptr)
- {
- if (canonicalize_squeeze_to_reshape(graph, tf_squeeze))
- changed = true;
- }
- }
- }
-
- return changed;
+ return canonicalize_squeeze_to_reshape(node->graph(), node);
}
} // namespace tf