summaryrefslogtreecommitdiff
path: root/compiler/moco-tf/src/Canonicalization
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/moco-tf/src/Canonicalization')
-rw-r--r--compiler/moco-tf/src/Canonicalization/AddCanonicalizer.cpp24
-rw-r--r--compiler/moco-tf/src/Canonicalization/AddCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.cpp116
-rw-r--r--compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.cpp29
-rw-r--r--compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.cpp86
-rw-r--r--compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.cpp59
-rw-r--r--compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/Conv2DBackpropInputCanonicalizer.cpp371
-rw-r--r--compiler/moco-tf/src/Canonicalization/Conv2DBackpropInputCanonicalizer.h45
-rw-r--r--compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.cpp110
-rw-r--r--compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.cpp105
-rw-r--r--compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.cpp29
-rw-r--r--compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.cpp114
-rw-r--r--compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/MaximumCanonicalizer.cpp34
-rw-r--r--compiler/moco-tf/src/Canonicalization/MaximumCanonicalizer.h47
-rw-r--r--compiler/moco-tf/src/Canonicalization/MeanCanonicalizer.cpp31
-rw-r--r--compiler/moco-tf/src/Canonicalization/MeanCanonicalizer.h47
-rw-r--r--compiler/moco-tf/src/Canonicalization/MulCanonicalizer.cpp23
-rw-r--r--compiler/moco-tf/src/Canonicalization/MulCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/PadCanonicalizer.cpp100
-rw-r--r--compiler/moco-tf/src/Canonicalization/PadCanonicalizer.h (renamed from compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.h)17
-rw-r--r--compiler/moco-tf/src/Canonicalization/PlaceholderCanonicalizer.cpp102
-rw-r--r--compiler/moco-tf/src/Canonicalization/PlaceholderCanonicalizer.h47
-rw-r--r--compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.cpp23
-rw-r--r--compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.cpp27
-rw-r--r--compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp27
-rw-r--r--compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.cpp47
-rw-r--r--compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp62
-rw-r--r--compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp37
-rw-r--r--compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.cpp27
-rw-r--r--compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.cpp115
-rw-r--r--compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp36
-rw-r--r--compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp27
-rw-r--r--compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/SubCanonicalizer.cpp23
-rw-r--r--compiler/moco-tf/src/Canonicalization/SubCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/TFPushCanonicalizer.cpp74
-rw-r--r--compiler/moco-tf/src/Canonicalization/TFPushCanonicalizer.h47
-rw-r--r--compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.cpp27
-rw-r--r--compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.h7
55 files changed, 1330 insertions, 952 deletions
diff --git a/compiler/moco-tf/src/Canonicalization/AddCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/AddCanonicalizer.cpp
index ef82f3dab..8028a870c 100644
--- a/compiler/moco-tf/src/Canonicalization/AddCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/AddCanonicalizer.cpp
@@ -16,8 +16,8 @@
#include "AddCanonicalizer.h"
-#include "Dialect/TFDialect.h"
-#include "Dialect/TFNodes.h"
+#include <moco/IR/TFDialect.h>
+#include <moco/IR/TFNodes.h>
#include "TFEltwiseBinaryCanonicalzeHelper.h"
@@ -26,25 +26,9 @@ namespace moco
namespace tf
{
-bool AddCanonicalizer::run(loco::Graph *graph)
+bool AddCanonicalizer::transform(TFAdd *node) const
{
- auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
- bool changed = false;
-
- for (auto node : active_nodes)
- {
- if (node->dialect() == TFDialect::get())
- {
- auto tf_node = dynamic_cast<moco::tf::TFAdd *>(node);
- if (tf_node != nullptr)
- {
- if (canonicalize_eltwise_binary_node(tf_node))
- changed = true;
- }
- }
- }
-
- return changed;
+ return canonicalize_eltwise_binary_node(node);
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/AddCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/AddCanonicalizer.h
index 07b8a72de..53ba9ed58 100644
--- a/compiler/moco-tf/src/Canonicalization/AddCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/AddCanonicalizer.h
@@ -18,6 +18,9 @@
#define __MOCO_TF_ADD_CANONICALIZER_H__
#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -29,13 +32,13 @@ namespace tf
/**
* @brief Convert TFAdd to Canonical EltwiseAdd
*/
-class AddCanonicalizer : public Transform
+class AddCanonicalizer : public SimpleNodeTransform<TFAdd>
{
public:
const char *name(void) const final { return "AddCanonicalizer"; }
public:
- bool run(loco::Graph *graph) override;
+ bool transform(TFAdd *node) const final;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.cpp
index 66a71089e..e07a4f64f 100644
--- a/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.cpp
@@ -16,71 +16,19 @@
#include "AvgPoolCanonicalizer.h"
-#include "Annotations/PadData.h"
-#include "Annotations/StrideData.h"
-#include "Annotations/ShapeInferenceData.h"
-#include "Annotations/WindowData.h"
+#include <moco/IR/TFDialect.h>
+#include <moco/Support/TFShapeInferenceHelper.h>
-#include "Dialect/TFDialect.h"
-#include "Dialect/TFNodes.h"
-#include "Dialect/TFNodeVisitor.h"
-#include "Dialect/TFNodeImpl.h"
+#include "CodecHelper.h"
-#include <moco/Log.h>
-#include <plier/tf/Convert.h>
+#include <loco/IR/NodeShape.h>
-#include <stdex/Memory.h>
+#include <moco/Log.h>
namespace
{
-using plier::tf::DataLayout;
-
-void set_feature_enc(loco::FeatureEncode *feature_enc, DataLayout data_layout)
-{
- auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>();
-
- if (data_layout == DataLayout::NHWC)
- {
- enc->perm()->axis(loco::FeatureAxis::Count) = 0;
- enc->perm()->axis(loco::FeatureAxis::Height) = 1;
- enc->perm()->axis(loco::FeatureAxis::Width) = 2;
- enc->perm()->axis(loco::FeatureAxis::Depth) = 3;
- }
- else if (data_layout == DataLayout::NCHW)
- {
- enc->perm()->axis(loco::FeatureAxis::Count) = 0;
- enc->perm()->axis(loco::FeatureAxis::Depth) = 1;
- enc->perm()->axis(loco::FeatureAxis::Height) = 2;
- enc->perm()->axis(loco::FeatureAxis::Width) = 3;
- }
-
- feature_enc->encoder(std::move(enc));
-}
-
-void set_feature_dec(loco::FeatureDecode *feature_dec, DataLayout data_layout)
-{
- auto dec = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>();
-
- if (data_layout == DataLayout::NHWC)
- {
- dec->perm()->axis(loco::FeatureAxis::Count) = 0;
- dec->perm()->axis(loco::FeatureAxis::Height) = 1;
- dec->perm()->axis(loco::FeatureAxis::Width) = 2;
- dec->perm()->axis(loco::FeatureAxis::Depth) = 3;
- }
- else if (data_layout == DataLayout::NCHW)
- {
- dec->perm()->axis(loco::FeatureAxis::Count) = 0;
- dec->perm()->axis(loco::FeatureAxis::Depth) = 1;
- dec->perm()->axis(loco::FeatureAxis::Height) = 2;
- dec->perm()->axis(loco::FeatureAxis::Width) = 3;
- }
-
- feature_dec->decoder(std::move(dec));
-}
-
-bool canonicalize_avgpool2d(loco::Graph *graph, moco::tf::TFAvgPool *node)
+bool canonicalize_avgpool2d(loco::Graph *graph, moco::TFAvgPool *node)
{
LOGGER(l);
@@ -113,30 +61,24 @@ bool canonicalize_avgpool2d(loco::Graph *graph, moco::tf::TFAvgPool *node)
avgPool2d_node->convention(loco::AvgPool2D::Convention::Valid);
- // paddata to pad
- auto pad_data = node->annot<moco::tf::PadData>();
- assert(pad_data != nullptr);
+ auto value_shape = moco::node_shape(node->value());
+ assert(value_shape.domain() != loco::Domain::Unknown);
- avgPool2d_node->pad()->top(pad_data->pad()->top());
- avgPool2d_node->pad()->bottom(pad_data->pad()->bottom());
- avgPool2d_node->pad()->left(pad_data->pad()->left());
- avgPool2d_node->pad()->right(pad_data->pad()->right());
+ auto node_stride = moco::stride_of(node->strides(), node->data_layout());
+ auto node_window = moco::window_of(node->ksize(), node->data_layout());
- // windowdata to window (ksize to window)
- auto window_data = node->annot<moco::tf::WindowData>();
- assert(window_data != nullptr);
+ moco::Padding2DInference infer_padding2d;
- auto window = avgPool2d_node->window();
- window->vertical(window_data->window()->vertical());
- window->horizontal(window_data->window()->horizontal());
+ infer_padding2d.padding(node->padding());
+ infer_padding2d.stride(node_stride);
+ infer_padding2d.window(node_window);
- // stridedata to stride (strides to stride)
- auto stride_data = node->annot<moco::tf::StrideData>();
- assert(stride_data != nullptr);
+ auto input_feature_shape = moco::as_feature_shape(value_shape, node->data_layout());
+ auto input_plane_shape = moco::make_plane_shape(input_feature_shape);
- auto stride = avgPool2d_node->stride();
- stride->vertical(stride_data->stride()->vertical());
- stride->horizontal(stride_data->stride()->horizontal());
+ *avgPool2d_node->pad() = infer_padding2d(input_plane_shape);
+ *avgPool2d_node->stride() = node_stride;
+ *avgPool2d_node->window() = node_window;
INFO(l) << "Canonicalize TFAvgPool pad = T " << avgPool2d_node->pad()->top() << ", L "
<< avgPool2d_node->pad()->left() << ", B " << avgPool2d_node->pad()->bottom() << ", R "
@@ -163,25 +105,9 @@ namespace moco
namespace tf
{
-bool AvgPoolCanonicalizer::run(loco::Graph *graph)
+bool AvgPoolCanonicalizer::transform(TFAvgPool *node) const
{
- auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
- bool changed = false;
-
- for (auto node : active_nodes)
- {
- if (node->dialect() == TFDialect::get())
- {
- auto tf_node = dynamic_cast<moco::tf::TFAvgPool *>(node);
- if (tf_node != nullptr)
- {
- if (canonicalize_avgpool2d(graph, tf_node))
- changed = true;
- }
- }
- }
-
- return changed;
+ return canonicalize_avgpool2d(node->graph(), node);
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.h
index 7d7e6a80b..e9c56c868 100644
--- a/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.h
@@ -18,6 +18,9 @@
#define __MOCO_TF_AVGPOOL_CANONICALIZER_H__
#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -29,13 +32,13 @@ namespace tf
/**
* @brief Convert TFAvgPool to Canonical AvgPool2D
*/
-class AvgPoolCanonicalizer : public Transform
+class AvgPoolCanonicalizer : public SimpleNodeTransform<moco::TFAvgPool>
{
public:
const char *name(void) const final { return "AvgPoolCanonicalizer"; }
public:
- bool run(loco::Graph *graph) override;
+ bool transform(TFAvgPool *node) const final;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.cpp
index 37b660e4a..a5568ce1a 100644
--- a/compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.cpp
@@ -16,12 +16,9 @@
#include "BiasAddCanonicalizer.h"
-#include "Dialect/TFDialect.h"
-#include "Dialect/TFNodes.h"
-#include "Dialect/TFNodeVisitor.h"
-#include "Dialect/TFNodeImpl.h"
+#include <moco/IR/TFDialect.h>
-#include <moco/tf/Names.h>
+#include <moco/Names.h>
#include <moco/Log.h>
#include <plier/tf/Convert.h>
@@ -29,7 +26,7 @@ namespace
{
using plier::tf::DataLayout;
-bool canonicalize_biasadd(loco::Graph *graph, moco::tf::TFBiasAdd *node)
+bool canonicalize_biasadd(loco::Graph *graph, moco::TFBiasAdd *node)
{
LOGGER(l);
@@ -103,25 +100,9 @@ namespace moco
namespace tf
{
-bool BiasAddCanonicalizer::run(loco::Graph *graph)
+bool BiasAddCanonicalizer::transform(TFBiasAdd *node) const
{
- auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
- bool changed = false;
-
- for (auto node : active_nodes)
- {
- if (node->dialect() == TFDialect::get())
- {
- auto tf_biasadd = dynamic_cast<moco::tf::TFBiasAdd *>(node);
- if (tf_biasadd != nullptr)
- {
- if (canonicalize_biasadd(graph, tf_biasadd))
- changed = true;
- }
- }
- }
-
- return changed;
+ return canonicalize_biasadd(node->graph(), node);
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.h
index a30894708..ff4032ca9 100644
--- a/compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.h
@@ -18,6 +18,9 @@
#define __MOCO_TF_BIASADD_CANONICALIZER_H__
#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -29,13 +32,13 @@ namespace tf
/**
* @brief Convert TFBiasAdd to Canonical BiasAdd
*/
-class BiasAddCanonicalizer : public Transform
+class BiasAddCanonicalizer final : public SimpleNodeTransform<moco::TFBiasAdd>
{
public:
const char *name(void) const final { return "BiasAddCanonicalizer"; }
public:
- bool run(loco::Graph *graph) override;
+ bool transform(TFBiasAdd *node) const final;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.cpp
index e3939adb9..b59a3f3d7 100644
--- a/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.cpp
@@ -15,27 +15,39 @@
*/
#include "ConcatV2Canonicalizer.h"
-
#include "LogHelper.h"
-#include "Annotations/ConcatData.h"
-#include "Annotations/ShapeInferenceData.h"
-
-#include "Dialect/TFDialect.h"
-#include "Dialect/TFNodes.h"
-#include "Dialect/TFNodeVisitor.h"
-#include "Dialect/TFNodeImpl.h"
+#include <moco/IR/TFDialect.h>
+#include <moco/Support/TFShapeInferenceHelper.h>
#include <moco/Log.h>
+#include <loco/Service/ShapeInference.h>
+
#include <stdex/Memory.h>
+#include <oops/UserExn.h>
namespace
{
using namespace moco::tf;
-bool canonicalize_concat(loco::Graph *graph, moco::tf::TFConcatV2 *node)
+bool scalar_value(moco::TFConst *node, int32_t &ret)
+{
+ auto nodeshape = node_shape(node);
+ if (!(node->dtype() == loco::DataType::S32))
+ return false;
+
+ auto tensor_shape = nodeshape.as<loco::TensorShape>();
+ if (!(tensor_shape.rank() == 0 || tensor_shape.rank() == 1))
+ return false;
+
+ ret = node->at<loco::DataType::S32>(0);
+
+ return true;
+}
+
+bool canonicalize_concat(loco::Graph *graph, moco::TFConcatV2 *node)
{
LOGGER(l);
@@ -71,19 +83,43 @@ bool canonicalize_concat(loco::Graph *graph, moco::tf::TFConcatV2 *node)
const int num_values = node->num_values();
assert(num_values >= 2);
- // get axis value
- auto concat_data = node->annot<ConcatData>();
- assert(concat_data != nullptr);
- auto axis_value = concat_data->axis();
+ // get axis absolute value
+ auto value_a = node->values(0);
+ if (!loco::shape_known(value_a))
+ return false;
- auto shapedata = node->annot<ShapeInferenceData>();
- auto node_rank = shapedata->rank();
+ uint32_t node_rank = 0;
+ {
+ auto value_a_shape = moco::node_shape(value_a);
+ assert(value_a_shape.domain() == loco::Domain::Tensor);
+
+ auto value_a_tensor_shape = value_a_shape.as<loco::TensorShape>();
+ node_rank = value_a_tensor_shape.rank();
+ }
+ int32_t axis_value = 0;
+ {
+ // axis should be TFConst
+ auto axis_node = node->axis();
+ auto tfconst = dynamic_cast<moco::TFConst *>(axis_node);
+ if (tfconst == nullptr)
+ {
+ // TODO Check this: this error can be from TFOptimizatier.
+ throw oops::UserExn("ConcatV2 node has invalid input for axis", node->name());
+ }
+ auto result = scalar_value(tfconst, axis_value);
+ if (!result)
+ {
+ // TODO Check this: this error can be from TFOptimizatier.
+ throw oops::UserExn("ConcatV2 node has invalid input for axis", node->name());
+ }
+ }
uint32_t axis_absolute = (axis_value >= 0) ? axis_value : (int32_t)node_rank + axis_value;
INFO(l) << "canonicalize_concat axis(" << axis_absolute << "), value(" << axis_value << "), rank("
<< node_rank << ")";
+ // Convert series of TensorConcat if num_values > 2
auto concat_node = graph->nodes()->create<loco::TensorConcat>();
concat_node->lhs(node->values(0));
concat_node->rhs(node->values(1));
@@ -115,25 +151,9 @@ namespace moco
namespace tf
{
-bool ConcatV2Canonicalizer::run(loco::Graph *graph)
+bool ConcatV2Canonicalizer::transform(TFConcatV2 *node) const
{
- auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
- bool changed = false;
-
- for (auto node : active_nodes)
- {
- if (node->dialect() == TFDialect::get())
- {
- auto tf_node = dynamic_cast<moco::tf::TFConcatV2 *>(node);
- if (tf_node != nullptr)
- {
- if (canonicalize_concat(graph, tf_node))
- changed = true;
- }
- }
- }
-
- return changed;
+ return canonicalize_concat(node->graph(), node);
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.h b/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.h
index 4448ddb16..e6b471b89 100644
--- a/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.h
@@ -18,6 +18,9 @@
#define __MOCO_TF_CONCATV2_CANONICALIZER_H__
#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -29,13 +32,13 @@ namespace tf
/**
* @brief Convert TFConcatV2 to Canonical TensorConcat
*/
-class ConcatV2Canonicalizer : public Transform
+class ConcatV2Canonicalizer : public SimpleNodeTransform<moco::TFConcatV2>
{
public:
const char *name(void) const final { return "ConcatV2Canonicalizer"; }
public:
- bool run(loco::Graph *graph) override;
+ bool transform(moco::TFConcatV2 *node) const final;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.cpp
index dea97f94a..60629cd5a 100644
--- a/compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.cpp
@@ -16,18 +16,17 @@
#include "ConstCanonicalizer.h"
-#include "Dialect/TFDialect.h"
-#include "Dialect/TFNodes.h"
-#include "Dialect/TFNodeVisitor.h"
-#include "Dialect/TFNodeImpl.h"
+#include <moco/IR/TFDialect.h>
-#include <moco/tf/Names.h>
+#include <moco/Names.h>
#include <moco/Log.h>
+#include <oops/UserExn.h>
+
namespace
{
-bool canonicalize_const(loco::Graph *graph, moco::tf::TFConst *node)
+bool canonicalize_const(loco::Graph *graph, moco::TFConst *node)
{
LOGGER(l);
@@ -55,13 +54,27 @@ bool canonicalize_const(loco::Graph *graph, moco::tf::TFConst *node)
const_node->dtype(dtype);
auto rank = node->rank();
- const_node->rank(rank);
- for (uint32_t r = 0; r < rank; ++r)
+
+ if (rank == 0)
+ {
+ // This routine implements a workaround that converts a scalar constant (rank-0 tensor)
+ // into a rank-1 tensor of shape [1].
+ //
+ // TODO Revise this implementation later
+ const_node->rank(1);
+ const_node->dim(0) = 1;
+ }
+ else
{
- if (node->dim(r).known())
- const_node->dim(r) = node->dim(r);
- else
- const_node->dim(r).unset();
+ const_node->rank(rank);
+
+ for (uint32_t r = 0; r < rank; ++r)
+ {
+ if (node->dim(r).known())
+ const_node->dim(r) = node->dim(r);
+ else
+ const_node->dim(r).unset();
+ }
}
switch (dtype)
@@ -87,7 +100,7 @@ bool canonicalize_const(loco::Graph *graph, moco::tf::TFConst *node)
break;
}
default:
- throw std::runtime_error("NYI for this DataType");
+ throw oops::UserExn("Const has unsupported data type", node->name());
}
// update graph
@@ -105,25 +118,9 @@ namespace moco
namespace tf
{
-bool ConstCanonicalizer::run(loco::Graph *graph)
+bool ConstCanonicalizer::transform(TFConst *node) const
{
- auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
- bool changed = false;
-
- for (auto node : active_nodes)
- {
- if (node->dialect() == TFDialect::get())
- {
- auto tf_const = dynamic_cast<moco::tf::TFConst *>(node);
- if (tf_const != nullptr)
- {
- if (canonicalize_const(graph, tf_const))
- changed = true;
- }
- }
- }
-
- return changed;
+ return canonicalize_const(node->graph(), node);
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.h
index 53f3ca8e3..1b0b2b867 100644
--- a/compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.h
@@ -18,6 +18,9 @@
#define __MOCO_TF_CONST_CANONICALIZER_H__
#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -29,13 +32,13 @@ namespace tf
/**
* @brief Convert TFConst to Canonical ConstGen
*/
-class ConstCanonicalizer : public Transform
+class ConstCanonicalizer : public SimpleNodeTransform<moco::TFConst>
{
public:
const char *name(void) const final { return "ConstCanonicalizer"; }
public:
- bool run(loco::Graph *graph) override;
+ bool transform(moco::TFConst *) const final;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/Conv2DBackpropInputCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/Conv2DBackpropInputCanonicalizer.cpp
new file mode 100644
index 000000000..d3cbd4ab3
--- /dev/null
+++ b/compiler/moco-tf/src/Canonicalization/Conv2DBackpropInputCanonicalizer.cpp
@@ -0,0 +1,371 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Conv2DBackpropInputCanonicalizer.h"
+
+#include <moco/IR/TFDialect.h>
+
+#include "CodecHelper.h"
+
+#include <loco/IR/Stride.h>
+#include <loco/IR/Padding2D.h>
+#include <loco/Service/ShapeInference.h>
+
+#include <oops/UserExn.h>
+
+namespace
+{
+using plier::tf::DataLayout;
+
+void set_filter_enc(loco::FilterEncode *filter_enc)
+{
+ auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Filter>>();
+
+ // In TensorFlow, Conv2dBackpropInput's filter is a 4-D tensor of following shape:
+ // [filter_height, filter_width, out_channels, in_channels] or HWOI or HWNC (in/out in loco sense)
+ enc->perm()->axis(loco::FilterAxis::Height) = 0;
+ enc->perm()->axis(loco::FilterAxis::Width) = 1;
+ enc->perm()->axis(loco::FilterAxis::Count) = 2;
+ enc->perm()->axis(loco::FilterAxis::Depth) = 3;
+
+ filter_enc->encoder(std::move(enc));
+}
+
+} // namespace
+
+namespace
+{
+
+bool stride_2d_from_4d(loco::Stride<2> &ret, const std::vector<int64_t> &strides_4d,
+ const DataLayout data_layout)
+{
+ if (!(strides_4d.size() == 4))
+ return false;
+
+ switch (data_layout)
+ {
+ case DataLayout::NHWC:
+ ret.vertical(strides_4d.at(1));
+ ret.horizontal(strides_4d.at(2));
+ break;
+ case DataLayout::NCHW:
+ ret.vertical(strides_4d.at(2));
+ ret.horizontal(strides_4d.at(3));
+ break;
+ default:
+ return false;
+ }
+ return true;
+}
+
+struct PlaneShape
+{
+ loco::Dimension vertical;
+ loco::Dimension horizontal;
+};
+
+class Padding2DInference final
+{
+public:
+ Padding2DInference(const moco::TFNode *node) { _node = node; }
+
+public:
+ loco::Padding2D operator()(void);
+
+public:
+ PlaneShape &input() { return _input; }
+ PlaneShape &output() { return _output; }
+ loco::Stride<2> &stride() { return _stride; }
+ loco::Window<2> &window() { return _window; }
+ moco::TFPadding &padding() { return _padding; }
+
+private:
+ /// @brief Check whether ingredients set by non-default values
+ bool ready()
+ {
+ if (not input().vertical.known())
+ return false;
+ if (not input().horizontal.known())
+ return false;
+ if (not output().vertical.known())
+ return false;
+ if (not output().horizontal.known())
+ return false;
+ if (stride().vertical() == 0)
+ return false;
+ if (stride().horizontal() == 0)
+ return false;
+ if (window().vertical() == 0)
+ return false;
+ if (window().horizontal() == 0)
+ return false;
+ if (padding().empty())
+ return false;
+
+ return true;
+ }
+
+ inline uint32_t tight_output_for_valid_padding(uint32_t input, uint32_t stride, uint32_t filter)
+ {
+ return stride * (input - 1) + filter;
+ }
+
+ /**
+ * @note For Conv2DBackpropInput SAME padding, TensorFlow requires this condition to hold
+ *
+ * Reference: `::tensorflow::GetWindowedOutputSizeVerboseV2()` from TensorFlow project
+ */
+ inline bool same_padding_applicable(uint32_t input, uint32_t output, uint32_t stride)
+ {
+ // Here 'input' and 'output' means Conv2DBackpropInput's actual node input and output.
+ // Then these three conditions are equivalent:
+ //
+ // input == floor((output + stride - 1) / stride)
+ // input == ceil(output / stride)
+ // (stride * (input - 1) < output) and (output <= stride * input)
+ return (stride * (input - 1) < output) and (output <= stride * input);
+ }
+
+ inline uint32_t padding_needed(uint32_t input, uint32_t output, uint32_t stride, uint32_t filter)
+ {
+ return stride * (input - 1) + filter - output;
+ }
+
+private:
+ const moco::TFNode *_node;
+ PlaneShape _input;
+ PlaneShape _output;
+ loco::Stride<2> _stride;
+ loco::Window<2> _window;
+ moco::TFPadding _padding;
+};
+
+loco::Padding2D Padding2DInference::operator()(void)
+{
+ assert(ready());
+
+ if (padding() == "VALID")
+ {
+ // In case of VALID padding, TensorFlow accepts any size same or larger than
+ // 'tight fit' output. When output size (set by 'input sizes' node input) is
+ // larger than tight fit, extra spaces filled with zero.
+ auto tight_output_vertical = tight_output_for_valid_padding(
+ input().vertical.value(), stride().vertical(), window().vertical());
+ auto tight_output_horizontal = tight_output_for_valid_padding(
+ input().horizontal.value(), stride().horizontal(), window().horizontal());
+
+ if (output().vertical.value() < tight_output_vertical or
+ output().horizontal.value() < tight_output_horizontal)
+ throw oops::UserExn("input_sizes is too small", _node->name());
+
+ // Currently, only accept tight fit.
+ // TODO Support non-tight case by adding zero padding operation
+ assert(output().vertical.value() == tight_output_vertical);
+ assert(output().horizontal.value() == tight_output_horizontal);
+
+ return loco::Padding2D(0, 0, 0, 0);
+ }
+
+ if (padding() == "SAME")
+ {
+ // This condition is required by TensorFlow
+ if (not same_padding_applicable(input().vertical.value(), output().vertical.value(),
+ stride().vertical()) or
+ not same_padding_applicable(input().horizontal.value(), output().horizontal.value(),
+ stride().horizontal()))
+ throw oops::UserExn("Size mismatch for SAME padding", _node->name());
+
+ auto whole_pad_vertical = padding_needed(input().vertical.value(), output().vertical.value(),
+ stride().vertical(), window().vertical());
+ auto whole_pad_horizontal =
+ padding_needed(input().horizontal.value(), output().horizontal.value(),
+ stride().horizontal(), window().horizontal());
+
+ loco::Padding2D res;
+
+ res.top(whole_pad_vertical / 2);
+ res.bottom(whole_pad_vertical - res.top());
+ res.left(whole_pad_horizontal / 2);
+ res.right(whole_pad_horizontal - res.left());
+
+ return res;
+ }
+
+ throw oops::UserExn("Usupported padding " + padding(), _node->name());
+}
+
+/**
+ * @param[out] ret PlaneShape extracted from 'node' with given 'data_layout'
+ * @param[in] node
+ * @param[in] data_layout
+ *
+ * @return true on success
+ */
+bool set_plane_shape(PlaneShape &ret, const loco::Node *node, const DataLayout data_layout)
+{
+ auto tensor_shape = loco::shape_get(node).as<loco::TensorShape>();
+ if (!(tensor_shape.rank() == 4))
+ return false;
+
+ switch (data_layout)
+ {
+ case DataLayout::NHWC:
+ ret.vertical = tensor_shape.dim(1).value();
+ ret.horizontal = tensor_shape.dim(2).value();
+ break;
+ case DataLayout::NCHW:
+ ret.vertical = tensor_shape.dim(2).value();
+ ret.horizontal = tensor_shape.dim(3).value();
+ break;
+ default:
+ return false;
+ }
+
+ return true;
+}
+
+/**
+ * @param[out] ret 2D Window extracted from HW** filter node
+ * @param[in] filter_node
+ *
+ * @return true on success
+ */
+bool set_window(loco::Window<2> &ret, const loco::Node *filter_node)
+{
+ auto tensor_shape = loco::shape_get(filter_node).as<loco::TensorShape>();
+ assert(tensor_shape.rank() == 4);
+
+ ret.vertical(tensor_shape.dim(0).value());
+ ret.horizontal(tensor_shape.dim(1).value());
+
+ return true;
+}
+
+} // namespace
+
+namespace
+{
+
+bool canonicalize_conv2d_backprop_input(loco::Graph *graph,
+ moco::TFConv2DBackpropInput *conv2d_backprop)
+{
+ /**
+ * @note This will replace TFConv2DBackpropInput node with canonical
+ * FeatureEncode + FilterEncode + TransposedConv2D + FeatureDecode
+ *
+ * Before
+ * input_sizes ----
+ * \
+ * filter -------- TFConv2DBackpropInput --- output(s)
+ * /
+ * out_backprop ---
+ *
+ * After
+ * input_sizes ----
+ * \
+ * filter -------- TFConv2DBackpropInput ---
+ * /
+ * out_backprop ---
+ *
+ * filter ------ FilterEncode ------ TransposedConv2D --- FeatureDecode --- output(s)
+ * (as ker) /
+ * out_backprop --- FeatureEncode ---
+ * (as ifm)
+ */
+
+ if (!loco::shape_known(conv2d_backprop->out_backprop()))
+ return false;
+ if (!loco::shape_known(conv2d_backprop))
+ return false;
+ if (!loco::shape_known(conv2d_backprop->filter()))
+ return false;
+
+ auto data_layout = plier::tf::as_data_layout(conv2d_backprop->data_layout());
+
+ // Nodes to replace
+ auto feature_enc = graph->nodes()->create<loco::FeatureEncode>();
+ auto filter_enc = graph->nodes()->create<loco::FilterEncode>();
+ auto tr_conv2d = graph->nodes()->create<loco::TransposedConv2D>();
+ auto feature_dec = graph->nodes()->create<loco::FeatureDecode>();
+
+ set_feature_enc(feature_enc, data_layout);
+ set_filter_enc(filter_enc);
+ set_feature_dec(feature_dec, data_layout);
+
+ // Attributes for new TransposedConv2D
+ loco::Stride<2> stride;
+ loco::Padding2D pad;
+
+ // Get attributes
+ {
+ if (!stride_2d_from_4d(stride, conv2d_backprop->strides(), data_layout))
+ throw oops::UserExn("Unsupported strides", conv2d_backprop->name());
+
+ Padding2DInference infer_pad(conv2d_backprop);
+
+ if (!set_plane_shape(infer_pad.input(), conv2d_backprop->out_backprop(), data_layout))
+ throw oops::UserExn("Unsupported out_backprop data_format", conv2d_backprop->name());
+ if (!set_plane_shape(infer_pad.output(), conv2d_backprop, data_layout))
+ throw oops::UserExn("Unsupported data_format", conv2d_backprop->name());
+ if (!set_window(infer_pad.window(), conv2d_backprop->filter()))
+ throw oops::UserExn("Unsupported filter shape", conv2d_backprop->name());
+ infer_pad.stride() = stride;
+ infer_pad.padding() = conv2d_backprop->padding();
+
+ // Run padding infer_pad
+ pad = infer_pad();
+ }
+
+ // Set attributes
+ tr_conv2d->pad()->top(pad.top());
+ tr_conv2d->pad()->bottom(pad.bottom());
+ tr_conv2d->pad()->left(pad.left());
+ tr_conv2d->pad()->right(pad.right());
+
+ tr_conv2d->stride()->vertical(stride.vertical());
+ tr_conv2d->stride()->horizontal(stride.horizontal());
+
+ // Update graph
+ auto input_node = conv2d_backprop->out_backprop();
+ auto filter_node = conv2d_backprop->filter();
+
+ // Update connections
+ feature_enc->input(input_node);
+ filter_enc->input(filter_node);
+ tr_conv2d->ifm(feature_enc);
+ tr_conv2d->ker(filter_enc);
+ feature_dec->input(tr_conv2d);
+
+ // Replace old conv2d_backprop
+ replace(conv2d_backprop).with(feature_dec);
+
+ return true;
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+bool Conv2DBackpropInputCanonicalizer::transform(TFConv2DBackpropInput *node) const
+{
+ return canonicalize_conv2d_backprop_input(node->graph(), node);
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/Canonicalization/Conv2DBackpropInputCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/Conv2DBackpropInputCanonicalizer.h
new file mode 100644
index 000000000..bc37bb9cb
--- /dev/null
+++ b/compiler/moco-tf/src/Canonicalization/Conv2DBackpropInputCanonicalizer.h
@@ -0,0 +1,45 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_CONV2DBACKPROPINPUT_CANONICALIZER_H__
+#define __MOCO_TF_CONV2DBACKPROPINPUT_CANONICALIZER_H__
+
+#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/TFNodes.h>
+
+#include <loco.h>
+
+namespace moco
+{
+namespace tf
+{
+
+/// @brief Convert TFConv2DBackpropInput to Canonical TransposedConv2D
+class Conv2DBackpropInputCanonicalizer : public SimpleNodeTransform<moco::TFConv2DBackpropInput>
+{
+public:
+ const char *name(void) const final { return "Conv2DBackpropInputCanonicalizer"; }
+
+public:
+ bool transform(moco::TFConv2DBackpropInput *) const final;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_CONV2DBACKPROPINPUT_CANONICALIZER_H__
diff --git a/compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.cpp
index f34339d0f..a955793a8 100644
--- a/compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.cpp
@@ -16,46 +16,18 @@
#include "Conv2DCanonicalizer.h"
-#include "Annotations/PadData.h"
-#include "Annotations/StrideData.h"
+#include <moco/IR/TFDialect.h>
+#include <moco/Support/TFShapeInferenceHelper.h>
-#include "Dialect/TFDialect.h"
-#include "Dialect/TFNodes.h"
-#include "Dialect/TFNodeVisitor.h"
-#include "Dialect/TFNodeImpl.h"
+#include "CodecHelper.h"
#include <moco/Log.h>
-#include <plier/tf/Convert.h>
-
-#include <stdex/Memory.h>
namespace
{
using plier::tf::DataLayout;
-void set_feature_enc(loco::FeatureEncode *feature_enc, DataLayout data_layout)
-{
- auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>();
-
- if (data_layout == DataLayout::NHWC)
- {
- enc->perm()->axis(loco::FeatureAxis::Count) = 0;
- enc->perm()->axis(loco::FeatureAxis::Height) = 1;
- enc->perm()->axis(loco::FeatureAxis::Width) = 2;
- enc->perm()->axis(loco::FeatureAxis::Depth) = 3;
- }
- else if (data_layout == DataLayout::NCHW)
- {
- enc->perm()->axis(loco::FeatureAxis::Count) = 0;
- enc->perm()->axis(loco::FeatureAxis::Depth) = 1;
- enc->perm()->axis(loco::FeatureAxis::Height) = 2;
- enc->perm()->axis(loco::FeatureAxis::Width) = 3;
- }
-
- feature_enc->encoder(std::move(enc));
-}
-
-void set_filter_enc(loco::FilterEncode *filter_enc, DataLayout data_layout)
+void set_filter_enc(loco::FilterEncode *filter_enc)
{
auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Filter>>();
@@ -69,29 +41,7 @@ void set_filter_enc(loco::FilterEncode *filter_enc, DataLayout data_layout)
filter_enc->encoder(std::move(enc));
}
-void set_feature_dec(loco::FeatureDecode *feature_dec, DataLayout data_layout)
-{
- auto dec = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>();
-
- if (data_layout == DataLayout::NHWC)
- {
- dec->perm()->axis(loco::FeatureAxis::Count) = 0;
- dec->perm()->axis(loco::FeatureAxis::Height) = 1;
- dec->perm()->axis(loco::FeatureAxis::Width) = 2;
- dec->perm()->axis(loco::FeatureAxis::Depth) = 3;
- }
- else if (data_layout == DataLayout::NCHW)
- {
- dec->perm()->axis(loco::FeatureAxis::Count) = 0;
- dec->perm()->axis(loco::FeatureAxis::Depth) = 1;
- dec->perm()->axis(loco::FeatureAxis::Height) = 2;
- dec->perm()->axis(loco::FeatureAxis::Width) = 3;
- }
-
- feature_dec->decoder(std::move(dec));
-}
-
-bool canonicalize_conv2d(loco::Graph *graph, moco::tf::TFConv2D *node)
+bool canonicalize_conv2d(loco::Graph *graph, moco::TFConv2D *node)
{
LOGGER(l);
@@ -125,23 +75,29 @@ bool canonicalize_conv2d(loco::Graph *graph, moco::tf::TFConv2D *node)
auto feature_dec = graph->nodes()->create<loco::FeatureDecode>();
set_feature_enc(feature_enc, data_layout);
- set_filter_enc(filter_enc, data_layout);
+ set_filter_enc(filter_enc);
set_feature_dec(feature_dec, data_layout);
- // Set Conv2D attributes from TFConv2D
- auto pad_data = node->annot<moco::tf::PadData>();
- assert(pad_data != nullptr);
+ auto input_shape = moco::node_shape(node->input());
+ assert(input_shape.domain() != loco::Domain::Unknown);
+
+ auto ker_shape = moco::node_shape(node->filter());
+ auto ker_tensor_shape = ker_shape.as<loco::TensorShape>(); // in HWIO
+
+ auto node_stride = moco::stride_of(node->strides(), node->data_layout());
+ auto node_window = moco::window_of(ker_tensor_shape, "HWIO");
+
+ moco::Padding2DInference infer_padding2d;
- conv2d->pad()->top(pad_data->pad()->top());
- conv2d->pad()->bottom(pad_data->pad()->bottom());
- conv2d->pad()->left(pad_data->pad()->left());
- conv2d->pad()->right(pad_data->pad()->right());
+ infer_padding2d.padding(node->padding());
+ infer_padding2d.stride(node_stride);
+ infer_padding2d.window(node_window);
- auto stride_data = node->annot<moco::tf::StrideData>();
- assert(stride_data != nullptr);
+ auto input_feature_shape = moco::as_feature_shape(input_shape, node->data_layout());
+ auto input_plane_shape = moco::make_plane_shape(input_feature_shape);
- conv2d->stride()->vertical(stride_data->stride()->vertical());
- conv2d->stride()->horizontal(stride_data->stride()->horizontal());
+ *conv2d->pad() = infer_padding2d(input_plane_shape);
+ *conv2d->stride() = node_stride;
// update graph
auto node_A = node->input();
@@ -167,25 +123,9 @@ namespace moco
namespace tf
{
-bool Conv2DCanonicalizer::run(loco::Graph *graph)
+bool Conv2DCanonicalizer::transform(TFConv2D *node) const
{
- auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
- bool changed = false;
-
- for (auto node : active_nodes)
- {
- if (node->dialect() == TFDialect::get())
- {
- auto tf_node = dynamic_cast<moco::tf::TFConv2D *>(node);
- if (tf_node != nullptr)
- {
- if (canonicalize_conv2d(graph, tf_node))
- changed = true;
- }
- }
- }
-
- return changed;
+ return canonicalize_conv2d(node->graph(), node);
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.h
index 6be264f90..ea39667f3 100644
--- a/compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.h
@@ -18,6 +18,9 @@
#define __MOCO_TF_CONV2D_CANONICALIZER_H__
#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -29,13 +32,13 @@ namespace tf
/**
* @brief Convert TFConv2D to Canonical Conv2D
*/
-class Conv2DCanonicalizer : public Transform
+class Conv2DCanonicalizer : public SimpleNodeTransform<TFConv2D>
{
public:
const char *name(void) const final { return "Conv2DCanonicalizer"; }
public:
- bool run(loco::Graph *graph) override;
+ bool transform(TFConv2D *) const final;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.cpp
index ee63efa2f..50dddf637 100644
--- a/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.cpp
@@ -16,47 +16,18 @@
#include "DepthwiseConv2dNativeCanonicalizer.h"
-#include "Annotations/PadData.h"
-#include "Annotations/ShapeInferenceData.h"
-#include "Annotations/StrideData.h"
+#include <moco/IR/TFDialect.h>
+#include <moco/Support/TFShapeInferenceHelper.h>
-#include "Dialect/TFDialect.h"
-#include "Dialect/TFNodes.h"
-#include "Dialect/TFNodeVisitor.h"
-#include "Dialect/TFNodeImpl.h"
+#include "CodecHelper.h"
#include <moco/Log.h>
-#include <plier/tf/Convert.h>
-
-#include <stdex/Memory.h>
namespace
{
using plier::tf::DataLayout;
-void set_feature_enc(loco::FeatureEncode *feature_enc, DataLayout data_layout)
-{
- auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>();
-
- if (data_layout == DataLayout::NHWC)
- {
- enc->perm()->axis(loco::FeatureAxis::Count) = 0;
- enc->perm()->axis(loco::FeatureAxis::Height) = 1;
- enc->perm()->axis(loco::FeatureAxis::Width) = 2;
- enc->perm()->axis(loco::FeatureAxis::Depth) = 3;
- }
- else if (data_layout == DataLayout::NCHW)
- {
- enc->perm()->axis(loco::FeatureAxis::Count) = 0;
- enc->perm()->axis(loco::FeatureAxis::Depth) = 1;
- enc->perm()->axis(loco::FeatureAxis::Height) = 2;
- enc->perm()->axis(loco::FeatureAxis::Width) = 3;
- }
-
- feature_enc->encoder(std::move(enc));
-}
-
void set_filter_enc(loco::DepthwiseFilterEncode *filter_enc)
{
auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::DepthwiseFilter>>();
@@ -71,29 +42,7 @@ void set_filter_enc(loco::DepthwiseFilterEncode *filter_enc)
filter_enc->encoder(std::move(enc));
}
-void set_feature_dec(loco::FeatureDecode *feature_dec, DataLayout data_layout)
-{
- auto dec = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>();
-
- if (data_layout == DataLayout::NHWC)
- {
- dec->perm()->axis(loco::FeatureAxis::Count) = 0;
- dec->perm()->axis(loco::FeatureAxis::Height) = 1;
- dec->perm()->axis(loco::FeatureAxis::Width) = 2;
- dec->perm()->axis(loco::FeatureAxis::Depth) = 3;
- }
- else if (data_layout == DataLayout::NCHW)
- {
- dec->perm()->axis(loco::FeatureAxis::Count) = 0;
- dec->perm()->axis(loco::FeatureAxis::Depth) = 1;
- dec->perm()->axis(loco::FeatureAxis::Height) = 2;
- dec->perm()->axis(loco::FeatureAxis::Width) = 3;
- }
-
- feature_dec->decoder(std::move(dec));
-}
-
-bool canonicalize_depthwiseconv2dnative(loco::Graph *graph, moco::tf::TFDepthwiseConv2dNative *node)
+bool canonicalize_depthwiseconv2dnative(loco::Graph *graph, moco::TFDepthwiseConv2dNative *node)
{
LOGGER(l);
@@ -134,20 +83,24 @@ bool canonicalize_depthwiseconv2dnative(loco::Graph *graph, moco::tf::TFDepthwis
set_filter_enc(filter_enc);
set_feature_dec(feature_dec, data_layout);
- // Set DetphwiseConv2D attributes from TFDepthwiseConv2dNative
- auto pad_data = node->annot<moco::tf::PadData>();
- assert(pad_data != nullptr);
+ // Calculate Pad and Stride from inference
+ auto input_shape = moco::node_shape(node->input());
+ auto ker_shape = moco::node_shape(node->filter());
+ auto ker_tensor_shape = ker_shape.as<loco::TensorShape>();
+ auto node_stride = moco::stride_of(node->strides(), node->data_layout());
+ auto node_window = moco::window_of(ker_tensor_shape, "HWCM");
+
+ moco::Padding2DInference infer_padding2d;
- depthwiseconv2d->pad()->top(pad_data->pad()->top());
- depthwiseconv2d->pad()->bottom(pad_data->pad()->bottom());
- depthwiseconv2d->pad()->left(pad_data->pad()->left());
- depthwiseconv2d->pad()->right(pad_data->pad()->right());
+ infer_padding2d.padding(node->padding());
+ infer_padding2d.stride(node_stride);
+ infer_padding2d.window(node_window);
- auto stride_data = node->annot<moco::tf::StrideData>();
- assert(stride_data != nullptr);
+ auto input_feature_shape = moco::as_feature_shape(input_shape, node->data_layout());
+ auto input_plane_shape = moco::make_plane_shape(input_feature_shape);
- depthwiseconv2d->stride()->vertical(stride_data->stride()->vertical());
- depthwiseconv2d->stride()->horizontal(stride_data->stride()->horizontal());
+ *depthwiseconv2d->pad() = infer_padding2d(input_plane_shape);
+ *depthwiseconv2d->stride() = node_stride;
// update graph
auto node_A = node->input();
@@ -175,25 +128,9 @@ namespace moco
namespace tf
{
-bool DepthwiseConv2dNativeCanonicalizer::run(loco::Graph *graph)
+bool DepthwiseConv2dNativeCanonicalizer::transform(TFDepthwiseConv2dNative *node) const
{
- auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
- bool changed = false;
-
- for (auto node : active_nodes)
- {
- if (node->dialect() == TFDialect::get())
- {
- auto tf_node = dynamic_cast<moco::tf::TFDepthwiseConv2dNative *>(node);
- if (tf_node != nullptr)
- {
- if (canonicalize_depthwiseconv2dnative(graph, tf_node))
- changed = true;
- }
- }
- }
-
- return changed;
+ return canonicalize_depthwiseconv2dnative(node->graph(), node);
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.h
index 9bb8c5ad8..704e1ade9 100644
--- a/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.h
@@ -18,6 +18,9 @@
#define __MOCO_TF_DEPTHWISE_CONV2D_NATIVE_CANONICALIZER_H__
#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/TFNodes.h>
namespace moco
{
@@ -27,13 +30,13 @@ namespace tf
/**
* @brief Convert TFDepthwiseConv2dNative to Canonical DepthwiseConv2D
*/
-class DepthwiseConv2dNativeCanonicalizer : public Transform
+class DepthwiseConv2dNativeCanonicalizer : public SimpleNodeTransform<moco::TFDepthwiseConv2dNative>
{
public:
const char *name(void) const final { return "DepthwiseConv2dNativeCanonicalizer"; }
public:
- bool run(loco::Graph *graph) override;
+ bool transform(moco::TFDepthwiseConv2dNative *) const final;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.cpp
index c4d5d8063..3b680cf04 100644
--- a/compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.cpp
@@ -18,18 +18,15 @@
#include "Convert.h"
-#include "Dialect/TFDialect.h"
-#include "Dialect/TFNodes.h"
-#include "Dialect/TFNodeVisitor.h"
-#include "Dialect/TFNodeImpl.h"
+#include <moco/IR/TFDialect.h>
-#include <moco/tf/Names.h>
+#include <moco/Names.h>
#include <moco/Log.h>
namespace
{
-bool canonicalize_identity(loco::Graph *graph, moco::tf::TFIdentity *node)
+bool canonicalize_identity(loco::Graph *graph, moco::TFIdentity *node)
{
LOGGER(l);
@@ -72,25 +69,9 @@ namespace moco
namespace tf
{
-bool IdentityCanonicalizer::run(loco::Graph *graph)
+bool IdentityCanonicalizer::transform(TFIdentity *node) const
{
- auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
- bool changed = false;
-
- for (auto node : active_nodes)
- {
- if (node->dialect() == TFDialect::get())
- {
- auto tf_identity = dynamic_cast<moco::tf::TFIdentity *>(node);
- if (tf_identity != nullptr)
- {
- if (canonicalize_identity(graph, tf_identity))
- changed = true;
- }
- }
- }
-
- return changed;
+ return canonicalize_identity(node->graph(), node);
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.h
index 81aee178a..59b2894c5 100644
--- a/compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.h
@@ -18,6 +18,9 @@
#define __MOCO_TF_IDENTITY_CANONICALIZER_H__
#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -29,13 +32,13 @@ namespace tf
/**
* @brief Convert TFIdentity to Canonical Forward
*/
-class IdentityCanonicalizer : public Transform
+class IdentityCanonicalizer : public SimpleNodeTransform<moco::TFIdentity>
{
public:
const char *name(void) const final { return "IdentityCanonicalizer"; }
public:
- bool run(loco::Graph *graph) override;
+ bool transform(moco::TFIdentity *) const final;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.cpp
index c46fbd208..06a605717 100644
--- a/compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.cpp
@@ -16,70 +16,17 @@
#include "MaxPoolCanonicalizer.h"
-#include "Annotations/PadData.h"
-#include "Annotations/StrideData.h"
-#include "Annotations/WindowData.h"
+#include <moco/IR/TFDialect.h>
+#include <moco/Support/TFShapeInferenceHelper.h>
-#include "Dialect/TFDialect.h"
-#include "Dialect/TFNodes.h"
-#include "Dialect/TFNodeVisitor.h"
-#include "Dialect/TFNodeImpl.h"
+#include "CodecHelper.h"
#include <moco/Log.h>
-#include <plier/tf/Convert.h>
-
-#include <stdex/Memory.h>
namespace
{
-using plier::tf::DataLayout;
-
-void set_feature_enc(loco::FeatureEncode *feature_enc, DataLayout data_layout)
-{
- auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>();
-
- if (data_layout == DataLayout::NHWC)
- {
- enc->perm()->axis(loco::FeatureAxis::Count) = 0;
- enc->perm()->axis(loco::FeatureAxis::Height) = 1;
- enc->perm()->axis(loco::FeatureAxis::Width) = 2;
- enc->perm()->axis(loco::FeatureAxis::Depth) = 3;
- }
- else if (data_layout == DataLayout::NCHW)
- {
- enc->perm()->axis(loco::FeatureAxis::Count) = 0;
- enc->perm()->axis(loco::FeatureAxis::Depth) = 1;
- enc->perm()->axis(loco::FeatureAxis::Height) = 2;
- enc->perm()->axis(loco::FeatureAxis::Width) = 3;
- }
-
- feature_enc->encoder(std::move(enc));
-}
-
-void set_feature_dec(loco::FeatureDecode *feature_dec, DataLayout data_layout)
-{
- auto dec = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>();
-
- if (data_layout == DataLayout::NHWC)
- {
- dec->perm()->axis(loco::FeatureAxis::Count) = 0;
- dec->perm()->axis(loco::FeatureAxis::Height) = 1;
- dec->perm()->axis(loco::FeatureAxis::Width) = 2;
- dec->perm()->axis(loco::FeatureAxis::Depth) = 3;
- }
- else if (data_layout == DataLayout::NCHW)
- {
- dec->perm()->axis(loco::FeatureAxis::Count) = 0;
- dec->perm()->axis(loco::FeatureAxis::Depth) = 1;
- dec->perm()->axis(loco::FeatureAxis::Height) = 2;
- dec->perm()->axis(loco::FeatureAxis::Width) = 3;
- }
-
- feature_dec->decoder(std::move(dec));
-}
-
-bool canonicalize_maxpool2d(loco::Graph *graph, moco::tf::TFMaxPool *node)
+bool canonicalize_maxpool2d(loco::Graph *graph, moco::TFMaxPool *node)
{
LOGGER(l);
@@ -111,36 +58,31 @@ bool canonicalize_maxpool2d(loco::Graph *graph, moco::tf::TFMaxPool *node)
set_feature_dec(feature_dec, data_layout);
// paddata to pad
- auto pad_data = node->annot<moco::tf::PadData>();
- assert(pad_data != nullptr);
+ auto input_shape = moco::node_shape(node->input());
+ assert(input_shape.domain() != loco::Domain::Unknown);
- maxPool2d_node->pad()->top(pad_data->pad()->top());
- maxPool2d_node->pad()->bottom(pad_data->pad()->bottom());
- maxPool2d_node->pad()->left(pad_data->pad()->left());
- maxPool2d_node->pad()->right(pad_data->pad()->right());
+ auto node_stride = moco::stride_of(node->strides(), node->data_layout());
+ auto node_window = moco::window_of(node->ksize(), node->data_layout());
- // windowdata to window (ksize to window)
- auto window_data = node->annot<moco::tf::WindowData>();
- assert(window_data != nullptr);
+ moco::Padding2DInference infer_padding2d;
- auto window = maxPool2d_node->window();
- window->vertical(window_data->window()->vertical());
- window->horizontal(window_data->window()->horizontal());
+ infer_padding2d.padding(node->padding());
+ infer_padding2d.stride(node_stride);
+ infer_padding2d.window(node_window);
- // stridedata to stride (strides to stride)
- auto stride_data = node->annot<moco::tf::StrideData>();
- assert(stride_data != nullptr);
+ auto input_feature_shape = moco::as_feature_shape(input_shape, node->data_layout());
+ auto input_plane_shape = moco::make_plane_shape(input_feature_shape);
- auto stride = maxPool2d_node->stride();
- stride->vertical(stride_data->stride()->vertical());
- stride->horizontal(stride_data->stride()->horizontal());
+ *maxPool2d_node->pad() = infer_padding2d(input_plane_shape);
+ *maxPool2d_node->stride() = node_stride;
+ *maxPool2d_node->window() = node_window;
INFO(l) << "Canonicalize TFMaxPool pad = T " << maxPool2d_node->pad()->top() << ", L "
<< maxPool2d_node->pad()->left() << ", B " << maxPool2d_node->pad()->bottom() << ", R "
<< maxPool2d_node->pad()->right() << std::endl;
// update graph
- auto node_A = node->value();
+ auto node_A = node->input();
// update connections
feature_enc->input(node_A);
@@ -160,25 +102,9 @@ namespace moco
namespace tf
{
-bool MaxPoolCanonicalizer::run(loco::Graph *graph)
+bool MaxPoolCanonicalizer::transform(TFMaxPool *node) const
{
- auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
- bool changed = false;
-
- for (auto node : active_nodes)
- {
- if (node->dialect() == TFDialect::get())
- {
- auto tf_node = dynamic_cast<moco::tf::TFMaxPool *>(node);
- if (tf_node != nullptr)
- {
- if (canonicalize_maxpool2d(graph, tf_node))
- changed = true;
- }
- }
- }
-
- return changed;
+ return canonicalize_maxpool2d(node->graph(), node);
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.h
index a486c4caa..c58ade528 100644
--- a/compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.h
@@ -18,6 +18,9 @@
#define __MOCO_TF_MAXPOOL_CANONICALIZER_H__
#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -29,13 +32,13 @@ namespace tf
/**
* @brief Convert TFMaxPool to Canonical MaxPool2D
*/
-class MaxPoolCanonicalizer : public Transform
+class MaxPoolCanonicalizer : public SimpleNodeTransform<moco::TFMaxPool>
{
public:
const char *name(void) const final { return "MaxPoolCanonicalizer"; }
public:
- bool run(loco::Graph *graph) override;
+ bool transform(moco::TFMaxPool *) const final;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/MaximumCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/MaximumCanonicalizer.cpp
new file mode 100644
index 000000000..92634d01f
--- /dev/null
+++ b/compiler/moco-tf/src/Canonicalization/MaximumCanonicalizer.cpp
@@ -0,0 +1,34 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "MaximumCanonicalizer.h"
+
+#include <moco/IR/TFDialect.h>
+
+#include "TFEltwiseBinaryCanonicalzeHelper.h"
+
+namespace moco
+{
+namespace tf
+{
+
+bool MaximumCanonicalizer::transform(moco::TFMaximum *node) const
+{
+ return canonicalize_eltwise_binary_node(node);
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/Canonicalization/MaximumCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/MaximumCanonicalizer.h
new file mode 100644
index 000000000..baff4d7ad
--- /dev/null
+++ b/compiler/moco-tf/src/Canonicalization/MaximumCanonicalizer.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_MAXIMUM_CANONICALIZER_H__
+#define __MOCO_TF_MAXIMUM_CANONICALIZER_H__
+
+#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/TFNodes.h>
+
+#include <loco.h>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief Convert TFMaximum to Canonical EltwiseMax
+ */
+class MaximumCanonicalizer : public SimpleNodeTransform<moco::TFMaximum>
+{
+public:
+ const char *name(void) const final { return "MaximumCanonicalizer"; }
+
+public:
+ bool transform(moco::TFMaximum *) const final;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_MAXIMUM_CANONICALIZER_H__
diff --git a/compiler/moco-tf/src/Canonicalization/MeanCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/MeanCanonicalizer.cpp
new file mode 100644
index 000000000..69eaf7900
--- /dev/null
+++ b/compiler/moco-tf/src/Canonicalization/MeanCanonicalizer.cpp
@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "MeanCanonicalizer.h"
+#include "TFReduceCanonicalzeHelper.h"
+
+namespace moco
+{
+namespace tf
+{
+
+bool MeanCanonicalizer::transform(moco::TFMean *node) const
+{
+ return canonicalize_reduce_node(node);
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/Canonicalization/MeanCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/MeanCanonicalizer.h
new file mode 100644
index 000000000..469d7e3cd
--- /dev/null
+++ b/compiler/moco-tf/src/Canonicalization/MeanCanonicalizer.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_MEAN_CANONICALIZER_H__
+#define __MOCO_TF_MEAN_CANONICALIZER_H__
+
+#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/TFNodes.h>
+
+#include <loco.h>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief Canonicalize TF-dialect TFMean into canonical TensorReduce(Mean) node
+ */
+class MeanCanonicalizer : public SimpleNodeTransform<moco::TFMean>
+{
+public:
+ const char *name(void) const final { return "MeanCanonicalizer"; }
+
+public:
+ bool transform(moco::TFMean *) const final;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_MEAN_CANONICALIZER_H__
diff --git a/compiler/moco-tf/src/Canonicalization/MulCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/MulCanonicalizer.cpp
index 78d0ebc48..d02f71361 100644
--- a/compiler/moco-tf/src/Canonicalization/MulCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/MulCanonicalizer.cpp
@@ -16,8 +16,7 @@
#include "MulCanonicalizer.h"
-#include "Dialect/TFDialect.h"
-#include "Dialect/TFNodes.h"
+#include <moco/IR/TFDialect.h>
#include "TFEltwiseBinaryCanonicalzeHelper.h"
@@ -26,25 +25,9 @@ namespace moco
namespace tf
{
-bool MulCanonicalizer::run(loco::Graph *graph)
+bool MulCanonicalizer::transform(moco::TFMul *node) const
{
- auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
- bool changed = false;
-
- for (auto node : active_nodes)
- {
- if (node->dialect() == TFDialect::get())
- {
- auto tf_node = dynamic_cast<moco::tf::TFMul *>(node);
- if (tf_node != nullptr)
- {
- if (canonicalize_eltwise_binary_node(tf_node))
- changed = true;
- }
- }
- }
-
- return changed;
+ return canonicalize_eltwise_binary_node(node);
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/MulCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/MulCanonicalizer.h
index 680f4c315..480eec700 100644
--- a/compiler/moco-tf/src/Canonicalization/MulCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/MulCanonicalizer.h
@@ -18,6 +18,9 @@
#define __MOCO_TF_MUL_CANONICALIZER_H__
#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -29,13 +32,13 @@ namespace tf
/**
* @brief Convert TFMul to Canonical EltwiseMul
*/
-class MulCanonicalizer : public Transform
+class MulCanonicalizer : public SimpleNodeTransform<moco::TFMul>
{
public:
const char *name(void) const final { return "MulCanonicalizer"; }
public:
- bool run(loco::Graph *graph) override;
+ bool transform(moco::TFMul *) const final;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/PadCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/PadCanonicalizer.cpp
new file mode 100644
index 000000000..36136aed4
--- /dev/null
+++ b/compiler/moco-tf/src/Canonicalization/PadCanonicalizer.cpp
@@ -0,0 +1,100 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "PadCanonicalizer.h"
+
+#include <moco/IR/TFDialect.h>
+
+#include "loco/Service/TypeInference.h"
+
+#include <stdex/Memory.h>
+
+namespace
+{
+
+bool canonicalize_pad(loco::Graph *graph, moco::TFPad *node)
+{
+ /**
+ * @note This will replace TFPad node with Canonical TensorConstantPad
+ *
+ * Before
+ * input --- TFPad -- C
+ * paddings --/
+ * After
+ * paddings ------- TFPad --
+ * /
+ * input ----------- TensorConstantPad -- C
+ * ConstGen --------/
+ * Where
+ * input : input of TFPad
+ * paddings : paddings of TFPad. it becomes TensorConstantPad's attribute.
+ * C : a node that uses TFPad as an input. TFPad is disconnected from C.
+ * ConstGen : constant value of Pad. TFPad has zero value by default.
+ */
+
+ auto pad_node = graph->nodes()->create<loco::TensorConstantPad>();
+
+ auto constant_node = graph->nodes()->create<loco::ConstGen>();
+
+ auto input_node = node->input();
+ // TODO: support other dtype.
+ assert(loco::dtype_get(input_node) == loco::DataType::FLOAT32);
+ constant_node->dtype(loco::DataType::FLOAT32);
+ // TODO: constant node changes to scalar when it is implemented.
+ constant_node->shape({1});
+ constant_node->size<loco::DataType::FLOAT32>(1);
+ constant_node->at<loco::DataType::FLOAT32>(0) = 0.0f;
+
+ auto const_paddings_node = loco::must_cast<loco::ConstGen *>(node->paddings());
+ // TODO: support S64 type.
+ assert(const_paddings_node->dtype() == loco::DataType::S32);
+ assert(const_paddings_node->rank() == 2);
+ assert(const_paddings_node->dim(1).value() == 2);
+
+ auto padding = pad_node->padding();
+ uint32_t padding_rank = const_paddings_node->dim(0).value();
+ padding->rank(padding_rank);
+
+ for (uint32_t i = 0; i < padding_rank; i++)
+ {
+ padding->front(i) = const_paddings_node->at<loco::DataType::S32>(i << 1);
+ padding->back(i) = const_paddings_node->at<loco::DataType::S32>((i << 1) + 1);
+ }
+
+ // update connections
+ pad_node->input(input_node);
+ pad_node->constant(constant_node);
+
+ // replace node
+ replace(node).with(pad_node);
+
+ return true;
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+bool PadCanonicalizer::transform(TFPad *node) const
+{
+ return canonicalize_pad(node->graph(), node);
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/PadCanonicalizer.h
index afd65be32..64bb6041a 100644
--- a/compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/PadCanonicalizer.h
@@ -14,10 +14,13 @@
* limitations under the License.
*/
-#ifndef __MOCO_TF_SQUAREDDIFFERENCE_CANONICALIZER_H__
-#define __MOCO_TF_SQUAREDDIFFERENCE_CANONICALIZER_H__
+#ifndef __MOCO_TF_PAD_CANONICALIZER_H__
+#define __MOCO_TF_PAD_CANONICALIZER_H__
#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/TFNodes.h>
namespace moco
{
@@ -25,18 +28,18 @@ namespace tf
{
/**
- * @brief Convert TFSquaredDifference to Canonical EltwiseSub and EltwiseMul
+ * @brief Convert TFPad to Canonical TensorConstantPad
*/
-class SquaredDifferenceCanonicalizer final : public Transform
+class PadCanonicalizer final : public SimpleNodeTransform<moco::TFPad>
{
public:
- const char *name(void) const final { return "SquaredDifferenceCanonicalizer"; }
+ const char *name(void) const final { return "PadCanonicalizer"; }
public:
- bool run(loco::Graph *graph) final;
+ bool transform(moco::TFPad *) const final;
};
} // namespace tf
} // namespace moco
-#endif // __MOCO_TF_SQUAREDDIFFERENCE_CANONICALIZER_H__
+#endif // __MOCO_TF_PAD_CANONICALIZER_H__
diff --git a/compiler/moco-tf/src/Canonicalization/PlaceholderCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/PlaceholderCanonicalizer.cpp
new file mode 100644
index 000000000..f568e909f
--- /dev/null
+++ b/compiler/moco-tf/src/Canonicalization/PlaceholderCanonicalizer.cpp
@@ -0,0 +1,102 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "PlaceholderCanonicalizer.h"
+
+#include <moco/IR/TFDialect.h>
+
+#include <moco/Names.h>
+#include <moco/Log.h>
+
+namespace
+{
+
+bool canonicalize_placeholder(loco::Graph *graph, moco::TFPlaceholder *node)
+{
+ LOGGER(l);
+
+ /**
+ * @note This will replace TFPlaceholder node with Canonical Pull
+ *
+ * Before
+ * TFPlaceholder -- C
+ *
+ * After
+ * TFPlaceholder -
+ * Pull -- C
+ *
+ * Where
+ * C : a node that uses TFPlaceholder as an input
+ * TFPlaceholder is disconnected from other nodes
+ */
+
+ INFO(l) << "PlaceholderCanonicalizer begin";
+
+ auto pull_node = graph->nodes()->create<loco::Pull>();
+
+ // copy properties
+ auto dtype = node->dtype();
+ pull_node->dtype(dtype);
+
+ auto rank = node->rank();
+
+ if (rank == 0)
+ {
+ // This routine implements a workaround that converts a scalar constant (rank-0 tensor)
+ // into a rank-1 tensor of shape [1].
+ //
+ // TODO Revise this implementation later
+ pull_node->rank(1);
+ pull_node->dim(0) = 1;
+ }
+ else
+ {
+ pull_node->rank(rank);
+
+ for (uint32_t r = 0; r < rank; ++r)
+ {
+ if (node->dim(r).known())
+ pull_node->dim(r) = node->dim(r);
+ else
+ pull_node->dim(r).unset();
+ }
+ }
+
+ // set loco::Pull GraphInputIndex
+ pull_node->index(moco::index(node));
+
+ // update graph
+ replace(node).with(pull_node);
+
+ INFO(l) << "PlaceholderCanonicalizer done";
+
+ return true;
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+bool PlaceholderCanonicalizer::transform(TFPlaceholder *node) const
+{
+ return canonicalize_placeholder(node->graph(), node);
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/Canonicalization/PlaceholderCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/PlaceholderCanonicalizer.h
new file mode 100644
index 000000000..66eafe6af
--- /dev/null
+++ b/compiler/moco-tf/src/Canonicalization/PlaceholderCanonicalizer.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_PLACEHOLDER_CANONICALIZER_H__
+#define __MOCO_TF_PLACEHOLDER_CANONICALIZER_H__
+
+#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/Nodes/TFPlaceholder.h>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief Convert TFPlaceholder to Canonical Pull
+ *
+ * @note GraphInputIndex is copied to Pull
+ */
+class PlaceholderCanonicalizer : public SimpleNodeTransform<::moco::TFPlaceholder>
+{
+public:
+ const char *name(void) const final { return "PlaceholderCanonicalizer"; }
+
+public:
+ bool transform(moco::TFPlaceholder *) const final;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_PLACEHOLDER_CANONICALIZER_H__
diff --git a/compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.cpp
index 9ad15150a..a448d85fa 100644
--- a/compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.cpp
@@ -16,8 +16,7 @@
#include "RealDivCanonicalizer.h"
-#include "Dialect/TFDialect.h"
-#include "Dialect/TFNodes.h"
+#include <moco/IR/TFDialect.h>
#include "TFEltwiseBinaryCanonicalzeHelper.h"
@@ -26,25 +25,9 @@ namespace moco
namespace tf
{
-bool RealDivCanonicalizer::run(loco::Graph *graph)
+bool RealDivCanonicalizer::transform(moco::TFRealDiv *node) const
{
- auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
- bool changed = false;
-
- for (auto node : active_nodes)
- {
- if (node->dialect() == TFDialect::get())
- {
- auto tf_node = dynamic_cast<moco::tf::TFRealDiv *>(node);
- if (tf_node != nullptr)
- {
- if (canonicalize_eltwise_binary_node(tf_node))
- changed = true;
- }
- }
- }
-
- return changed;
+ return canonicalize_eltwise_binary_node(node);
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.h
index 8e6953396..76e1bd377 100644
--- a/compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.h
@@ -18,6 +18,9 @@
#define __MOCO_TF_REALDIV_CANONICALIZER_H__
#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -29,13 +32,13 @@ namespace tf
/**
* @brief Convert TFRealDiv to Canonical EltwiseDiv
*/
-class RealDivCanonicalizer : public Transform
+class RealDivCanonicalizer : public SimpleNodeTransform<moco::TFRealDiv>
{
public:
const char *name(void) const final { return "RealDivCanonicalizer"; }
public:
- bool run(loco::Graph *graph) override;
+ bool transform(moco::TFRealDiv *) const final;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.cpp
index 07657244b..c53a880a8 100644
--- a/compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.cpp
@@ -16,17 +16,14 @@
#include "Relu6Canonicalizer.h"
-#include "Dialect/TFDialect.h"
-#include "Dialect/TFNodes.h"
-#include "Dialect/TFNodeVisitor.h"
-#include "Dialect/TFNodeImpl.h"
+#include <moco/IR/TFDialect.h>
#include <stdex/Memory.h>
namespace
{
-bool canonicalize_relu6(loco::Graph *graph, moco::tf::TFRelu6 *node)
+bool canonicalize_relu6(loco::Graph *graph, moco::TFRelu6 *node)
{
/**
* @note This will replace TFRelu6 node with Canonical ReLU6
@@ -64,25 +61,9 @@ namespace moco
namespace tf
{
-bool Relu6Canonicalizer::run(loco::Graph *graph)
+bool Relu6Canonicalizer::transform(TFRelu6 *node) const
{
- auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
- bool changed = false;
-
- for (auto node : active_nodes)
- {
- if (node->dialect() == TFDialect::get())
- {
- auto tf_node = dynamic_cast<moco::tf::TFRelu6 *>(node);
- if (tf_node != nullptr)
- {
- if (canonicalize_relu6(graph, tf_node))
- changed = true;
- }
- }
- }
-
- return changed;
+ return canonicalize_relu6(node->graph(), node);
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.h b/compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.h
index aa1580f28..d8ad5db8e 100644
--- a/compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.h
@@ -18,6 +18,9 @@
#define __MOCO_TF_RELU6_CANONICALIZER_H__
#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -29,13 +32,13 @@ namespace tf
/**
* @brief Convert TFRelu6 to Canonical ReLU6
*/
-class Relu6Canonicalizer : public Transform
+class Relu6Canonicalizer : public SimpleNodeTransform<moco::TFRelu6>
{
public:
const char *name(void) const final { return "Relu6Canonicalizer"; }
public:
- bool run(loco::Graph *graph) override;
+ bool transform(moco::TFRelu6 *) const final;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp
index 20cd0bab9..7965dc931 100644
--- a/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp
@@ -16,17 +16,14 @@
#include "ReluCanonicalizer.h"
-#include "Dialect/TFDialect.h"
-#include "Dialect/TFNodes.h"
-#include "Dialect/TFNodeVisitor.h"
-#include "Dialect/TFNodeImpl.h"
+#include <moco/IR/TFDialect.h>
#include <stdex/Memory.h>
namespace
{
-bool canonicalize_relu(loco::Graph *graph, moco::tf::TFRelu *node)
+bool canonicalize_relu(loco::Graph *graph, moco::TFRelu *node)
{
/**
* @note This will replace TFRelu node with Canonical ReLU
@@ -64,25 +61,9 @@ namespace moco
namespace tf
{
-bool ReluCanonicalizer::run(loco::Graph *graph)
+bool ReluCanonicalizer::transform(TFRelu *node) const
{
- auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
- bool changed = false;
-
- for (auto node : active_nodes)
- {
- if (node->dialect() == TFDialect::get())
- {
- auto tf_node = dynamic_cast<moco::tf::TFRelu *>(node);
- if (tf_node != nullptr)
- {
- if (canonicalize_relu(graph, tf_node))
- changed = true;
- }
- }
- }
-
- return changed;
+ return canonicalize_relu(node->graph(), node);
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.h
index 97adba308..e27abe158 100644
--- a/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.h
@@ -18,6 +18,9 @@
#define __MOCO_TF_RELU_CANONICALIZER_H__
#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -29,13 +32,13 @@ namespace tf
/**
* @brief Convert TFRelu to Canonical ReLU
*/
-class ReluCanonicalizer : public Transform
+class ReluCanonicalizer : public SimpleNodeTransform<moco::TFRelu>
{
public:
const char *name(void) const final { return "ReluCanonicalizer"; }
public:
- bool run(loco::Graph *graph) override;
+ bool transform(moco::TFRelu *) const final;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.cpp
index 3771d549a..b944568e0 100644
--- a/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.cpp
@@ -16,11 +16,11 @@
#include "ReshapeCanonicalizer.h"
-#include "Dialect/TFDialect.h"
-#include "Dialect/TFNodeImpl.h"
+#include <moco/IR/TFDialect.h>
#include <moco/Log.h>
#include <plier/tf/Convert.h>
+#include <oops/UserExn.h>
#include <cassert>
@@ -31,7 +31,7 @@ using plier::tf::DataLayout;
/**
* @brief Check whether given 'new shape' arg is a fixed shape input for Reshape
*
- * ConstNode can be moco::tf::TFConst or loco::ConstGen
+ * ConstNode can be moco::TFConst or loco::ConstGen
*/
template <typename ConstNode> bool is_fixed_shape_input(ConstNode *const_shape_input)
{
@@ -54,13 +54,16 @@ template <typename ConstNode> bool is_fixed_shape_input(ConstNode *const_shape_i
// has wildcard dimension, i.e. dynamic reshape
return false;
}
- assert(shape_dim >= 1 && "Unknown behavior: New shape of Reshape has invalid dimension");
+ if (!(shape_dim >= 1))
+ {
+ throw oops::UserExn("New shape of Reshape has invalid dimension");
+ }
}
return true;
}
/// @note Currently only supports to canonicalize Fixed Reshape
-bool canonicalize_reshape(loco::Graph *graph, moco::tf::TFReshape *node)
+bool canonicalize_reshape(loco::Graph *graph, moco::TFReshape *node)
{
LOGGER(l);
INFO(l) << "TFNodeCanonicalize TFReshape begin";
@@ -99,14 +102,17 @@ bool canonicalize_reshape(loco::Graph *graph, moco::tf::TFReshape *node)
// Supports 2 cases for Reshape's shape input:
// TF-dialect TFConst or Canonical ConstGen
loco::Node *shape_input = node->shape();
- auto tfconst_shape_input = dynamic_cast<moco::tf::TFConst *>(shape_input);
+ auto tfconst_shape_input = dynamic_cast<moco::TFConst *>(shape_input);
auto constgen_shape_input = dynamic_cast<loco::ConstGen *>(shape_input);
if (tfconst_shape_input)
{
// Only support fixed reshape
// TODO support dynamic reshape
- assert(is_fixed_shape_input(tfconst_shape_input));
+ if (!(is_fixed_shape_input(tfconst_shape_input)))
+ {
+ throw oops::UserExn("Supports only fixed reshape", node->name());
+ }
auto rank = tfconst_shape_input->dim(0).value();
fixed_reshape->rank(rank);
@@ -118,7 +124,10 @@ bool canonicalize_reshape(loco::Graph *graph, moco::tf::TFReshape *node)
else if (constgen_shape_input)
{
// ditto
- assert(is_fixed_shape_input(constgen_shape_input));
+ if (!(is_fixed_shape_input(constgen_shape_input)))
+ {
+ throw oops::UserExn("Supports only fixed reshape", node->name());
+ }
auto rank = constgen_shape_input->dim(0).value();
fixed_reshape->rank(rank);
@@ -130,7 +139,7 @@ bool canonicalize_reshape(loco::Graph *graph, moco::tf::TFReshape *node)
else
{
// TODO support dynamic reshape from not const node
- throw std::runtime_error("ReshapeCanonicalizer: only support const node as input shape");
+ throw oops::UserExn("Supports only const node as input shape", node->name());
}
// replace
@@ -151,25 +160,9 @@ namespace moco
namespace tf
{
-bool ReshapeCanonicalizer::run(loco::Graph *graph)
+bool ReshapeCanonicalizer::transform(TFReshape *node) const
{
- auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
- bool changed = false;
-
- for (auto node : active_nodes)
- {
- if (node->dialect() == TFDialect::get())
- {
- auto tf_reshape = dynamic_cast<moco::tf::TFReshape *>(node);
- if (tf_reshape != nullptr)
- {
- if (canonicalize_reshape(graph, tf_reshape))
- changed = true;
- }
- }
- }
-
- return changed;
+ return canonicalize_reshape(node->graph(), node);
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.h
index c9deee7a4..1a792024e 100644
--- a/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.h
@@ -18,6 +18,9 @@
#define __MOCO_TF_RESHAPE_CANONICALIZER_H__
#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -29,13 +32,13 @@ namespace tf
/**
* @brief Convert TFReshape to Canonical Reshape
*/
-class ReshapeCanonicalizer : public Transform
+class ReshapeCanonicalizer : public SimpleNodeTransform<moco::TFReshape>
{
public:
const char *name(void) const final { return "ReshapeCanonicalizer"; }
public:
- bool run(loco::Graph *graph) override;
+ bool transform(moco::TFReshape *) const final;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp
index b4fbcac3c..c31dbf6d6 100644
--- a/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp
@@ -16,29 +16,25 @@
#include "RsqrtCanonicalizer.h"
-#include "Annotations/ShapeInferenceData.h"
-
-#include "Dialect/TFDialect.h"
-#include "Dialect/TFNodes.h"
-#include "Dialect/TFNodeVisitor.h"
-#include "Dialect/TFNodeImpl.h"
+#include <moco/IR/TFDialect.h>
+#include <moco/Support/TFShapeInferenceHelper.h>
#include <moco/Log.h>
#include <loco/Service/TypeInference.h>
#include <stdex/Memory.h>
+#include <oops/UserExn.h>
namespace
{
template <typename T>
-void prepare_const_gen(loco::ConstGen *const_node, const moco::tf::ShapeInferenceData *shapedata,
- T value);
+bool prepare_const_gen(loco::ConstGen *const_node, const loco::TensorShape &tensorshape, T value);
template <>
-void prepare_const_gen<float>(loco::ConstGen *const_node,
- const moco::tf::ShapeInferenceData *shapedata, float value)
+bool prepare_const_gen<float>(loco::ConstGen *const_node, const loco::TensorShape &tensorshape,
+ float value)
{
LOGGER(l);
@@ -47,18 +43,18 @@ void prepare_const_gen<float>(loco::ConstGen *const_node,
auto dtype = loco::DataType::FLOAT32;
const_node->dtype(dtype);
- auto rank = shapedata->rank();
+ auto rank = tensorshape.rank();
const_node->rank(rank);
for (uint32_t r = 0; r < rank; ++r)
{
- if (shapedata->dim(r).known())
- const_node->dim(r) = shapedata->dim(r);
+ if (tensorshape.dim(r).known())
+ const_node->dim(r) = tensorshape.dim(r);
else
- throw std::runtime_error("Cannot handle unknown shape");
+ return false;
- assert(shapedata->dim(r).value() > 0);
+ assert(tensorshape.dim(r).value() > 0);
- const_num_elements *= shapedata->dim(r).value();
+ const_num_elements *= tensorshape.dim(r).value();
}
INFO(l) << "prepare_const_gen : Elements = " << const_num_elements;
@@ -68,9 +64,11 @@ void prepare_const_gen<float>(loco::ConstGen *const_node,
{
const_node->at<loco::DataType::FLOAT32>(i) = value;
}
+
+ return true;
}
-bool canonicalize_rsqrt(loco::Graph *graph, moco::tf::TFRsqrt *node)
+bool canonicalize_rsqrt(loco::Graph *graph, moco::TFRsqrt *node)
{
/**
* @note This will replace TFRsqrt node with Canonical EltwiseSqrt + EltwiseRealDiv
@@ -91,13 +89,14 @@ bool canonicalize_rsqrt(loco::Graph *graph, moco::tf::TFRsqrt *node)
* TFRsqrt is converted to 1 / EltwiseSqrt
*/
- auto rsqrt_shapedata = node->annot<moco::tf::ShapeInferenceData>();
- if (rsqrt_shapedata == nullptr)
+ auto nodeshape = moco::node_shape(node);
+ if (nodeshape.domain() == loco::Domain::Unknown)
{
// We need this shape information
assert(false); // this shouldn't happen, let's add an alarm
return false;
}
+ auto tensorshape = nodeshape.as<loco::TensorShape>();
if (!loco::dtype_known(node))
{
@@ -114,11 +113,12 @@ bool canonicalize_rsqrt(loco::Graph *graph, moco::tf::TFRsqrt *node)
switch (dtype)
{
case loco::DataType::FLOAT32:
- prepare_const_gen<float>(const_node, rsqrt_shapedata, 1.0f);
+ if (!prepare_const_gen<float>(const_node, tensorshape, 1.0f))
+ throw oops::UserExn("Cannot handle unknown shape", node->name());
break;
default:
- throw std::runtime_error("NYI for this DataType");
+ throw oops::UserExn("Unsupported data type", node->name());
}
auto node_A = node->x();
@@ -141,25 +141,9 @@ namespace moco
namespace tf
{
-bool RsqrtCanonicalizer::run(loco::Graph *graph)
+bool RsqrtCanonicalizer::transform(TFRsqrt *node) const
{
- auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
- bool changed = false;
-
- for (auto node : active_nodes)
- {
- if (node->dialect() == TFDialect::get())
- {
- auto tf_node = dynamic_cast<moco::tf::TFRsqrt *>(node);
- if (tf_node != nullptr)
- {
- if (canonicalize_rsqrt(graph, tf_node))
- changed = true;
- }
- }
- }
-
- return changed;
+ return canonicalize_rsqrt(node->graph(), node);
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.h
index a58c0adcb..7fd4ff697 100644
--- a/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.h
@@ -18,6 +18,9 @@
#define __MOCO_TF_RSQRT_CANONICALIZER_H__
#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -29,13 +32,13 @@ namespace tf
/**
* @brief Convert TFRsqrt to Canonical EltwiseDiv + EltwiseSqrt
*/
-class RsqrtCanonicalizer : public Transform
+class RsqrtCanonicalizer : public SimpleNodeTransform<moco::TFRsqrt>
{
public:
const char *name(void) const final { return "RsqrtCanonicalizer"; }
public:
- bool run(loco::Graph *graph) override;
+ bool transform(moco::TFRsqrt *) const final;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp
index 3b5043fa7..98af7b693 100644
--- a/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp
@@ -16,19 +16,15 @@
#include "SoftmaxCanonicalizer.h"
-#include "Annotations/ShapeInferenceData.h"
-
-#include "Dialect/TFDialect.h"
-#include "Dialect/TFNodes.h"
-#include "Dialect/TFNodeVisitor.h"
-#include "Dialect/TFNodeImpl.h"
+#include <moco/IR/TFDialect.h>
+#include <moco/Support/TFShapeInferenceHelper.h>
#include <moco/Log.h>
namespace
{
-bool canonicalize_softmax(loco::Graph *graph, moco::tf::TFSoftmax *node)
+bool canonicalize_softmax(loco::Graph *graph, moco::TFSoftmax *node)
{
LOGGER(l);
@@ -46,12 +42,11 @@ bool canonicalize_softmax(loco::Graph *graph, moco::tf::TFSoftmax *node)
* In ---- TensorSoftmax ----- Out(s)
*/
- auto softmax_shape = node->annot<moco::tf::ShapeInferenceData>();
-
+ auto nodeshape = moco::node_shape(node);
// Canonicalization into TensorSoftmax is valid when softmax has shape info
- assert(softmax_shape);
+ assert(nodeshape.domain() != loco::Domain::Unknown);
- auto softmax_tensor_shape = softmax_shape->tensor_shape();
+ auto softmax_tensor_shape = nodeshape.as<loco::TensorShape>();
// Create loco node to replace
auto softmax = graph->nodes()->create<loco::TensorSoftmax>();
@@ -74,25 +69,9 @@ namespace moco
namespace tf
{
-bool SoftmaxCanonicalizer::run(loco::Graph *graph)
+bool SoftmaxCanonicalizer::transform(TFSoftmax *node) const
{
- auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
- bool changed = false;
-
- for (auto node : active_nodes)
- {
- if (node->dialect() == TFDialect::get())
- {
- auto tf_softmax = dynamic_cast<moco::tf::TFSoftmax *>(node);
- if (tf_softmax != nullptr)
- {
- if (canonicalize_softmax(graph, tf_softmax))
- changed = true;
- }
- }
- }
-
- return changed;
+ return canonicalize_softmax(node->graph(), node);
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.h
index 6debf4194..ebaf04cfe 100644
--- a/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.h
@@ -18,6 +18,9 @@
#define __MOCO_TF_SOFTMAx_CANONICALIZER_H__
#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -29,13 +32,13 @@ namespace tf
/**
* @brief Canonicalize TF-dialect TFSoftmax into canonical Softmax node
*/
-class SoftmaxCanonicalizer : public Transform
+class SoftmaxCanonicalizer : public SimpleNodeTransform<moco::TFSoftmax>
{
public:
const char *name(void) const final { return "SoftmaxCanonicalizer"; }
public:
- bool run(loco::Graph *graph) override;
+ bool transform(moco::TFSoftmax *) const final;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.cpp
index 347265121..89b9b8a44 100644
--- a/compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.cpp
@@ -16,15 +16,12 @@
#include "SqrtCanonicalizer.h"
-#include "Dialect/TFDialect.h"
-#include "Dialect/TFNodes.h"
-#include "Dialect/TFNodeVisitor.h"
-#include "Dialect/TFNodeImpl.h"
+#include <moco/IR/TFDialect.h>
namespace
{
-bool canonicalize_sqrt(loco::Graph *graph, moco::tf::TFSqrt *node)
+bool canonicalize_sqrt(loco::Graph *graph, moco::TFSqrt *node)
{
/**
* @note This will replace TFSqrt node with Canonical EltwiseSqrt
@@ -62,25 +59,9 @@ namespace moco
namespace tf
{
-bool SqrtCanonicalizer::run(loco::Graph *graph)
+bool SqrtCanonicalizer::transform(TFSqrt *node) const
{
- auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
- bool changed = false;
-
- for (auto node : active_nodes)
- {
- if (node->dialect() == TFDialect::get())
- {
- auto tf_node = dynamic_cast<moco::tf::TFSqrt *>(node);
- if (tf_node != nullptr)
- {
- if (canonicalize_sqrt(graph, tf_node))
- changed = true;
- }
- }
- }
-
- return changed;
+ return canonicalize_sqrt(node->graph(), node);
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.h
index b4e6da09a..3f7ffead8 100644
--- a/compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.h
@@ -18,6 +18,9 @@
#define __MOCO_TF_SQRT_CANONICALIZER_H__
#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -29,13 +32,13 @@ namespace tf
/**
* @brief Convert TFsqrt to Canonical EltwiseSqrt
*/
-class SqrtCanonicalizer : public Transform
+class SqrtCanonicalizer : public SimpleNodeTransform<moco::TFSqrt>
{
public:
const char *name(void) const final { return "SqrtCanonicalizer"; }
public:
- bool run(loco::Graph *graph) override;
+ bool transform(moco::TFSqrt *) const final;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.cpp
deleted file mode 100644
index 4eb7a7217..000000000
--- a/compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.cpp
+++ /dev/null
@@ -1,115 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "SquaredDifferenceCanonicalizer.h"
-
-#include "Dialect/TFDialect.h"
-#include "Dialect/TFNodes.h"
-#include "Dialect/TFNodeVisitor.h"
-#include "Dialect/TFNodeImpl.h"
-
-#include <loco/IR/NodeShape.h>
-#include <loco/Service/ShapeInference.h>
-
-#include <stdex/Memory.h>
-
-namespace
-{
-
-bool canonicalize_sqdiff(loco::Graph *graph, moco::tf::TFSquaredDifference *node)
-{
- /**
- * @note This will replace TFSquaredDifference node with Canonical EltwiseSub and EltwiseMul
- *
- * Before
- * A --- TFSquaredDifference -- C
- * B --/
- * After
- * A --- TFSquaredDifference --
- * B --/
- * A --- EltwiseSub == EltwiseMul -- C
- * B --/
- * Where
- * A : x of TFSquaredDifference
- * B : y of TFSquaredDifference
- * C : a node that uses TFSquaredDifference as an input
- * TFSquaredDifference is disconnected from C
- * A and B are drawn multiple times to simplify the diagram
- */
-
- auto node_A = node->x();
- auto node_B = node->y();
-
- if (!loco::shape_known(node_A) || !loco::shape_known(node_B))
- {
- // Wait for shape inference
- return false;
- }
-
- const auto &x_shape = loco::shape_get(node_A);
- const auto &y_shape = loco::shape_get(node_B);
-
- if (!(x_shape == y_shape))
- {
- // TODO support broadcast
- return false;
- }
-
- auto sub_node = graph->nodes()->create<loco::EltwiseSub>();
- auto mul_node = graph->nodes()->create<loco::EltwiseMul>();
-
- // update connections
- sub_node->lhs(node_A);
- sub_node->rhs(node_B);
- mul_node->lhs(sub_node);
- mul_node->rhs(sub_node);
-
- // replace node
- replace(node).with(mul_node);
-
- return true;
-}
-
-} // namespace
-
-namespace moco
-{
-namespace tf
-{
-
-bool SquaredDifferenceCanonicalizer::run(loco::Graph *graph)
-{
- auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
- bool changed = false;
-
- for (auto node : active_nodes)
- {
- if (node->dialect() == TFDialect::get())
- {
- auto tf_node = dynamic_cast<moco::tf::TFSquaredDifference *>(node);
- if (tf_node != nullptr)
- {
- if (canonicalize_sqdiff(graph, tf_node))
- changed = true;
- }
- }
- }
-
- return changed;
-}
-
-} // namespace tf
-} // namespace moco
diff --git a/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp
index a3fcc3b47..f5b991206 100644
--- a/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp
@@ -16,19 +16,15 @@
#include "SqueezeCanonicalizer.h"
-#include "Annotations/ShapeInferenceData.h"
-
-#include "Dialect/TFDialect.h"
-#include "Dialect/TFNodes.h"
-#include "Dialect/TFNodeVisitor.h"
-#include "Dialect/TFNodeImpl.h"
+#include <moco/IR/TFDialect.h>
+#include <moco/Support/TFShapeInferenceHelper.h>
#include <moco/Log.h>
namespace
{
-bool canonicalize_squeeze_to_reshape(loco::Graph *graph, moco::tf::TFSqueeze *node)
+bool canonicalize_squeeze_to_reshape(loco::Graph *graph, moco::TFSqueeze *node)
{
LOGGER(l);
@@ -46,12 +42,12 @@ bool canonicalize_squeeze_to_reshape(loco::Graph *graph, moco::tf::TFSqueeze *no
* In ---- FixedReshape ----- Out(s)
*/
- auto squeeze_shape = node->annot<moco::tf::ShapeInferenceData>();
+ auto nodeshape = moco::node_shape(node);
// canonicalize into FixedReshape is valid when squeeze has shape info
// TODO Support general Squeeze case
- assert(squeeze_shape);
+ assert(nodeshape.domain() != loco::Domain::Unknown);
- auto squeeze_tensor_shape = squeeze_shape->tensor_shape();
+ auto squeeze_tensor_shape = nodeshape.as<loco::TensorShape>();
// Create loco node to replace
auto reshape = graph->nodes()->create<loco::FixedReshape>();
@@ -81,25 +77,9 @@ namespace moco
namespace tf
{
-bool SqueezeCanonicalizer::run(loco::Graph *graph)
+bool SqueezeCanonicalizer::transform(TFSqueeze *node) const
{
- auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
- bool changed = false;
-
- for (auto node : active_nodes)
- {
- if (node->dialect() == TFDialect::get())
- {
- auto tf_squeeze = dynamic_cast<moco::tf::TFSqueeze *>(node);
- if (tf_squeeze != nullptr)
- {
- if (canonicalize_squeeze_to_reshape(graph, tf_squeeze))
- changed = true;
- }
- }
- }
-
- return changed;
+ return canonicalize_squeeze_to_reshape(node->graph(), node);
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.h
index dc5b2d7b1..28a1442bd 100644
--- a/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.h
@@ -18,6 +18,9 @@
#define __MOCO_TF_SQUEEZE_CANONICALIZER_H__
#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -31,13 +34,13 @@ namespace tf
*
* @note There is no canonical Squeeze node
*/
-class SqueezeCanonicalizer : public Transform
+class SqueezeCanonicalizer : public SimpleNodeTransform<moco::TFSqueeze>
{
public:
const char *name(void) const final { return "SqueezeCanonicalizer"; }
public:
- bool run(loco::Graph *graph) override;
+ bool transform(moco::TFSqueeze *) const final;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp
index a52af05a5..574fa3993 100644
--- a/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp
@@ -16,17 +16,14 @@
#include "StopGradientCanonicalizer.h"
-#include "Dialect/TFDialect.h"
-#include "Dialect/TFNodes.h"
-#include "Dialect/TFNodeVisitor.h"
-#include "Dialect/TFNodeImpl.h"
+#include <moco/IR/TFDialect.h>
#include <moco/Log.h>
namespace
{
-bool canonicalize_stopgradient(loco::Graph *graph, moco::tf::TFStopGradient *node)
+bool canonicalize_stopgradient(loco::Graph *graph, moco::TFStopGradient *node)
{
LOGGER(l);
@@ -65,25 +62,9 @@ namespace moco
namespace tf
{
-bool StopGradientCanonicalizer::run(loco::Graph *graph)
+bool StopGradientCanonicalizer::transform(TFStopGradient *node) const
{
- auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
- bool changed = false;
-
- for (auto node : active_nodes)
- {
- if (node->dialect() == TFDialect::get())
- {
- auto tf_stopgradient = dynamic_cast<moco::tf::TFStopGradient *>(node);
- if (tf_stopgradient != nullptr)
- {
- if (canonicalize_stopgradient(graph, tf_stopgradient))
- changed = true;
- }
- }
- }
-
- return changed;
+ return canonicalize_stopgradient(node->graph(), node);
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.h
index a23a801f0..6a17728a6 100644
--- a/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.h
@@ -18,6 +18,9 @@
#define __MOCO_TF_STOPGRADIENT_CANONICALIZER_H__
#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -29,13 +32,13 @@ namespace tf
/**
* @brief Canonicalize TF-dialect TFStopGradient into canonical Forward node
*/
-class StopGradientCanonicalizer : public Transform
+class StopGradientCanonicalizer : public SimpleNodeTransform<moco::TFStopGradient>
{
public:
const char *name(void) const final { return "StopGradientCanonicalizer"; }
public:
- bool run(loco::Graph *graph) override;
+ bool transform(moco::TFStopGradient *) const final;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/SubCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/SubCanonicalizer.cpp
index 21f4210eb..c518b7d64 100644
--- a/compiler/moco-tf/src/Canonicalization/SubCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/SubCanonicalizer.cpp
@@ -16,8 +16,7 @@
#include "SubCanonicalizer.h"
-#include "Dialect/TFDialect.h"
-#include "Dialect/TFNodes.h"
+#include <moco/IR/TFDialect.h>
#include "TFEltwiseBinaryCanonicalzeHelper.h"
@@ -26,25 +25,9 @@ namespace moco
namespace tf
{
-bool SubCanonicalizer::run(loco::Graph *graph)
+bool SubCanonicalizer::transform(moco::TFSub *node) const
{
- auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
- bool changed = false;
-
- for (auto node : active_nodes)
- {
- if (node->dialect() == TFDialect::get())
- {
- auto tf_node = dynamic_cast<moco::tf::TFSub *>(node);
- if (tf_node != nullptr)
- {
- if (canonicalize_eltwise_binary_node(tf_node))
- changed = true;
- }
- }
- }
-
- return changed;
+ return canonicalize_eltwise_binary_node(node);
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/SubCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/SubCanonicalizer.h
index 4ab470685..f715cc86c 100644
--- a/compiler/moco-tf/src/Canonicalization/SubCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/SubCanonicalizer.h
@@ -18,6 +18,9 @@
#define __MOCO_TF_SUB_CANONICALIZER_H__
#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -29,13 +32,13 @@ namespace tf
/**
* @brief Convert TFSub to Canonical EltwiseSub
*/
-class SubCanonicalizer : public Transform
+class SubCanonicalizer : public SimpleNodeTransform<moco::TFSub>
{
public:
const char *name(void) const final { return "SubCanonicalizer"; }
public:
- bool run(loco::Graph *graph) override;
+ bool transform(moco::TFSub *) const final;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/TFPushCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/TFPushCanonicalizer.cpp
new file mode 100644
index 000000000..081e0e5f9
--- /dev/null
+++ b/compiler/moco-tf/src/Canonicalization/TFPushCanonicalizer.cpp
@@ -0,0 +1,74 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFPushCanonicalizer.h"
+
+#include <moco/IR/TFDialect.h>
+
+#include <stdex/Memory.h>
+
+namespace
+{
+
+bool canonicalize_push(loco::Graph *graph, moco::TFPush *node)
+{
+ /**
+ * @note This will replace TFRelu node with Canonical ReLU
+ *
+ * Before
+ * A --- TFPush
+ * After
+ * +- TFPush
+ * |
+ * A -+- Push
+ *
+ * Where
+ * A : from of TFPush
+ * TFPush will have no GraphOutputIndex
+ * Push will have GraphOutputIndex that from TFPush
+ */
+
+ auto push_node = graph->nodes()->create<loco::Push>();
+
+ auto node_A = node->from();
+
+ // update connections
+ push_node->from(node_A);
+
+ // update output index
+ push_node->index(node->index());
+ node->index_reset();
+
+ // replace node
+ replace(node).with(push_node);
+
+ return true;
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+bool TFPushCanonicalizer::transform(TFPush *node) const
+{
+ return canonicalize_push(node->graph(), node);
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/Canonicalization/TFPushCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/TFPushCanonicalizer.h
new file mode 100644
index 000000000..569a71f82
--- /dev/null
+++ b/compiler/moco-tf/src/Canonicalization/TFPushCanonicalizer.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_PUSH_CANONICALIZER_H__
+#define __MOCO_TF_PUSH_CANONICALIZER_H__
+
+#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/TFNodes.h>
+
+#include <loco.h>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief Convert TFPush to Canonical Push
+ */
+class TFPushCanonicalizer : public SimpleNodeTransform<moco::TFPush>
+{
+public:
+ const char *name(void) const final { return "TFPushCanonicalizer"; }
+
+public:
+ bool transform(moco::TFPush *) const final;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_PUSH_CANONICALIZER_H__
diff --git a/compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.cpp
index 9b7b073e1..3f48a50fc 100644
--- a/compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.cpp
@@ -16,17 +16,14 @@
#include "TanhCanonicalizer.h"
-#include "Dialect/TFDialect.h"
-#include "Dialect/TFNodes.h"
-#include "Dialect/TFNodeVisitor.h"
-#include "Dialect/TFNodeImpl.h"
+#include <moco/IR/TFDialect.h>
#include <stdex/Memory.h>
namespace
{
-bool canonicalize_tanh(loco::Graph *graph, moco::tf::TFTanh *node)
+bool canonicalize_tanh(loco::Graph *graph, moco::TFTanh *node)
{
/**
* @note This will replace TFTanh node with Canonical Tanh
@@ -64,25 +61,9 @@ namespace moco
namespace tf
{
-bool TanhCanonicalizer::run(loco::Graph *graph)
+bool TanhCanonicalizer::transform(TFTanh *node) const
{
- auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
- bool changed = false;
-
- for (auto node : active_nodes)
- {
- if (node->dialect() == TFDialect::get())
- {
- auto tf_node = dynamic_cast<moco::tf::TFTanh *>(node);
- if (tf_node != nullptr)
- {
- if (canonicalize_tanh(graph, tf_node))
- changed = true;
- }
- }
- }
-
- return changed;
+ return canonicalize_tanh(node->graph(), node);
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.h
index cf566a4d4..af5e79fb5 100644
--- a/compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.h
@@ -18,6 +18,9 @@
#define __MOCO_TF_TANH_CANONICALIZER_H__
#include "Transform.h"
+#include "SimpleNodeTransform.h"
+
+#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -29,13 +32,13 @@ namespace tf
/**
* @brief Convert TFTanh to Canonical Tanh
*/
-class TanhCanonicalizer : public Transform
+class TanhCanonicalizer : public SimpleNodeTransform<moco::TFTanh>
{
public:
const char *name(void) const final { return "TanhCanonicalizer"; }
public:
- bool run(loco::Graph *graph) override;
+ bool transform(moco::TFTanh *) const override;
};
} // namespace tf