summaryrefslogtreecommitdiff
path: root/compiler/moco-tf/src/Canonicalization
diff options
context:
space:
mode:
authorChunseok Lee <chunseok.lee@samsung.com>2020-10-29 13:12:50 +0900
committerChunseok Lee <chunseok.lee@samsung.com>2020-10-29 13:12:50 +0900
commitd6b371e095d737922187a518b8faba1ef6f3a2b1 (patch)
tree9d90c09c887b5111389dbedf924f59206411cd5a /compiler/moco-tf/src/Canonicalization
parentc55f8a6db48cda9d3a78048338b7f18c4cca62b8 (diff)
downloadnnfw-d6b371e095d737922187a518b8faba1ef6f3a2b1.tar.gz
nnfw-d6b371e095d737922187a518b8faba1ef6f3a2b1.tar.bz2
nnfw-d6b371e095d737922187a518b8faba1ef6f3a2b1.zip
Imported Upstream version 0.4upstream/0.4
Diffstat (limited to 'compiler/moco-tf/src/Canonicalization')
-rw-r--r--compiler/moco-tf/src/Canonicalization/AddCanonicalizer.cpp24
-rw-r--r--compiler/moco-tf/src/Canonicalization/AddCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.cpp116
-rw-r--r--compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.cpp29
-rw-r--r--compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.cpp86
-rw-r--r--compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.cpp59
-rw-r--r--compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/Conv2DBackpropInputCanonicalizer.cpp371
-rw-r--r--compiler/moco-tf/src/Canonicalization/Conv2DBackpropInputCanonicalizer.h45
-rw-r--r--compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.cpp110
-rw-r--r--compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.cpp105
-rw-r--r--compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.cpp29
-rw-r--r--compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.cpp114
-rw-r--r--compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/MaximumCanonicalizer.cpp34
-rw-r--r--compiler/moco-tf/src/Canonicalization/MaximumCanonicalizer.h47
-rw-r--r--compiler/moco-tf/src/Canonicalization/MeanCanonicalizer.cpp31
-rw-r--r--compiler/moco-tf/src/Canonicalization/MeanCanonicalizer.h47
-rw-r--r--compiler/moco-tf/src/Canonicalization/MulCanonicalizer.cpp23
-rw-r--r--compiler/moco-tf/src/Canonicalization/MulCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/PadCanonicalizer.cpp100
-rw-r--r--compiler/moco-tf/src/Canonicalization/PlaceholderCanonicalizer.cpp102
-rw-r--r--compiler/moco-tf/src/Canonicalization/PlaceholderCanonicalizer.h47
-rw-r--r--compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.cpp23
-rw-r--r--compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.cpp27
-rw-r--r--compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp27
-rw-r--r--compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.cpp47
-rw-r--r--compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp62
-rw-r--r--compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp37
-rw-r--r--compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.cpp27
-rw-r--r--compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.cpp115
-rw-r--r--compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.h (renamed from compiler/moco-tf/src/Canonicalization/PadCanonicalizer.h)17
-rw-r--r--compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp36
-rw-r--r--compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp27
-rw-r--r--compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/SubCanonicalizer.cpp23
-rw-r--r--compiler/moco-tf/src/Canonicalization/SubCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/TFPushCanonicalizer.cpp74
-rw-r--r--compiler/moco-tf/src/Canonicalization/TFPushCanonicalizer.h47
-rw-r--r--compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.cpp27
-rw-r--r--compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.h7
55 files changed, 952 insertions, 1330 deletions
diff --git a/compiler/moco-tf/src/Canonicalization/AddCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/AddCanonicalizer.cpp
index 8028a870c..ef82f3dab 100644
--- a/compiler/moco-tf/src/Canonicalization/AddCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/AddCanonicalizer.cpp
@@ -16,8 +16,8 @@
#include "AddCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
-#include <moco/IR/TFNodes.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
#include "TFEltwiseBinaryCanonicalzeHelper.h"
@@ -26,9 +26,25 @@ namespace moco
namespace tf
{
-bool AddCanonicalizer::transform(TFAdd *node) const
+bool AddCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_eltwise_binary_node(node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFAdd *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_eltwise_binary_node(tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/AddCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/AddCanonicalizer.h
index 53ba9ed58..07b8a72de 100644
--- a/compiler/moco-tf/src/Canonicalization/AddCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/AddCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_ADD_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFAdd to Canonical EltwiseAdd
*/
-class AddCanonicalizer : public SimpleNodeTransform<TFAdd>
+class AddCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "AddCanonicalizer"; }
public:
- bool transform(TFAdd *node) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.cpp
index e07a4f64f..66a71089e 100644
--- a/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.cpp
@@ -16,19 +16,71 @@
#include "AvgPoolCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
-#include <moco/Support/TFShapeInferenceHelper.h>
+#include "Annotations/PadData.h"
+#include "Annotations/StrideData.h"
+#include "Annotations/ShapeInferenceData.h"
+#include "Annotations/WindowData.h"
-#include "CodecHelper.h"
-
-#include <loco/IR/NodeShape.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
#include <moco/Log.h>
+#include <plier/tf/Convert.h>
+
+#include <stdex/Memory.h>
namespace
{
-bool canonicalize_avgpool2d(loco::Graph *graph, moco::TFAvgPool *node)
+using plier::tf::DataLayout;
+
+void set_feature_enc(loco::FeatureEncode *feature_enc, DataLayout data_layout)
+{
+ auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+ enc->perm()->axis(loco::FeatureAxis::Height) = 1;
+ enc->perm()->axis(loco::FeatureAxis::Width) = 2;
+ enc->perm()->axis(loco::FeatureAxis::Depth) = 3;
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+ enc->perm()->axis(loco::FeatureAxis::Depth) = 1;
+ enc->perm()->axis(loco::FeatureAxis::Height) = 2;
+ enc->perm()->axis(loco::FeatureAxis::Width) = 3;
+ }
+
+ feature_enc->encoder(std::move(enc));
+}
+
+void set_feature_dec(loco::FeatureDecode *feature_dec, DataLayout data_layout)
+{
+ auto dec = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+ dec->perm()->axis(loco::FeatureAxis::Height) = 1;
+ dec->perm()->axis(loco::FeatureAxis::Width) = 2;
+ dec->perm()->axis(loco::FeatureAxis::Depth) = 3;
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+ dec->perm()->axis(loco::FeatureAxis::Depth) = 1;
+ dec->perm()->axis(loco::FeatureAxis::Height) = 2;
+ dec->perm()->axis(loco::FeatureAxis::Width) = 3;
+ }
+
+ feature_dec->decoder(std::move(dec));
+}
+
+bool canonicalize_avgpool2d(loco::Graph *graph, moco::tf::TFAvgPool *node)
{
LOGGER(l);
@@ -61,24 +113,30 @@ bool canonicalize_avgpool2d(loco::Graph *graph, moco::TFAvgPool *node)
avgPool2d_node->convention(loco::AvgPool2D::Convention::Valid);
- auto value_shape = moco::node_shape(node->value());
- assert(value_shape.domain() != loco::Domain::Unknown);
+ // paddata to pad
+ auto pad_data = node->annot<moco::tf::PadData>();
+ assert(pad_data != nullptr);
- auto node_stride = moco::stride_of(node->strides(), node->data_layout());
- auto node_window = moco::window_of(node->ksize(), node->data_layout());
+ avgPool2d_node->pad()->top(pad_data->pad()->top());
+ avgPool2d_node->pad()->bottom(pad_data->pad()->bottom());
+ avgPool2d_node->pad()->left(pad_data->pad()->left());
+ avgPool2d_node->pad()->right(pad_data->pad()->right());
- moco::Padding2DInference infer_padding2d;
+ // windowdata to window (ksize to window)
+ auto window_data = node->annot<moco::tf::WindowData>();
+ assert(window_data != nullptr);
- infer_padding2d.padding(node->padding());
- infer_padding2d.stride(node_stride);
- infer_padding2d.window(node_window);
+ auto window = avgPool2d_node->window();
+ window->vertical(window_data->window()->vertical());
+ window->horizontal(window_data->window()->horizontal());
- auto input_feature_shape = moco::as_feature_shape(value_shape, node->data_layout());
- auto input_plane_shape = moco::make_plane_shape(input_feature_shape);
+ // stridedata to stride (strides to stride)
+ auto stride_data = node->annot<moco::tf::StrideData>();
+ assert(stride_data != nullptr);
- *avgPool2d_node->pad() = infer_padding2d(input_plane_shape);
- *avgPool2d_node->stride() = node_stride;
- *avgPool2d_node->window() = node_window;
+ auto stride = avgPool2d_node->stride();
+ stride->vertical(stride_data->stride()->vertical());
+ stride->horizontal(stride_data->stride()->horizontal());
INFO(l) << "Canonicalize TFAvgPool pad = T " << avgPool2d_node->pad()->top() << ", L "
<< avgPool2d_node->pad()->left() << ", B " << avgPool2d_node->pad()->bottom() << ", R "
@@ -105,9 +163,25 @@ namespace moco
namespace tf
{
-bool AvgPoolCanonicalizer::transform(TFAvgPool *node) const
+bool AvgPoolCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_avgpool2d(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFAvgPool *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_avgpool2d(graph, tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.h
index e9c56c868..7d7e6a80b 100644
--- a/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_AVGPOOL_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFAvgPool to Canonical AvgPool2D
*/
-class AvgPoolCanonicalizer : public SimpleNodeTransform<moco::TFAvgPool>
+class AvgPoolCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "AvgPoolCanonicalizer"; }
public:
- bool transform(TFAvgPool *node) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.cpp
index a5568ce1a..37b660e4a 100644
--- a/compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.cpp
@@ -16,9 +16,12 @@
#include "BiasAddCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
-#include <moco/Names.h>
+#include <moco/tf/Names.h>
#include <moco/Log.h>
#include <plier/tf/Convert.h>
@@ -26,7 +29,7 @@ namespace
{
using plier::tf::DataLayout;
-bool canonicalize_biasadd(loco::Graph *graph, moco::TFBiasAdd *node)
+bool canonicalize_biasadd(loco::Graph *graph, moco::tf::TFBiasAdd *node)
{
LOGGER(l);
@@ -100,9 +103,25 @@ namespace moco
namespace tf
{
-bool BiasAddCanonicalizer::transform(TFBiasAdd *node) const
+bool BiasAddCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_biasadd(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_biasadd = dynamic_cast<moco::tf::TFBiasAdd *>(node);
+ if (tf_biasadd != nullptr)
+ {
+ if (canonicalize_biasadd(graph, tf_biasadd))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.h
index ff4032ca9..a30894708 100644
--- a/compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_BIASADD_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFBiasAdd to Canonical BiasAdd
*/
-class BiasAddCanonicalizer final : public SimpleNodeTransform<moco::TFBiasAdd>
+class BiasAddCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "BiasAddCanonicalizer"; }
public:
- bool transform(TFBiasAdd *node) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.cpp
index b59a3f3d7..e3939adb9 100644
--- a/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.cpp
@@ -15,39 +15,27 @@
*/
#include "ConcatV2Canonicalizer.h"
+
#include "LogHelper.h"
-#include <moco/IR/TFDialect.h>
-#include <moco/Support/TFShapeInferenceHelper.h>
+#include "Annotations/ConcatData.h"
+#include "Annotations/ShapeInferenceData.h"
-#include <moco/Log.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
-#include <loco/Service/ShapeInference.h>
+#include <moco/Log.h>
#include <stdex/Memory.h>
-#include <oops/UserExn.h>
namespace
{
using namespace moco::tf;
-bool scalar_value(moco::TFConst *node, int32_t &ret)
-{
- auto nodeshape = node_shape(node);
- if (!(node->dtype() == loco::DataType::S32))
- return false;
-
- auto tensor_shape = nodeshape.as<loco::TensorShape>();
- if (!(tensor_shape.rank() == 0 || tensor_shape.rank() == 1))
- return false;
-
- ret = node->at<loco::DataType::S32>(0);
-
- return true;
-}
-
-bool canonicalize_concat(loco::Graph *graph, moco::TFConcatV2 *node)
+bool canonicalize_concat(loco::Graph *graph, moco::tf::TFConcatV2 *node)
{
LOGGER(l);
@@ -83,43 +71,19 @@ bool canonicalize_concat(loco::Graph *graph, moco::TFConcatV2 *node)
const int num_values = node->num_values();
assert(num_values >= 2);
- // get axis absolute value
- auto value_a = node->values(0);
- if (!loco::shape_known(value_a))
- return false;
+ // get axis value
+ auto concat_data = node->annot<ConcatData>();
+ assert(concat_data != nullptr);
+ auto axis_value = concat_data->axis();
- uint32_t node_rank = 0;
- {
- auto value_a_shape = moco::node_shape(value_a);
- assert(value_a_shape.domain() == loco::Domain::Tensor);
-
- auto value_a_tensor_shape = value_a_shape.as<loco::TensorShape>();
- node_rank = value_a_tensor_shape.rank();
- }
+ auto shapedata = node->annot<ShapeInferenceData>();
+ auto node_rank = shapedata->rank();
- int32_t axis_value = 0;
- {
- // axis should be TFConst
- auto axis_node = node->axis();
- auto tfconst = dynamic_cast<moco::TFConst *>(axis_node);
- if (tfconst == nullptr)
- {
- // TODO Check this: this error can be from TFOptimizatier.
- throw oops::UserExn("ConcatV2 node has invalid input for axis", node->name());
- }
- auto result = scalar_value(tfconst, axis_value);
- if (!result)
- {
- // TODO Check this: this error can be from TFOptimizatier.
- throw oops::UserExn("ConcatV2 node has invalid input for axis", node->name());
- }
- }
uint32_t axis_absolute = (axis_value >= 0) ? axis_value : (int32_t)node_rank + axis_value;
INFO(l) << "canonicalize_concat axis(" << axis_absolute << "), value(" << axis_value << "), rank("
<< node_rank << ")";
- // Convert series of TensorConcat if num_values > 2
auto concat_node = graph->nodes()->create<loco::TensorConcat>();
concat_node->lhs(node->values(0));
concat_node->rhs(node->values(1));
@@ -151,9 +115,25 @@ namespace moco
namespace tf
{
-bool ConcatV2Canonicalizer::transform(TFConcatV2 *node) const
+bool ConcatV2Canonicalizer::run(loco::Graph *graph)
{
- return canonicalize_concat(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFConcatV2 *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_concat(graph, tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.h b/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.h
index e6b471b89..4448ddb16 100644
--- a/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_CONCATV2_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFConcatV2 to Canonical TensorConcat
*/
-class ConcatV2Canonicalizer : public SimpleNodeTransform<moco::TFConcatV2>
+class ConcatV2Canonicalizer : public Transform
{
public:
const char *name(void) const final { return "ConcatV2Canonicalizer"; }
public:
- bool transform(moco::TFConcatV2 *node) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.cpp
index 60629cd5a..dea97f94a 100644
--- a/compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.cpp
@@ -16,17 +16,18 @@
#include "ConstCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
-#include <moco/Names.h>
+#include <moco/tf/Names.h>
#include <moco/Log.h>
-#include <oops/UserExn.h>
-
namespace
{
-bool canonicalize_const(loco::Graph *graph, moco::TFConst *node)
+bool canonicalize_const(loco::Graph *graph, moco::tf::TFConst *node)
{
LOGGER(l);
@@ -54,27 +55,13 @@ bool canonicalize_const(loco::Graph *graph, moco::TFConst *node)
const_node->dtype(dtype);
auto rank = node->rank();
-
- if (rank == 0)
- {
- // This routine implements a workaround that converts a scalar constant (rank-0 tensor)
- // into a rank-1 tensor of shape [1].
- //
- // TODO Revise this implementation later
- const_node->rank(1);
- const_node->dim(0) = 1;
- }
- else
+ const_node->rank(rank);
+ for (uint32_t r = 0; r < rank; ++r)
{
- const_node->rank(rank);
-
- for (uint32_t r = 0; r < rank; ++r)
- {
- if (node->dim(r).known())
- const_node->dim(r) = node->dim(r);
- else
- const_node->dim(r).unset();
- }
+ if (node->dim(r).known())
+ const_node->dim(r) = node->dim(r);
+ else
+ const_node->dim(r).unset();
}
switch (dtype)
@@ -100,7 +87,7 @@ bool canonicalize_const(loco::Graph *graph, moco::TFConst *node)
break;
}
default:
- throw oops::UserExn("Const has unsupported data type", node->name());
+ throw std::runtime_error("NYI for this DataType");
}
// update graph
@@ -118,9 +105,25 @@ namespace moco
namespace tf
{
-bool ConstCanonicalizer::transform(TFConst *node) const
+bool ConstCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_const(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_const = dynamic_cast<moco::tf::TFConst *>(node);
+ if (tf_const != nullptr)
+ {
+ if (canonicalize_const(graph, tf_const))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.h
index 1b0b2b867..53f3ca8e3 100644
--- a/compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_CONST_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFConst to Canonical ConstGen
*/
-class ConstCanonicalizer : public SimpleNodeTransform<moco::TFConst>
+class ConstCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "ConstCanonicalizer"; }
public:
- bool transform(moco::TFConst *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/Conv2DBackpropInputCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/Conv2DBackpropInputCanonicalizer.cpp
deleted file mode 100644
index d3cbd4ab3..000000000
--- a/compiler/moco-tf/src/Canonicalization/Conv2DBackpropInputCanonicalizer.cpp
+++ /dev/null
@@ -1,371 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "Conv2DBackpropInputCanonicalizer.h"
-
-#include <moco/IR/TFDialect.h>
-
-#include "CodecHelper.h"
-
-#include <loco/IR/Stride.h>
-#include <loco/IR/Padding2D.h>
-#include <loco/Service/ShapeInference.h>
-
-#include <oops/UserExn.h>
-
-namespace
-{
-using plier::tf::DataLayout;
-
-void set_filter_enc(loco::FilterEncode *filter_enc)
-{
- auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Filter>>();
-
- // In TensorFlow, Conv2dBackpropInput's filter is a 4-D tensor of following shape:
- // [filter_height, filter_width, out_channels, in_channels] or HWOI or HWNC (in/out in loco sense)
- enc->perm()->axis(loco::FilterAxis::Height) = 0;
- enc->perm()->axis(loco::FilterAxis::Width) = 1;
- enc->perm()->axis(loco::FilterAxis::Count) = 2;
- enc->perm()->axis(loco::FilterAxis::Depth) = 3;
-
- filter_enc->encoder(std::move(enc));
-}
-
-} // namespace
-
-namespace
-{
-
-bool stride_2d_from_4d(loco::Stride<2> &ret, const std::vector<int64_t> &strides_4d,
- const DataLayout data_layout)
-{
- if (!(strides_4d.size() == 4))
- return false;
-
- switch (data_layout)
- {
- case DataLayout::NHWC:
- ret.vertical(strides_4d.at(1));
- ret.horizontal(strides_4d.at(2));
- break;
- case DataLayout::NCHW:
- ret.vertical(strides_4d.at(2));
- ret.horizontal(strides_4d.at(3));
- break;
- default:
- return false;
- }
- return true;
-}
-
-struct PlaneShape
-{
- loco::Dimension vertical;
- loco::Dimension horizontal;
-};
-
-class Padding2DInference final
-{
-public:
- Padding2DInference(const moco::TFNode *node) { _node = node; }
-
-public:
- loco::Padding2D operator()(void);
-
-public:
- PlaneShape &input() { return _input; }
- PlaneShape &output() { return _output; }
- loco::Stride<2> &stride() { return _stride; }
- loco::Window<2> &window() { return _window; }
- moco::TFPadding &padding() { return _padding; }
-
-private:
- /// @brief Check whether ingredients set by non-default values
- bool ready()
- {
- if (not input().vertical.known())
- return false;
- if (not input().horizontal.known())
- return false;
- if (not output().vertical.known())
- return false;
- if (not output().horizontal.known())
- return false;
- if (stride().vertical() == 0)
- return false;
- if (stride().horizontal() == 0)
- return false;
- if (window().vertical() == 0)
- return false;
- if (window().horizontal() == 0)
- return false;
- if (padding().empty())
- return false;
-
- return true;
- }
-
- inline uint32_t tight_output_for_valid_padding(uint32_t input, uint32_t stride, uint32_t filter)
- {
- return stride * (input - 1) + filter;
- }
-
- /**
- * @note For Conv2DBackpropInput SAME padding, TensorFlow requires this condition to hold
- *
- * Reference: `::tensorflow::GetWindowedOutputSizeVerboseV2()` from TensorFlow project
- */
- inline bool same_padding_applicable(uint32_t input, uint32_t output, uint32_t stride)
- {
- // Here 'input' and 'output' means Conv2DBackpropInput's actual node input and output.
- // Then these three conditions are equivalent:
- //
- // input == floor((output + stride - 1) / stride)
- // input == ceil(output / stride)
- // (stride * (input - 1) < output) and (output <= stride * input)
- return (stride * (input - 1) < output) and (output <= stride * input);
- }
-
- inline uint32_t padding_needed(uint32_t input, uint32_t output, uint32_t stride, uint32_t filter)
- {
- return stride * (input - 1) + filter - output;
- }
-
-private:
- const moco::TFNode *_node;
- PlaneShape _input;
- PlaneShape _output;
- loco::Stride<2> _stride;
- loco::Window<2> _window;
- moco::TFPadding _padding;
-};
-
-loco::Padding2D Padding2DInference::operator()(void)
-{
- assert(ready());
-
- if (padding() == "VALID")
- {
- // In case of VALID padding, TensorFlow accepts any size same or larger than
- // 'tight fit' output. When output size (set by 'input sizes' node input) is
- // larger than tight fit, extra spaces filled with zero.
- auto tight_output_vertical = tight_output_for_valid_padding(
- input().vertical.value(), stride().vertical(), window().vertical());
- auto tight_output_horizontal = tight_output_for_valid_padding(
- input().horizontal.value(), stride().horizontal(), window().horizontal());
-
- if (output().vertical.value() < tight_output_vertical or
- output().horizontal.value() < tight_output_horizontal)
- throw oops::UserExn("input_sizes is too small", _node->name());
-
- // Currently, only accept tight fit.
- // TODO Support non-tight case by adding zero padding operation
- assert(output().vertical.value() == tight_output_vertical);
- assert(output().horizontal.value() == tight_output_horizontal);
-
- return loco::Padding2D(0, 0, 0, 0);
- }
-
- if (padding() == "SAME")
- {
- // This condition is required by TensorFlow
- if (not same_padding_applicable(input().vertical.value(), output().vertical.value(),
- stride().vertical()) or
- not same_padding_applicable(input().horizontal.value(), output().horizontal.value(),
- stride().horizontal()))
- throw oops::UserExn("Size mismatch for SAME padding", _node->name());
-
- auto whole_pad_vertical = padding_needed(input().vertical.value(), output().vertical.value(),
- stride().vertical(), window().vertical());
- auto whole_pad_horizontal =
- padding_needed(input().horizontal.value(), output().horizontal.value(),
- stride().horizontal(), window().horizontal());
-
- loco::Padding2D res;
-
- res.top(whole_pad_vertical / 2);
- res.bottom(whole_pad_vertical - res.top());
- res.left(whole_pad_horizontal / 2);
- res.right(whole_pad_horizontal - res.left());
-
- return res;
- }
-
- throw oops::UserExn("Usupported padding " + padding(), _node->name());
-}
-
-/**
- * @param[out] ret PlaneShape extracted from 'node' with given 'data_layout'
- * @param[in] node
- * @param[in] data_layout
- *
- * @return true on success
- */
-bool set_plane_shape(PlaneShape &ret, const loco::Node *node, const DataLayout data_layout)
-{
- auto tensor_shape = loco::shape_get(node).as<loco::TensorShape>();
- if (!(tensor_shape.rank() == 4))
- return false;
-
- switch (data_layout)
- {
- case DataLayout::NHWC:
- ret.vertical = tensor_shape.dim(1).value();
- ret.horizontal = tensor_shape.dim(2).value();
- break;
- case DataLayout::NCHW:
- ret.vertical = tensor_shape.dim(2).value();
- ret.horizontal = tensor_shape.dim(3).value();
- break;
- default:
- return false;
- }
-
- return true;
-}
-
-/**
- * @param[out] ret 2D Window extracted from HW** filter node
- * @param[in] filter_node
- *
- * @return true on success
- */
-bool set_window(loco::Window<2> &ret, const loco::Node *filter_node)
-{
- auto tensor_shape = loco::shape_get(filter_node).as<loco::TensorShape>();
- assert(tensor_shape.rank() == 4);
-
- ret.vertical(tensor_shape.dim(0).value());
- ret.horizontal(tensor_shape.dim(1).value());
-
- return true;
-}
-
-} // namespace
-
-namespace
-{
-
-bool canonicalize_conv2d_backprop_input(loco::Graph *graph,
- moco::TFConv2DBackpropInput *conv2d_backprop)
-{
- /**
- * @note This will replace TFConv2DBackpropInput node with canonical
- * FeatureEncode + FilterEncode + TransposedConv2D + FeatureDecode
- *
- * Before
- * input_sizes ----
- * \
- * filter -------- TFConv2DBackpropInput --- output(s)
- * /
- * out_backprop ---
- *
- * After
- * input_sizes ----
- * \
- * filter -------- TFConv2DBackpropInput ---
- * /
- * out_backprop ---
- *
- * filter ------ FilterEncode ------ TransposedConv2D --- FeatureDecode --- output(s)
- * (as ker) /
- * out_backprop --- FeatureEncode ---
- * (as ifm)
- */
-
- if (!loco::shape_known(conv2d_backprop->out_backprop()))
- return false;
- if (!loco::shape_known(conv2d_backprop))
- return false;
- if (!loco::shape_known(conv2d_backprop->filter()))
- return false;
-
- auto data_layout = plier::tf::as_data_layout(conv2d_backprop->data_layout());
-
- // Nodes to replace
- auto feature_enc = graph->nodes()->create<loco::FeatureEncode>();
- auto filter_enc = graph->nodes()->create<loco::FilterEncode>();
- auto tr_conv2d = graph->nodes()->create<loco::TransposedConv2D>();
- auto feature_dec = graph->nodes()->create<loco::FeatureDecode>();
-
- set_feature_enc(feature_enc, data_layout);
- set_filter_enc(filter_enc);
- set_feature_dec(feature_dec, data_layout);
-
- // Attributes for new TransposedConv2D
- loco::Stride<2> stride;
- loco::Padding2D pad;
-
- // Get attributes
- {
- if (!stride_2d_from_4d(stride, conv2d_backprop->strides(), data_layout))
- throw oops::UserExn("Unsupported strides", conv2d_backprop->name());
-
- Padding2DInference infer_pad(conv2d_backprop);
-
- if (!set_plane_shape(infer_pad.input(), conv2d_backprop->out_backprop(), data_layout))
- throw oops::UserExn("Unsupported out_backprop data_format", conv2d_backprop->name());
- if (!set_plane_shape(infer_pad.output(), conv2d_backprop, data_layout))
- throw oops::UserExn("Unsupported data_format", conv2d_backprop->name());
- if (!set_window(infer_pad.window(), conv2d_backprop->filter()))
- throw oops::UserExn("Unsupported filter shape", conv2d_backprop->name());
- infer_pad.stride() = stride;
- infer_pad.padding() = conv2d_backprop->padding();
-
- // Run padding infer_pad
- pad = infer_pad();
- }
-
- // Set attributes
- tr_conv2d->pad()->top(pad.top());
- tr_conv2d->pad()->bottom(pad.bottom());
- tr_conv2d->pad()->left(pad.left());
- tr_conv2d->pad()->right(pad.right());
-
- tr_conv2d->stride()->vertical(stride.vertical());
- tr_conv2d->stride()->horizontal(stride.horizontal());
-
- // Update graph
- auto input_node = conv2d_backprop->out_backprop();
- auto filter_node = conv2d_backprop->filter();
-
- // Update connections
- feature_enc->input(input_node);
- filter_enc->input(filter_node);
- tr_conv2d->ifm(feature_enc);
- tr_conv2d->ker(filter_enc);
- feature_dec->input(tr_conv2d);
-
- // Replace old conv2d_backprop
- replace(conv2d_backprop).with(feature_dec);
-
- return true;
-}
-
-} // namespace
-
-namespace moco
-{
-namespace tf
-{
-
-bool Conv2DBackpropInputCanonicalizer::transform(TFConv2DBackpropInput *node) const
-{
- return canonicalize_conv2d_backprop_input(node->graph(), node);
-}
-
-} // namespace tf
-} // namespace moco
diff --git a/compiler/moco-tf/src/Canonicalization/Conv2DBackpropInputCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/Conv2DBackpropInputCanonicalizer.h
deleted file mode 100644
index bc37bb9cb..000000000
--- a/compiler/moco-tf/src/Canonicalization/Conv2DBackpropInputCanonicalizer.h
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __MOCO_TF_CONV2DBACKPROPINPUT_CANONICALIZER_H__
-#define __MOCO_TF_CONV2DBACKPROPINPUT_CANONICALIZER_H__
-
-#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
-
-#include <loco.h>
-
-namespace moco
-{
-namespace tf
-{
-
-/// @brief Convert TFConv2DBackpropInput to Canonical TransposedConv2D
-class Conv2DBackpropInputCanonicalizer : public SimpleNodeTransform<moco::TFConv2DBackpropInput>
-{
-public:
- const char *name(void) const final { return "Conv2DBackpropInputCanonicalizer"; }
-
-public:
- bool transform(moco::TFConv2DBackpropInput *) const final;
-};
-
-} // namespace tf
-} // namespace moco
-
-#endif // __MOCO_TF_CONV2DBACKPROPINPUT_CANONICALIZER_H__
diff --git a/compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.cpp
index a955793a8..f34339d0f 100644
--- a/compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.cpp
@@ -16,18 +16,46 @@
#include "Conv2DCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
-#include <moco/Support/TFShapeInferenceHelper.h>
+#include "Annotations/PadData.h"
+#include "Annotations/StrideData.h"
-#include "CodecHelper.h"
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
#include <moco/Log.h>
+#include <plier/tf/Convert.h>
+
+#include <stdex/Memory.h>
namespace
{
using plier::tf::DataLayout;
-void set_filter_enc(loco::FilterEncode *filter_enc)
+void set_feature_enc(loco::FeatureEncode *feature_enc, DataLayout data_layout)
+{
+ auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+ enc->perm()->axis(loco::FeatureAxis::Height) = 1;
+ enc->perm()->axis(loco::FeatureAxis::Width) = 2;
+ enc->perm()->axis(loco::FeatureAxis::Depth) = 3;
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+ enc->perm()->axis(loco::FeatureAxis::Depth) = 1;
+ enc->perm()->axis(loco::FeatureAxis::Height) = 2;
+ enc->perm()->axis(loco::FeatureAxis::Width) = 3;
+ }
+
+ feature_enc->encoder(std::move(enc));
+}
+
+void set_filter_enc(loco::FilterEncode *filter_enc, DataLayout data_layout)
{
auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Filter>>();
@@ -41,7 +69,29 @@ void set_filter_enc(loco::FilterEncode *filter_enc)
filter_enc->encoder(std::move(enc));
}
-bool canonicalize_conv2d(loco::Graph *graph, moco::TFConv2D *node)
+void set_feature_dec(loco::FeatureDecode *feature_dec, DataLayout data_layout)
+{
+ auto dec = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+ dec->perm()->axis(loco::FeatureAxis::Height) = 1;
+ dec->perm()->axis(loco::FeatureAxis::Width) = 2;
+ dec->perm()->axis(loco::FeatureAxis::Depth) = 3;
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+ dec->perm()->axis(loco::FeatureAxis::Depth) = 1;
+ dec->perm()->axis(loco::FeatureAxis::Height) = 2;
+ dec->perm()->axis(loco::FeatureAxis::Width) = 3;
+ }
+
+ feature_dec->decoder(std::move(dec));
+}
+
+bool canonicalize_conv2d(loco::Graph *graph, moco::tf::TFConv2D *node)
{
LOGGER(l);
@@ -75,29 +125,23 @@ bool canonicalize_conv2d(loco::Graph *graph, moco::TFConv2D *node)
auto feature_dec = graph->nodes()->create<loco::FeatureDecode>();
set_feature_enc(feature_enc, data_layout);
- set_filter_enc(filter_enc);
+ set_filter_enc(filter_enc, data_layout);
set_feature_dec(feature_dec, data_layout);
- auto input_shape = moco::node_shape(node->input());
- assert(input_shape.domain() != loco::Domain::Unknown);
-
- auto ker_shape = moco::node_shape(node->filter());
- auto ker_tensor_shape = ker_shape.as<loco::TensorShape>(); // in HWIO
-
- auto node_stride = moco::stride_of(node->strides(), node->data_layout());
- auto node_window = moco::window_of(ker_tensor_shape, "HWIO");
-
- moco::Padding2DInference infer_padding2d;
+ // Set Conv2D attributes from TFConv2D
+ auto pad_data = node->annot<moco::tf::PadData>();
+ assert(pad_data != nullptr);
- infer_padding2d.padding(node->padding());
- infer_padding2d.stride(node_stride);
- infer_padding2d.window(node_window);
+ conv2d->pad()->top(pad_data->pad()->top());
+ conv2d->pad()->bottom(pad_data->pad()->bottom());
+ conv2d->pad()->left(pad_data->pad()->left());
+ conv2d->pad()->right(pad_data->pad()->right());
- auto input_feature_shape = moco::as_feature_shape(input_shape, node->data_layout());
- auto input_plane_shape = moco::make_plane_shape(input_feature_shape);
+ auto stride_data = node->annot<moco::tf::StrideData>();
+ assert(stride_data != nullptr);
- *conv2d->pad() = infer_padding2d(input_plane_shape);
- *conv2d->stride() = node_stride;
+ conv2d->stride()->vertical(stride_data->stride()->vertical());
+ conv2d->stride()->horizontal(stride_data->stride()->horizontal());
// update graph
auto node_A = node->input();
@@ -123,9 +167,25 @@ namespace moco
namespace tf
{
-bool Conv2DCanonicalizer::transform(TFConv2D *node) const
+bool Conv2DCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_conv2d(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFConv2D *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_conv2d(graph, tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.h
index ea39667f3..6be264f90 100644
--- a/compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_CONV2D_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFConv2D to Canonical Conv2D
*/
-class Conv2DCanonicalizer : public SimpleNodeTransform<TFConv2D>
+class Conv2DCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "Conv2DCanonicalizer"; }
public:
- bool transform(TFConv2D *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.cpp
index 50dddf637..ee63efa2f 100644
--- a/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.cpp
@@ -16,18 +16,47 @@
#include "DepthwiseConv2dNativeCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
-#include <moco/Support/TFShapeInferenceHelper.h>
+#include "Annotations/PadData.h"
+#include "Annotations/ShapeInferenceData.h"
+#include "Annotations/StrideData.h"
-#include "CodecHelper.h"
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
#include <moco/Log.h>
+#include <plier/tf/Convert.h>
+
+#include <stdex/Memory.h>
namespace
{
using plier::tf::DataLayout;
+void set_feature_enc(loco::FeatureEncode *feature_enc, DataLayout data_layout)
+{
+ auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+ enc->perm()->axis(loco::FeatureAxis::Height) = 1;
+ enc->perm()->axis(loco::FeatureAxis::Width) = 2;
+ enc->perm()->axis(loco::FeatureAxis::Depth) = 3;
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+ enc->perm()->axis(loco::FeatureAxis::Depth) = 1;
+ enc->perm()->axis(loco::FeatureAxis::Height) = 2;
+ enc->perm()->axis(loco::FeatureAxis::Width) = 3;
+ }
+
+ feature_enc->encoder(std::move(enc));
+}
+
void set_filter_enc(loco::DepthwiseFilterEncode *filter_enc)
{
auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::DepthwiseFilter>>();
@@ -42,7 +71,29 @@ void set_filter_enc(loco::DepthwiseFilterEncode *filter_enc)
filter_enc->encoder(std::move(enc));
}
-bool canonicalize_depthwiseconv2dnative(loco::Graph *graph, moco::TFDepthwiseConv2dNative *node)
+void set_feature_dec(loco::FeatureDecode *feature_dec, DataLayout data_layout)
+{
+ auto dec = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+ dec->perm()->axis(loco::FeatureAxis::Height) = 1;
+ dec->perm()->axis(loco::FeatureAxis::Width) = 2;
+ dec->perm()->axis(loco::FeatureAxis::Depth) = 3;
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+ dec->perm()->axis(loco::FeatureAxis::Depth) = 1;
+ dec->perm()->axis(loco::FeatureAxis::Height) = 2;
+ dec->perm()->axis(loco::FeatureAxis::Width) = 3;
+ }
+
+ feature_dec->decoder(std::move(dec));
+}
+
+bool canonicalize_depthwiseconv2dnative(loco::Graph *graph, moco::tf::TFDepthwiseConv2dNative *node)
{
LOGGER(l);
@@ -83,24 +134,20 @@ bool canonicalize_depthwiseconv2dnative(loco::Graph *graph, moco::TFDepthwiseCon
set_filter_enc(filter_enc);
set_feature_dec(feature_dec, data_layout);
- // Calculate Pad and Stride from inference
- auto input_shape = moco::node_shape(node->input());
- auto ker_shape = moco::node_shape(node->filter());
- auto ker_tensor_shape = ker_shape.as<loco::TensorShape>();
- auto node_stride = moco::stride_of(node->strides(), node->data_layout());
- auto node_window = moco::window_of(ker_tensor_shape, "HWCM");
-
- moco::Padding2DInference infer_padding2d;
+ // Set DetphwiseConv2D attributes from TFDepthwiseConv2dNative
+ auto pad_data = node->annot<moco::tf::PadData>();
+ assert(pad_data != nullptr);
- infer_padding2d.padding(node->padding());
- infer_padding2d.stride(node_stride);
- infer_padding2d.window(node_window);
+ depthwiseconv2d->pad()->top(pad_data->pad()->top());
+ depthwiseconv2d->pad()->bottom(pad_data->pad()->bottom());
+ depthwiseconv2d->pad()->left(pad_data->pad()->left());
+ depthwiseconv2d->pad()->right(pad_data->pad()->right());
- auto input_feature_shape = moco::as_feature_shape(input_shape, node->data_layout());
- auto input_plane_shape = moco::make_plane_shape(input_feature_shape);
+ auto stride_data = node->annot<moco::tf::StrideData>();
+ assert(stride_data != nullptr);
- *depthwiseconv2d->pad() = infer_padding2d(input_plane_shape);
- *depthwiseconv2d->stride() = node_stride;
+ depthwiseconv2d->stride()->vertical(stride_data->stride()->vertical());
+ depthwiseconv2d->stride()->horizontal(stride_data->stride()->horizontal());
// update graph
auto node_A = node->input();
@@ -128,9 +175,25 @@ namespace moco
namespace tf
{
-bool DepthwiseConv2dNativeCanonicalizer::transform(TFDepthwiseConv2dNative *node) const
+bool DepthwiseConv2dNativeCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_depthwiseconv2dnative(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFDepthwiseConv2dNative *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_depthwiseconv2dnative(graph, tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.h
index 704e1ade9..9bb8c5ad8 100644
--- a/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_DEPTHWISE_CONV2D_NATIVE_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
namespace moco
{
@@ -30,13 +27,13 @@ namespace tf
/**
* @brief Convert TFDepthwiseConv2dNative to Canonical DepthwiseConv2D
*/
-class DepthwiseConv2dNativeCanonicalizer : public SimpleNodeTransform<moco::TFDepthwiseConv2dNative>
+class DepthwiseConv2dNativeCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "DepthwiseConv2dNativeCanonicalizer"; }
public:
- bool transform(moco::TFDepthwiseConv2dNative *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.cpp
index 3b680cf04..c4d5d8063 100644
--- a/compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.cpp
@@ -18,15 +18,18 @@
#include "Convert.h"
-#include <moco/IR/TFDialect.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
-#include <moco/Names.h>
+#include <moco/tf/Names.h>
#include <moco/Log.h>
namespace
{
-bool canonicalize_identity(loco::Graph *graph, moco::TFIdentity *node)
+bool canonicalize_identity(loco::Graph *graph, moco::tf::TFIdentity *node)
{
LOGGER(l);
@@ -69,9 +72,25 @@ namespace moco
namespace tf
{
-bool IdentityCanonicalizer::transform(TFIdentity *node) const
+bool IdentityCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_identity(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_identity = dynamic_cast<moco::tf::TFIdentity *>(node);
+ if (tf_identity != nullptr)
+ {
+ if (canonicalize_identity(graph, tf_identity))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.h
index 59b2894c5..81aee178a 100644
--- a/compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_IDENTITY_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFIdentity to Canonical Forward
*/
-class IdentityCanonicalizer : public SimpleNodeTransform<moco::TFIdentity>
+class IdentityCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "IdentityCanonicalizer"; }
public:
- bool transform(moco::TFIdentity *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.cpp
index 06a605717..c46fbd208 100644
--- a/compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.cpp
@@ -16,17 +16,70 @@
#include "MaxPoolCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
-#include <moco/Support/TFShapeInferenceHelper.h>
+#include "Annotations/PadData.h"
+#include "Annotations/StrideData.h"
+#include "Annotations/WindowData.h"
-#include "CodecHelper.h"
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
#include <moco/Log.h>
+#include <plier/tf/Convert.h>
+
+#include <stdex/Memory.h>
namespace
{
-bool canonicalize_maxpool2d(loco::Graph *graph, moco::TFMaxPool *node)
+using plier::tf::DataLayout;
+
+void set_feature_enc(loco::FeatureEncode *feature_enc, DataLayout data_layout)
+{
+ auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+ enc->perm()->axis(loco::FeatureAxis::Height) = 1;
+ enc->perm()->axis(loco::FeatureAxis::Width) = 2;
+ enc->perm()->axis(loco::FeatureAxis::Depth) = 3;
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+ enc->perm()->axis(loco::FeatureAxis::Depth) = 1;
+ enc->perm()->axis(loco::FeatureAxis::Height) = 2;
+ enc->perm()->axis(loco::FeatureAxis::Width) = 3;
+ }
+
+ feature_enc->encoder(std::move(enc));
+}
+
+void set_feature_dec(loco::FeatureDecode *feature_dec, DataLayout data_layout)
+{
+ auto dec = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+ dec->perm()->axis(loco::FeatureAxis::Height) = 1;
+ dec->perm()->axis(loco::FeatureAxis::Width) = 2;
+ dec->perm()->axis(loco::FeatureAxis::Depth) = 3;
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+ dec->perm()->axis(loco::FeatureAxis::Depth) = 1;
+ dec->perm()->axis(loco::FeatureAxis::Height) = 2;
+ dec->perm()->axis(loco::FeatureAxis::Width) = 3;
+ }
+
+ feature_dec->decoder(std::move(dec));
+}
+
+bool canonicalize_maxpool2d(loco::Graph *graph, moco::tf::TFMaxPool *node)
{
LOGGER(l);
@@ -58,31 +111,36 @@ bool canonicalize_maxpool2d(loco::Graph *graph, moco::TFMaxPool *node)
set_feature_dec(feature_dec, data_layout);
// paddata to pad
- auto input_shape = moco::node_shape(node->input());
- assert(input_shape.domain() != loco::Domain::Unknown);
+ auto pad_data = node->annot<moco::tf::PadData>();
+ assert(pad_data != nullptr);
- auto node_stride = moco::stride_of(node->strides(), node->data_layout());
- auto node_window = moco::window_of(node->ksize(), node->data_layout());
+ maxPool2d_node->pad()->top(pad_data->pad()->top());
+ maxPool2d_node->pad()->bottom(pad_data->pad()->bottom());
+ maxPool2d_node->pad()->left(pad_data->pad()->left());
+ maxPool2d_node->pad()->right(pad_data->pad()->right());
- moco::Padding2DInference infer_padding2d;
+ // windowdata to window (ksize to window)
+ auto window_data = node->annot<moco::tf::WindowData>();
+ assert(window_data != nullptr);
- infer_padding2d.padding(node->padding());
- infer_padding2d.stride(node_stride);
- infer_padding2d.window(node_window);
+ auto window = maxPool2d_node->window();
+ window->vertical(window_data->window()->vertical());
+ window->horizontal(window_data->window()->horizontal());
- auto input_feature_shape = moco::as_feature_shape(input_shape, node->data_layout());
- auto input_plane_shape = moco::make_plane_shape(input_feature_shape);
+ // stridedata to stride (strides to stride)
+ auto stride_data = node->annot<moco::tf::StrideData>();
+ assert(stride_data != nullptr);
- *maxPool2d_node->pad() = infer_padding2d(input_plane_shape);
- *maxPool2d_node->stride() = node_stride;
- *maxPool2d_node->window() = node_window;
+ auto stride = maxPool2d_node->stride();
+ stride->vertical(stride_data->stride()->vertical());
+ stride->horizontal(stride_data->stride()->horizontal());
INFO(l) << "Canonicalize TFMaxPool pad = T " << maxPool2d_node->pad()->top() << ", L "
<< maxPool2d_node->pad()->left() << ", B " << maxPool2d_node->pad()->bottom() << ", R "
<< maxPool2d_node->pad()->right() << std::endl;
// update graph
- auto node_A = node->input();
+ auto node_A = node->value();
// update connections
feature_enc->input(node_A);
@@ -102,9 +160,25 @@ namespace moco
namespace tf
{
-bool MaxPoolCanonicalizer::transform(TFMaxPool *node) const
+bool MaxPoolCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_maxpool2d(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFMaxPool *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_maxpool2d(graph, tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.h
index c58ade528..a486c4caa 100644
--- a/compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_MAXPOOL_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFMaxPool to Canonical MaxPool2D
*/
-class MaxPoolCanonicalizer : public SimpleNodeTransform<moco::TFMaxPool>
+class MaxPoolCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "MaxPoolCanonicalizer"; }
public:
- bool transform(moco::TFMaxPool *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/MaximumCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/MaximumCanonicalizer.cpp
deleted file mode 100644
index 92634d01f..000000000
--- a/compiler/moco-tf/src/Canonicalization/MaximumCanonicalizer.cpp
+++ /dev/null
@@ -1,34 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "MaximumCanonicalizer.h"
-
-#include <moco/IR/TFDialect.h>
-
-#include "TFEltwiseBinaryCanonicalzeHelper.h"
-
-namespace moco
-{
-namespace tf
-{
-
-bool MaximumCanonicalizer::transform(moco::TFMaximum *node) const
-{
- return canonicalize_eltwise_binary_node(node);
-}
-
-} // namespace tf
-} // namespace moco
diff --git a/compiler/moco-tf/src/Canonicalization/MaximumCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/MaximumCanonicalizer.h
deleted file mode 100644
index baff4d7ad..000000000
--- a/compiler/moco-tf/src/Canonicalization/MaximumCanonicalizer.h
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __MOCO_TF_MAXIMUM_CANONICALIZER_H__
-#define __MOCO_TF_MAXIMUM_CANONICALIZER_H__
-
-#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
-
-#include <loco.h>
-
-namespace moco
-{
-namespace tf
-{
-
-/**
- * @brief Convert TFMaximum to Canonical EltwiseMax
- */
-class MaximumCanonicalizer : public SimpleNodeTransform<moco::TFMaximum>
-{
-public:
- const char *name(void) const final { return "MaximumCanonicalizer"; }
-
-public:
- bool transform(moco::TFMaximum *) const final;
-};
-
-} // namespace tf
-} // namespace moco
-
-#endif // __MOCO_TF_MAXIMUM_CANONICALIZER_H__
diff --git a/compiler/moco-tf/src/Canonicalization/MeanCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/MeanCanonicalizer.cpp
deleted file mode 100644
index 69eaf7900..000000000
--- a/compiler/moco-tf/src/Canonicalization/MeanCanonicalizer.cpp
+++ /dev/null
@@ -1,31 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "MeanCanonicalizer.h"
-#include "TFReduceCanonicalzeHelper.h"
-
-namespace moco
-{
-namespace tf
-{
-
-bool MeanCanonicalizer::transform(moco::TFMean *node) const
-{
- return canonicalize_reduce_node(node);
-}
-
-} // namespace tf
-} // namespace moco
diff --git a/compiler/moco-tf/src/Canonicalization/MeanCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/MeanCanonicalizer.h
deleted file mode 100644
index 469d7e3cd..000000000
--- a/compiler/moco-tf/src/Canonicalization/MeanCanonicalizer.h
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __MOCO_TF_MEAN_CANONICALIZER_H__
-#define __MOCO_TF_MEAN_CANONICALIZER_H__
-
-#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
-
-#include <loco.h>
-
-namespace moco
-{
-namespace tf
-{
-
-/**
- * @brief Canonicalize TF-dialect TFMean into canonical TensorReduce(Mean) node
- */
-class MeanCanonicalizer : public SimpleNodeTransform<moco::TFMean>
-{
-public:
- const char *name(void) const final { return "MeanCanonicalizer"; }
-
-public:
- bool transform(moco::TFMean *) const final;
-};
-
-} // namespace tf
-} // namespace moco
-
-#endif // __MOCO_TF_MEAN_CANONICALIZER_H__
diff --git a/compiler/moco-tf/src/Canonicalization/MulCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/MulCanonicalizer.cpp
index d02f71361..78d0ebc48 100644
--- a/compiler/moco-tf/src/Canonicalization/MulCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/MulCanonicalizer.cpp
@@ -16,7 +16,8 @@
#include "MulCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
#include "TFEltwiseBinaryCanonicalzeHelper.h"
@@ -25,9 +26,25 @@ namespace moco
namespace tf
{
-bool MulCanonicalizer::transform(moco::TFMul *node) const
+bool MulCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_eltwise_binary_node(node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFMul *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_eltwise_binary_node(tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/MulCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/MulCanonicalizer.h
index 480eec700..680f4c315 100644
--- a/compiler/moco-tf/src/Canonicalization/MulCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/MulCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_MUL_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFMul to Canonical EltwiseMul
*/
-class MulCanonicalizer : public SimpleNodeTransform<moco::TFMul>
+class MulCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "MulCanonicalizer"; }
public:
- bool transform(moco::TFMul *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/PadCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/PadCanonicalizer.cpp
deleted file mode 100644
index 36136aed4..000000000
--- a/compiler/moco-tf/src/Canonicalization/PadCanonicalizer.cpp
+++ /dev/null
@@ -1,100 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "PadCanonicalizer.h"
-
-#include <moco/IR/TFDialect.h>
-
-#include "loco/Service/TypeInference.h"
-
-#include <stdex/Memory.h>
-
-namespace
-{
-
-bool canonicalize_pad(loco::Graph *graph, moco::TFPad *node)
-{
- /**
- * @note This will replace TFPad node with Canonical TensorConstantPad
- *
- * Before
- * input --- TFPad -- C
- * paddings --/
- * After
- * paddings ------- TFPad --
- * /
- * input ----------- TensorConstantPad -- C
- * ConstGen --------/
- * Where
- * input : input of TFPad
- * paddings : paddings of TFPad. it becomes TensorConstantPad's attribute.
- * C : a node that uses TFPad as an input. TFPad is disconnected from C.
- * ConstGen : constant value of Pad. TFPad has zero value by default.
- */
-
- auto pad_node = graph->nodes()->create<loco::TensorConstantPad>();
-
- auto constant_node = graph->nodes()->create<loco::ConstGen>();
-
- auto input_node = node->input();
- // TODO: support other dtype.
- assert(loco::dtype_get(input_node) == loco::DataType::FLOAT32);
- constant_node->dtype(loco::DataType::FLOAT32);
- // TODO: constant node changes to scalar when it is implemented.
- constant_node->shape({1});
- constant_node->size<loco::DataType::FLOAT32>(1);
- constant_node->at<loco::DataType::FLOAT32>(0) = 0.0f;
-
- auto const_paddings_node = loco::must_cast<loco::ConstGen *>(node->paddings());
- // TODO: support S64 type.
- assert(const_paddings_node->dtype() == loco::DataType::S32);
- assert(const_paddings_node->rank() == 2);
- assert(const_paddings_node->dim(1).value() == 2);
-
- auto padding = pad_node->padding();
- uint32_t padding_rank = const_paddings_node->dim(0).value();
- padding->rank(padding_rank);
-
- for (uint32_t i = 0; i < padding_rank; i++)
- {
- padding->front(i) = const_paddings_node->at<loco::DataType::S32>(i << 1);
- padding->back(i) = const_paddings_node->at<loco::DataType::S32>((i << 1) + 1);
- }
-
- // update connections
- pad_node->input(input_node);
- pad_node->constant(constant_node);
-
- // replace node
- replace(node).with(pad_node);
-
- return true;
-}
-
-} // namespace
-
-namespace moco
-{
-namespace tf
-{
-
-bool PadCanonicalizer::transform(TFPad *node) const
-{
- return canonicalize_pad(node->graph(), node);
-}
-
-} // namespace tf
-} // namespace moco
diff --git a/compiler/moco-tf/src/Canonicalization/PlaceholderCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/PlaceholderCanonicalizer.cpp
deleted file mode 100644
index f568e909f..000000000
--- a/compiler/moco-tf/src/Canonicalization/PlaceholderCanonicalizer.cpp
+++ /dev/null
@@ -1,102 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "PlaceholderCanonicalizer.h"
-
-#include <moco/IR/TFDialect.h>
-
-#include <moco/Names.h>
-#include <moco/Log.h>
-
-namespace
-{
-
-bool canonicalize_placeholder(loco::Graph *graph, moco::TFPlaceholder *node)
-{
- LOGGER(l);
-
- /**
- * @note This will replace TFPlaceholder node with Canonical Pull
- *
- * Before
- * TFPlaceholder -- C
- *
- * After
- * TFPlaceholder -
- * Pull -- C
- *
- * Where
- * C : a node that uses TFPlaceholder as an input
- * TFPlaceholder is disconnected from other nodes
- */
-
- INFO(l) << "PlaceholderCanonicalizer begin";
-
- auto pull_node = graph->nodes()->create<loco::Pull>();
-
- // copy properties
- auto dtype = node->dtype();
- pull_node->dtype(dtype);
-
- auto rank = node->rank();
-
- if (rank == 0)
- {
- // This routine implements a workaround that converts a scalar constant (rank-0 tensor)
- // into a rank-1 tensor of shape [1].
- //
- // TODO Revise this implementation later
- pull_node->rank(1);
- pull_node->dim(0) = 1;
- }
- else
- {
- pull_node->rank(rank);
-
- for (uint32_t r = 0; r < rank; ++r)
- {
- if (node->dim(r).known())
- pull_node->dim(r) = node->dim(r);
- else
- pull_node->dim(r).unset();
- }
- }
-
- // set loco::Pull GraphInputIndex
- pull_node->index(moco::index(node));
-
- // update graph
- replace(node).with(pull_node);
-
- INFO(l) << "PlaceholderCanonicalizer done";
-
- return true;
-}
-
-} // namespace
-
-namespace moco
-{
-namespace tf
-{
-
-bool PlaceholderCanonicalizer::transform(TFPlaceholder *node) const
-{
- return canonicalize_placeholder(node->graph(), node);
-}
-
-} // namespace tf
-} // namespace moco
diff --git a/compiler/moco-tf/src/Canonicalization/PlaceholderCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/PlaceholderCanonicalizer.h
deleted file mode 100644
index 66eafe6af..000000000
--- a/compiler/moco-tf/src/Canonicalization/PlaceholderCanonicalizer.h
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __MOCO_TF_PLACEHOLDER_CANONICALIZER_H__
-#define __MOCO_TF_PLACEHOLDER_CANONICALIZER_H__
-
-#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/Nodes/TFPlaceholder.h>
-
-namespace moco
-{
-namespace tf
-{
-
-/**
- * @brief Convert TFPlaceholder to Canonical Pull
- *
- * @note GraphInputIndex is copied to Pull
- */
-class PlaceholderCanonicalizer : public SimpleNodeTransform<::moco::TFPlaceholder>
-{
-public:
- const char *name(void) const final { return "PlaceholderCanonicalizer"; }
-
-public:
- bool transform(moco::TFPlaceholder *) const final;
-};
-
-} // namespace tf
-} // namespace moco
-
-#endif // __MOCO_TF_PLACEHOLDER_CANONICALIZER_H__
diff --git a/compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.cpp
index a448d85fa..9ad15150a 100644
--- a/compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.cpp
@@ -16,7 +16,8 @@
#include "RealDivCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
#include "TFEltwiseBinaryCanonicalzeHelper.h"
@@ -25,9 +26,25 @@ namespace moco
namespace tf
{
-bool RealDivCanonicalizer::transform(moco::TFRealDiv *node) const
+bool RealDivCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_eltwise_binary_node(node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFRealDiv *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_eltwise_binary_node(tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.h
index 76e1bd377..8e6953396 100644
--- a/compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_REALDIV_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFRealDiv to Canonical EltwiseDiv
*/
-class RealDivCanonicalizer : public SimpleNodeTransform<moco::TFRealDiv>
+class RealDivCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "RealDivCanonicalizer"; }
public:
- bool transform(moco::TFRealDiv *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.cpp
index c53a880a8..07657244b 100644
--- a/compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.cpp
@@ -16,14 +16,17 @@
#include "Relu6Canonicalizer.h"
-#include <moco/IR/TFDialect.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
#include <stdex/Memory.h>
namespace
{
-bool canonicalize_relu6(loco::Graph *graph, moco::TFRelu6 *node)
+bool canonicalize_relu6(loco::Graph *graph, moco::tf::TFRelu6 *node)
{
/**
* @note This will replace TFRelu6 node with Canonical ReLU6
@@ -61,9 +64,25 @@ namespace moco
namespace tf
{
-bool Relu6Canonicalizer::transform(TFRelu6 *node) const
+bool Relu6Canonicalizer::run(loco::Graph *graph)
{
- return canonicalize_relu6(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFRelu6 *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_relu6(graph, tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.h b/compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.h
index d8ad5db8e..aa1580f28 100644
--- a/compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_RELU6_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFRelu6 to Canonical ReLU6
*/
-class Relu6Canonicalizer : public SimpleNodeTransform<moco::TFRelu6>
+class Relu6Canonicalizer : public Transform
{
public:
const char *name(void) const final { return "Relu6Canonicalizer"; }
public:
- bool transform(moco::TFRelu6 *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp
index 7965dc931..20cd0bab9 100644
--- a/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp
@@ -16,14 +16,17 @@
#include "ReluCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
#include <stdex/Memory.h>
namespace
{
-bool canonicalize_relu(loco::Graph *graph, moco::TFRelu *node)
+bool canonicalize_relu(loco::Graph *graph, moco::tf::TFRelu *node)
{
/**
* @note This will replace TFRelu node with Canonical ReLU
@@ -61,9 +64,25 @@ namespace moco
namespace tf
{
-bool ReluCanonicalizer::transform(TFRelu *node) const
+bool ReluCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_relu(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFRelu *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_relu(graph, tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.h
index e27abe158..97adba308 100644
--- a/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_RELU_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFRelu to Canonical ReLU
*/
-class ReluCanonicalizer : public SimpleNodeTransform<moco::TFRelu>
+class ReluCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "ReluCanonicalizer"; }
public:
- bool transform(moco::TFRelu *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.cpp
index b944568e0..3771d549a 100644
--- a/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.cpp
@@ -16,11 +16,11 @@
#include "ReshapeCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodeImpl.h"
#include <moco/Log.h>
#include <plier/tf/Convert.h>
-#include <oops/UserExn.h>
#include <cassert>
@@ -31,7 +31,7 @@ using plier::tf::DataLayout;
/**
* @brief Check whether given 'new shape' arg is a fixed shape input for Reshape
*
- * ConstNode can be moco::TFConst or loco::ConstGen
+ * ConstNode can be moco::tf::TFConst or loco::ConstGen
*/
template <typename ConstNode> bool is_fixed_shape_input(ConstNode *const_shape_input)
{
@@ -54,16 +54,13 @@ template <typename ConstNode> bool is_fixed_shape_input(ConstNode *const_shape_i
// has wildcard dimension, i.e. dynamic reshape
return false;
}
- if (!(shape_dim >= 1))
- {
- throw oops::UserExn("New shape of Reshape has invalid dimension");
- }
+ assert(shape_dim >= 1 && "Unknown behavior: New shape of Reshape has invalid dimension");
}
return true;
}
/// @note Currently only supports to canonicalize Fixed Reshape
-bool canonicalize_reshape(loco::Graph *graph, moco::TFReshape *node)
+bool canonicalize_reshape(loco::Graph *graph, moco::tf::TFReshape *node)
{
LOGGER(l);
INFO(l) << "TFNodeCanonicalize TFReshape begin";
@@ -102,17 +99,14 @@ bool canonicalize_reshape(loco::Graph *graph, moco::TFReshape *node)
// Supports 2 cases for Reshape's shape input:
// TF-dialect TFConst or Canonical ConstGen
loco::Node *shape_input = node->shape();
- auto tfconst_shape_input = dynamic_cast<moco::TFConst *>(shape_input);
+ auto tfconst_shape_input = dynamic_cast<moco::tf::TFConst *>(shape_input);
auto constgen_shape_input = dynamic_cast<loco::ConstGen *>(shape_input);
if (tfconst_shape_input)
{
// Only support fixed reshape
// TODO support dynamic reshape
- if (!(is_fixed_shape_input(tfconst_shape_input)))
- {
- throw oops::UserExn("Supports only fixed reshape", node->name());
- }
+ assert(is_fixed_shape_input(tfconst_shape_input));
auto rank = tfconst_shape_input->dim(0).value();
fixed_reshape->rank(rank);
@@ -124,10 +118,7 @@ bool canonicalize_reshape(loco::Graph *graph, moco::TFReshape *node)
else if (constgen_shape_input)
{
// ditto
- if (!(is_fixed_shape_input(constgen_shape_input)))
- {
- throw oops::UserExn("Supports only fixed reshape", node->name());
- }
+ assert(is_fixed_shape_input(constgen_shape_input));
auto rank = constgen_shape_input->dim(0).value();
fixed_reshape->rank(rank);
@@ -139,7 +130,7 @@ bool canonicalize_reshape(loco::Graph *graph, moco::TFReshape *node)
else
{
// TODO support dynamic reshape from not const node
- throw oops::UserExn("Supports only const node as input shape", node->name());
+ throw std::runtime_error("ReshapeCanonicalizer: only support const node as input shape");
}
// replace
@@ -160,9 +151,25 @@ namespace moco
namespace tf
{
-bool ReshapeCanonicalizer::transform(TFReshape *node) const
+bool ReshapeCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_reshape(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_reshape = dynamic_cast<moco::tf::TFReshape *>(node);
+ if (tf_reshape != nullptr)
+ {
+ if (canonicalize_reshape(graph, tf_reshape))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.h
index 1a792024e..c9deee7a4 100644
--- a/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_RESHAPE_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFReshape to Canonical Reshape
*/
-class ReshapeCanonicalizer : public SimpleNodeTransform<moco::TFReshape>
+class ReshapeCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "ReshapeCanonicalizer"; }
public:
- bool transform(moco::TFReshape *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp
index c31dbf6d6..b4fbcac3c 100644
--- a/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp
@@ -16,25 +16,29 @@
#include "RsqrtCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
-#include <moco/Support/TFShapeInferenceHelper.h>
+#include "Annotations/ShapeInferenceData.h"
+
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
#include <moco/Log.h>
#include <loco/Service/TypeInference.h>
#include <stdex/Memory.h>
-#include <oops/UserExn.h>
namespace
{
template <typename T>
-bool prepare_const_gen(loco::ConstGen *const_node, const loco::TensorShape &tensorshape, T value);
+void prepare_const_gen(loco::ConstGen *const_node, const moco::tf::ShapeInferenceData *shapedata,
+ T value);
template <>
-bool prepare_const_gen<float>(loco::ConstGen *const_node, const loco::TensorShape &tensorshape,
- float value)
+void prepare_const_gen<float>(loco::ConstGen *const_node,
+ const moco::tf::ShapeInferenceData *shapedata, float value)
{
LOGGER(l);
@@ -43,18 +47,18 @@ bool prepare_const_gen<float>(loco::ConstGen *const_node, const loco::TensorShap
auto dtype = loco::DataType::FLOAT32;
const_node->dtype(dtype);
- auto rank = tensorshape.rank();
+ auto rank = shapedata->rank();
const_node->rank(rank);
for (uint32_t r = 0; r < rank; ++r)
{
- if (tensorshape.dim(r).known())
- const_node->dim(r) = tensorshape.dim(r);
+ if (shapedata->dim(r).known())
+ const_node->dim(r) = shapedata->dim(r);
else
- return false;
+ throw std::runtime_error("Cannot handle unknown shape");
- assert(tensorshape.dim(r).value() > 0);
+ assert(shapedata->dim(r).value() > 0);
- const_num_elements *= tensorshape.dim(r).value();
+ const_num_elements *= shapedata->dim(r).value();
}
INFO(l) << "prepare_const_gen : Elements = " << const_num_elements;
@@ -64,11 +68,9 @@ bool prepare_const_gen<float>(loco::ConstGen *const_node, const loco::TensorShap
{
const_node->at<loco::DataType::FLOAT32>(i) = value;
}
-
- return true;
}
-bool canonicalize_rsqrt(loco::Graph *graph, moco::TFRsqrt *node)
+bool canonicalize_rsqrt(loco::Graph *graph, moco::tf::TFRsqrt *node)
{
/**
* @note This will replace TFRsqrt node with Canonical EltwiseSqrt + EltwiseRealDiv
@@ -89,14 +91,13 @@ bool canonicalize_rsqrt(loco::Graph *graph, moco::TFRsqrt *node)
* TFRsqrt is converted to 1 / EltwiseSqrt
*/
- auto nodeshape = moco::node_shape(node);
- if (nodeshape.domain() == loco::Domain::Unknown)
+ auto rsqrt_shapedata = node->annot<moco::tf::ShapeInferenceData>();
+ if (rsqrt_shapedata == nullptr)
{
// We need this shape information
assert(false); // this shouldn't happen, let's add an alarm
return false;
}
- auto tensorshape = nodeshape.as<loco::TensorShape>();
if (!loco::dtype_known(node))
{
@@ -113,12 +114,11 @@ bool canonicalize_rsqrt(loco::Graph *graph, moco::TFRsqrt *node)
switch (dtype)
{
case loco::DataType::FLOAT32:
- if (!prepare_const_gen<float>(const_node, tensorshape, 1.0f))
- throw oops::UserExn("Cannot handle unknown shape", node->name());
+ prepare_const_gen<float>(const_node, rsqrt_shapedata, 1.0f);
break;
default:
- throw oops::UserExn("Unsupported data type", node->name());
+ throw std::runtime_error("NYI for this DataType");
}
auto node_A = node->x();
@@ -141,9 +141,25 @@ namespace moco
namespace tf
{
-bool RsqrtCanonicalizer::transform(TFRsqrt *node) const
+bool RsqrtCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_rsqrt(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFRsqrt *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_rsqrt(graph, tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.h
index 7fd4ff697..a58c0adcb 100644
--- a/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_RSQRT_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFRsqrt to Canonical EltwiseDiv + EltwiseSqrt
*/
-class RsqrtCanonicalizer : public SimpleNodeTransform<moco::TFRsqrt>
+class RsqrtCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "RsqrtCanonicalizer"; }
public:
- bool transform(moco::TFRsqrt *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp
index 98af7b693..3b5043fa7 100644
--- a/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp
@@ -16,15 +16,19 @@
#include "SoftmaxCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
-#include <moco/Support/TFShapeInferenceHelper.h>
+#include "Annotations/ShapeInferenceData.h"
+
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
#include <moco/Log.h>
namespace
{
-bool canonicalize_softmax(loco::Graph *graph, moco::TFSoftmax *node)
+bool canonicalize_softmax(loco::Graph *graph, moco::tf::TFSoftmax *node)
{
LOGGER(l);
@@ -42,11 +46,12 @@ bool canonicalize_softmax(loco::Graph *graph, moco::TFSoftmax *node)
* In ---- TensorSoftmax ----- Out(s)
*/
- auto nodeshape = moco::node_shape(node);
+ auto softmax_shape = node->annot<moco::tf::ShapeInferenceData>();
+
// Canonicalization into TensorSoftmax is valid when softmax has shape info
- assert(nodeshape.domain() != loco::Domain::Unknown);
+ assert(softmax_shape);
- auto softmax_tensor_shape = nodeshape.as<loco::TensorShape>();
+ auto softmax_tensor_shape = softmax_shape->tensor_shape();
// Create loco node to replace
auto softmax = graph->nodes()->create<loco::TensorSoftmax>();
@@ -69,9 +74,25 @@ namespace moco
namespace tf
{
-bool SoftmaxCanonicalizer::transform(TFSoftmax *node) const
+bool SoftmaxCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_softmax(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_softmax = dynamic_cast<moco::tf::TFSoftmax *>(node);
+ if (tf_softmax != nullptr)
+ {
+ if (canonicalize_softmax(graph, tf_softmax))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.h
index ebaf04cfe..6debf4194 100644
--- a/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_SOFTMAx_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Canonicalize TF-dialect TFSoftmax into canonical Softmax node
*/
-class SoftmaxCanonicalizer : public SimpleNodeTransform<moco::TFSoftmax>
+class SoftmaxCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "SoftmaxCanonicalizer"; }
public:
- bool transform(moco::TFSoftmax *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.cpp
index 89b9b8a44..347265121 100644
--- a/compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.cpp
@@ -16,12 +16,15 @@
#include "SqrtCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
namespace
{
-bool canonicalize_sqrt(loco::Graph *graph, moco::TFSqrt *node)
+bool canonicalize_sqrt(loco::Graph *graph, moco::tf::TFSqrt *node)
{
/**
* @note This will replace TFSqrt node with Canonical EltwiseSqrt
@@ -59,9 +62,25 @@ namespace moco
namespace tf
{
-bool SqrtCanonicalizer::transform(TFSqrt *node) const
+bool SqrtCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_sqrt(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFSqrt *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_sqrt(graph, tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.h
index 3f7ffead8..b4e6da09a 100644
--- a/compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_SQRT_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFsqrt to Canonical EltwiseSqrt
*/
-class SqrtCanonicalizer : public SimpleNodeTransform<moco::TFSqrt>
+class SqrtCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "SqrtCanonicalizer"; }
public:
- bool transform(moco::TFSqrt *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.cpp
new file mode 100644
index 000000000..4eb7a7217
--- /dev/null
+++ b/compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.cpp
@@ -0,0 +1,115 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "SquaredDifferenceCanonicalizer.h"
+
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
+
+#include <loco/IR/NodeShape.h>
+#include <loco/Service/ShapeInference.h>
+
+#include <stdex/Memory.h>
+
+namespace
+{
+
+bool canonicalize_sqdiff(loco::Graph *graph, moco::tf::TFSquaredDifference *node)
+{
+ /**
+ * @note This will replace TFSquaredDifference node with Canonical EltwiseSub and EltwiseMul
+ *
+ * Before
+ * A --- TFSquaredDifference -- C
+ * B --/
+ * After
+ * A --- TFSquaredDifference --
+ * B --/
+ * A --- EltwiseSub == EltwiseMul -- C
+ * B --/
+ * Where
+ * A : x of TFSquaredDifference
+ * B : y of TFSquaredDifference
+ * C : a node that uses TFSquaredDifference as an input
+ * TFSquaredDifference is disconnected from C
+ * A and B are drawn multiple times to simplify the diagram
+ */
+
+ auto node_A = node->x();
+ auto node_B = node->y();
+
+ if (!loco::shape_known(node_A) || !loco::shape_known(node_B))
+ {
+ // Wait for shape inference
+ return false;
+ }
+
+ const auto &x_shape = loco::shape_get(node_A);
+ const auto &y_shape = loco::shape_get(node_B);
+
+ if (!(x_shape == y_shape))
+ {
+ // TODO support broadcast
+ return false;
+ }
+
+ auto sub_node = graph->nodes()->create<loco::EltwiseSub>();
+ auto mul_node = graph->nodes()->create<loco::EltwiseMul>();
+
+ // update connections
+ sub_node->lhs(node_A);
+ sub_node->rhs(node_B);
+ mul_node->lhs(sub_node);
+ mul_node->rhs(sub_node);
+
+ // replace node
+ replace(node).with(mul_node);
+
+ return true;
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+bool SquaredDifferenceCanonicalizer::run(loco::Graph *graph)
+{
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFSquaredDifference *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_sqdiff(graph, tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/Canonicalization/PadCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.h
index 64bb6041a..afd65be32 100644
--- a/compiler/moco-tf/src/Canonicalization/PadCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.h
@@ -14,13 +14,10 @@
* limitations under the License.
*/
-#ifndef __MOCO_TF_PAD_CANONICALIZER_H__
-#define __MOCO_TF_PAD_CANONICALIZER_H__
+#ifndef __MOCO_TF_SQUAREDDIFFERENCE_CANONICALIZER_H__
+#define __MOCO_TF_SQUAREDDIFFERENCE_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
namespace moco
{
@@ -28,18 +25,18 @@ namespace tf
{
/**
- * @brief Convert TFPad to Canonical TensorConstantPad
+ * @brief Convert TFSquaredDifference to Canonical EltwiseSub and EltwiseMul
*/
-class PadCanonicalizer final : public SimpleNodeTransform<moco::TFPad>
+class SquaredDifferenceCanonicalizer final : public Transform
{
public:
- const char *name(void) const final { return "PadCanonicalizer"; }
+ const char *name(void) const final { return "SquaredDifferenceCanonicalizer"; }
public:
- bool transform(moco::TFPad *) const final;
+ bool run(loco::Graph *graph) final;
};
} // namespace tf
} // namespace moco
-#endif // __MOCO_TF_PAD_CANONICALIZER_H__
+#endif // __MOCO_TF_SQUAREDDIFFERENCE_CANONICALIZER_H__
diff --git a/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp
index f5b991206..a3fcc3b47 100644
--- a/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp
@@ -16,15 +16,19 @@
#include "SqueezeCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
-#include <moco/Support/TFShapeInferenceHelper.h>
+#include "Annotations/ShapeInferenceData.h"
+
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
#include <moco/Log.h>
namespace
{
-bool canonicalize_squeeze_to_reshape(loco::Graph *graph, moco::TFSqueeze *node)
+bool canonicalize_squeeze_to_reshape(loco::Graph *graph, moco::tf::TFSqueeze *node)
{
LOGGER(l);
@@ -42,12 +46,12 @@ bool canonicalize_squeeze_to_reshape(loco::Graph *graph, moco::TFSqueeze *node)
* In ---- FixedReshape ----- Out(s)
*/
- auto nodeshape = moco::node_shape(node);
+ auto squeeze_shape = node->annot<moco::tf::ShapeInferenceData>();
// canonicalize into FixedReshape is valid when squeeze has shape info
// TODO Support general Squeeze case
- assert(nodeshape.domain() != loco::Domain::Unknown);
+ assert(squeeze_shape);
- auto squeeze_tensor_shape = nodeshape.as<loco::TensorShape>();
+ auto squeeze_tensor_shape = squeeze_shape->tensor_shape();
// Create loco node to replace
auto reshape = graph->nodes()->create<loco::FixedReshape>();
@@ -77,9 +81,25 @@ namespace moco
namespace tf
{
-bool SqueezeCanonicalizer::transform(TFSqueeze *node) const
+bool SqueezeCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_squeeze_to_reshape(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_squeeze = dynamic_cast<moco::tf::TFSqueeze *>(node);
+ if (tf_squeeze != nullptr)
+ {
+ if (canonicalize_squeeze_to_reshape(graph, tf_squeeze))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.h
index 28a1442bd..dc5b2d7b1 100644
--- a/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_SQUEEZE_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -34,13 +31,13 @@ namespace tf
*
* @note There is no canonical Squeeze node
*/
-class SqueezeCanonicalizer : public SimpleNodeTransform<moco::TFSqueeze>
+class SqueezeCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "SqueezeCanonicalizer"; }
public:
- bool transform(moco::TFSqueeze *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp
index 574fa3993..a52af05a5 100644
--- a/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp
@@ -16,14 +16,17 @@
#include "StopGradientCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
#include <moco/Log.h>
namespace
{
-bool canonicalize_stopgradient(loco::Graph *graph, moco::TFStopGradient *node)
+bool canonicalize_stopgradient(loco::Graph *graph, moco::tf::TFStopGradient *node)
{
LOGGER(l);
@@ -62,9 +65,25 @@ namespace moco
namespace tf
{
-bool StopGradientCanonicalizer::transform(TFStopGradient *node) const
+bool StopGradientCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_stopgradient(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_stopgradient = dynamic_cast<moco::tf::TFStopGradient *>(node);
+ if (tf_stopgradient != nullptr)
+ {
+ if (canonicalize_stopgradient(graph, tf_stopgradient))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.h
index 6a17728a6..a23a801f0 100644
--- a/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_STOPGRADIENT_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Canonicalize TF-dialect TFStopGradient into canonical Forward node
*/
-class StopGradientCanonicalizer : public SimpleNodeTransform<moco::TFStopGradient>
+class StopGradientCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "StopGradientCanonicalizer"; }
public:
- bool transform(moco::TFStopGradient *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/SubCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/SubCanonicalizer.cpp
index c518b7d64..21f4210eb 100644
--- a/compiler/moco-tf/src/Canonicalization/SubCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/SubCanonicalizer.cpp
@@ -16,7 +16,8 @@
#include "SubCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
#include "TFEltwiseBinaryCanonicalzeHelper.h"
@@ -25,9 +26,25 @@ namespace moco
namespace tf
{
-bool SubCanonicalizer::transform(moco::TFSub *node) const
+bool SubCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_eltwise_binary_node(node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFSub *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_eltwise_binary_node(tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/SubCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/SubCanonicalizer.h
index f715cc86c..4ab470685 100644
--- a/compiler/moco-tf/src/Canonicalization/SubCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/SubCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_SUB_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFSub to Canonical EltwiseSub
*/
-class SubCanonicalizer : public SimpleNodeTransform<moco::TFSub>
+class SubCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "SubCanonicalizer"; }
public:
- bool transform(moco::TFSub *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/TFPushCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/TFPushCanonicalizer.cpp
deleted file mode 100644
index 081e0e5f9..000000000
--- a/compiler/moco-tf/src/Canonicalization/TFPushCanonicalizer.cpp
+++ /dev/null
@@ -1,74 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "TFPushCanonicalizer.h"
-
-#include <moco/IR/TFDialect.h>
-
-#include <stdex/Memory.h>
-
-namespace
-{
-
-bool canonicalize_push(loco::Graph *graph, moco::TFPush *node)
-{
- /**
- * @note This will replace TFRelu node with Canonical ReLU
- *
- * Before
- * A --- TFPush
- * After
- * +- TFPush
- * |
- * A -+- Push
- *
- * Where
- * A : from of TFPush
- * TFPush will have no GraphOutputIndex
- * Push will have GraphOutputIndex that from TFPush
- */
-
- auto push_node = graph->nodes()->create<loco::Push>();
-
- auto node_A = node->from();
-
- // update connections
- push_node->from(node_A);
-
- // update output index
- push_node->index(node->index());
- node->index_reset();
-
- // replace node
- replace(node).with(push_node);
-
- return true;
-}
-
-} // namespace
-
-namespace moco
-{
-namespace tf
-{
-
-bool TFPushCanonicalizer::transform(TFPush *node) const
-{
- return canonicalize_push(node->graph(), node);
-}
-
-} // namespace tf
-} // namespace moco
diff --git a/compiler/moco-tf/src/Canonicalization/TFPushCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/TFPushCanonicalizer.h
deleted file mode 100644
index 569a71f82..000000000
--- a/compiler/moco-tf/src/Canonicalization/TFPushCanonicalizer.h
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __MOCO_TF_PUSH_CANONICALIZER_H__
-#define __MOCO_TF_PUSH_CANONICALIZER_H__
-
-#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
-
-#include <loco.h>
-
-namespace moco
-{
-namespace tf
-{
-
-/**
- * @brief Convert TFPush to Canonical Push
- */
-class TFPushCanonicalizer : public SimpleNodeTransform<moco::TFPush>
-{
-public:
- const char *name(void) const final { return "TFPushCanonicalizer"; }
-
-public:
- bool transform(moco::TFPush *) const final;
-};
-
-} // namespace tf
-} // namespace moco
-
-#endif // __MOCO_TF_PUSH_CANONICALIZER_H__
diff --git a/compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.cpp
index 3f48a50fc..9b7b073e1 100644
--- a/compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.cpp
@@ -16,14 +16,17 @@
#include "TanhCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
#include <stdex/Memory.h>
namespace
{
-bool canonicalize_tanh(loco::Graph *graph, moco::TFTanh *node)
+bool canonicalize_tanh(loco::Graph *graph, moco::tf::TFTanh *node)
{
/**
* @note This will replace TFTanh node with Canonical Tanh
@@ -61,9 +64,25 @@ namespace moco
namespace tf
{
-bool TanhCanonicalizer::transform(TFTanh *node) const
+bool TanhCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_tanh(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFTanh *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_tanh(graph, tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.h
index af5e79fb5..cf566a4d4 100644
--- a/compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_TANH_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFTanh to Canonical Tanh
*/
-class TanhCanonicalizer : public SimpleNodeTransform<moco::TFTanh>
+class TanhCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "TanhCanonicalizer"; }
public:
- bool transform(moco::TFTanh *) const override;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf