summaryrefslogtreecommitdiff
path: root/compiler/exo/src/Pass
diff options
context:
space:
mode:
authorChunseok Lee <chunseok.lee@samsung.com>2020-04-23 14:45:49 +0900
committerChunseok Lee <chunseok.lee@samsung.com>2020-04-23 14:45:49 +0900
commite2ef8438a24f7c56a0744eb579a6e293ee2fbf8e (patch)
tree44a1a7951d168dd4370e13593ed03f4bc6d920c5 /compiler/exo/src/Pass
parent302e6564a7a76109e1178207e44e45a58631c477 (diff)
downloadnnfw-e2ef8438a24f7c56a0744eb579a6e293ee2fbf8e.tar.gz
nnfw-e2ef8438a24f7c56a0744eb579a6e293ee2fbf8e.tar.bz2
nnfw-e2ef8438a24f7c56a0744eb579a6e293ee2fbf8e.zip
Imported Upstream version 1.4.0upstream/1.4.0submit/tizen/20200423.054851
Diffstat (limited to 'compiler/exo/src/Pass')
-rw-r--r--compiler/exo/src/Pass/FoldReshapeOfConstPass.cpp116
-rw-r--r--compiler/exo/src/Pass/FoldReshapeOfConstPass.h46
-rw-r--r--compiler/exo/src/Pass/FoldTransposeOfConstPass.cpp154
-rw-r--r--compiler/exo/src/Pass/FoldTransposeOfConstPass.h46
-rw-r--r--compiler/exo/src/Pass/FuseBiasAddPass.cpp362
-rw-r--r--compiler/exo/src/Pass/FuseBiasAddPass.h61
-rw-r--r--compiler/exo/src/Pass/FuseBiasAddPass.test.cpp361
-rw-r--r--compiler/exo/src/Pass/FuseInstanceNormPass.cpp402
-rw-r--r--compiler/exo/src/Pass/FuseInstanceNormPass.h40
-rw-r--r--compiler/exo/src/Pass/FuseReluPass.cpp115
-rw-r--r--compiler/exo/src/Pass/FuseReluPass.h40
-rw-r--r--compiler/exo/src/Pass/FuseReluPass.test.cpp115
-rw-r--r--compiler/exo/src/Pass/FuseRsqrtPass.cpp95
-rw-r--r--compiler/exo/src/Pass/FuseRsqrtPass.h47
-rw-r--r--compiler/exo/src/Pass/FuseSquaredDifferencePass.cpp86
-rw-r--r--compiler/exo/src/Pass/FuseSquaredDifferencePass.h49
-rw-r--r--compiler/exo/src/Pass/MergeConcatNodesPass.cpp191
-rw-r--r--compiler/exo/src/Pass/MergeConcatNodesPass.h41
-rw-r--r--compiler/exo/src/Pass/ShapeInferencePass.cpp59
-rw-r--r--compiler/exo/src/Pass/ShapeInferencePass.h40
-rw-r--r--compiler/exo/src/Pass/TypeInferencePass.cpp57
-rw-r--r--compiler/exo/src/Pass/TypeInferencePass.h42
22 files changed, 2565 insertions, 0 deletions
diff --git a/compiler/exo/src/Pass/FoldReshapeOfConstPass.cpp b/compiler/exo/src/Pass/FoldReshapeOfConstPass.cpp
new file mode 100644
index 000000000..0fdcea939
--- /dev/null
+++ b/compiler/exo/src/Pass/FoldReshapeOfConstPass.cpp
@@ -0,0 +1,116 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "FoldReshapeOfConstPass.h"
+
+#include "Check.h"
+
+#include "Dialect/IR/TFLNodes.h"
+#include "Dialect/IR/TFLNodeVisitor.h"
+
+#include <loco/Service/ShapeInference.h>
+
+#include <oops/InternalExn.h>
+
+namespace
+{
+
+/**
+ * @brief Check if node is TFLReshape and its input is TFLConst
+ * @return Casted TFLReshape for foldable candidate, nullptr otherwise
+ */
+locoex::TFLReshape *as_candidate(loco::Node *node)
+{
+ auto reshape = dynamic_cast<locoex::TFLReshape *>(node);
+ if (not reshape)
+ return nullptr;
+
+ // Only accept Constant input of Reshape
+ if (not dynamic_cast<locoex::TFLConst *>(reshape->tensor()))
+ return nullptr;
+
+ return reshape;
+}
+
+uint32_t volume(loco::Node *tensor_node)
+{
+ auto shape = loco::shape_get(tensor_node).as<loco::TensorShape>();
+
+ uint32_t vol = 1;
+ for (uint32_t axis = 0; axis < shape.rank(); ++axis)
+ vol *= shape.dim(axis).value();
+
+ return vol;
+}
+
+void fold_reshape_of_const(locoex::TFLReshape *reshape)
+{
+ const loco::DataType FLOAT32 = loco::DataType::FLOAT32;
+
+ auto const_orig = dynamic_cast<locoex::TFLConst *>(reshape->tensor());
+
+ // Exceptions
+ {
+ EXO_ASSERT(const_orig, "Only support for Reshape-Const pair");
+ // TODO support other data types
+ if (const_orig->dtype() != FLOAT32)
+ INTERNAL_EXN_V("NYI for this data type", oops::to_uint32(const_orig->dtype()));
+
+ if (volume(const_orig) != volume(reshape))
+ INTERNAL_EXN("New shape of Reshape is not matched");
+ }
+
+ auto new_shape = loco::shape_get(reshape).as<loco::TensorShape>();
+
+ // TFLConst to replace
+ auto const_new = reshape->graph()->nodes()->create<locoex::TFLConst>();
+
+ const_new->dtype(FLOAT32);
+ const_new->rank(new_shape.rank());
+ const_new->size<FLOAT32>(const_orig->size<FLOAT32>());
+ for (uint32_t axis = 0; axis < new_shape.rank(); ++axis)
+ const_new->dim(axis) = new_shape.dim(axis);
+
+ for (uint32_t i = 0; i < const_new->size<FLOAT32>(); ++i)
+ {
+ const_new->at<FLOAT32>(i) = const_orig->at<FLOAT32>(i);
+ }
+
+ // replace
+ loco::replace(reshape).with(const_new);
+}
+
+} // namespace
+
+namespace exo
+{
+
+bool FoldReshapeOfConstPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto reshape = as_candidate(node))
+ {
+ fold_reshape_of_const(reshape);
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Pass/FoldReshapeOfConstPass.h b/compiler/exo/src/Pass/FoldReshapeOfConstPass.h
new file mode 100644
index 000000000..10f8004bf
--- /dev/null
+++ b/compiler/exo/src/Pass/FoldReshapeOfConstPass.h
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __PASS_FOLD_RESHAPE_OF_CONST_PASS_H__
+#define __PASS_FOLD_RESHAPE_OF_CONST_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace exo
+{
+
+/**
+ * @brief Class to fuse TFLReshape + TFLConst into one equivalent TFLConst
+ *
+ * <before>
+ * TFLConst --- TFLReshape --- Out
+ *
+ * <after>
+ * TFLConst --- TFLReshape ---
+ * TFLConst (new) ------------ Out
+ *
+ * TODO This pass is for temporary. Deprecate this pass.
+ */
+struct FoldReshapeOfConstPass final : public logo::Pass
+{
+ const char *name(void) const final { return "exo::FoldReshapeOfConstPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace exo
+
+#endif // __PASS_FOLD_RESHAPE_OF_CONST_PASS_H__
diff --git a/compiler/exo/src/Pass/FoldTransposeOfConstPass.cpp b/compiler/exo/src/Pass/FoldTransposeOfConstPass.cpp
new file mode 100644
index 000000000..005c42944
--- /dev/null
+++ b/compiler/exo/src/Pass/FoldTransposeOfConstPass.cpp
@@ -0,0 +1,154 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "FoldTransposeOfConstPass.h"
+
+#include "Check.h"
+
+#include "Dialect/IR/TFLNodes.h"
+#include "Dialect/IR/TFLNodeVisitor.h"
+
+// TODO remove dependency to angkor
+#include <nncc/core/ADT/tensor/IndexEnumerator.h>
+#include <nncc/core/ADT/tensor/LexicalLayout.h>
+
+#include <oops/InternalExn.h>
+
+namespace
+{
+
+/**
+ * @brief Check if node is TFLTranspose and its input is TFLConst
+ * @return Casted TFLTranspose for foldable candidate, nullptr otherwise
+ */
+locoex::TFLTranspose *as_candidate(loco::Node *node)
+{
+ auto transpose = dynamic_cast<locoex::TFLTranspose *>(node);
+ if (not transpose)
+ return nullptr;
+
+ // Only accept Constant input of Transpose
+ if (not dynamic_cast<locoex::TFLConst *>(transpose->a()))
+ return nullptr;
+
+ // Only accept Constant permutation of Transpose
+ if (not dynamic_cast<locoex::TFLConst *>(transpose->perm()))
+ return nullptr;
+
+ return transpose;
+}
+
+nncc::core::ADT::tensor::Shape angkor_shape(locoex::TFLConst *node)
+{
+ nncc::core::ADT::tensor::Shape ret;
+
+ ret.resize(node->rank());
+ for (uint32_t axis = 0; axis < node->rank(); ++axis)
+ {
+ ret.dim(axis) = node->dim(axis).value();
+ }
+
+ return ret;
+}
+
+void fold_transpose_of_const(locoex::TFLTranspose *transpose)
+{
+ const loco::DataType FLOAT32 = loco::DataType::FLOAT32;
+ const loco::DataType S32 = loco::DataType::S32;
+
+ auto const_orig = dynamic_cast<locoex::TFLConst *>(transpose->a());
+ auto perm = dynamic_cast<locoex::TFLConst *>(transpose->perm());
+
+ // Exceptions
+ {
+ EXO_ASSERT(const_orig, "Only support for Transpose-Const pair");
+ // TODO support other data types
+ if (const_orig->dtype() != FLOAT32)
+ INTERNAL_EXN_V("NYI for this data type", oops::to_uint32(const_orig->dtype()));
+
+ EXO_ASSERT(perm, "Only support for constant permutation for Transpose");
+ // TODO support other data types
+ if (perm->dtype() != S32)
+ INTERNAL_EXN_V("NYI for this data type", oops::to_uint32(perm->dtype()));
+
+ auto okay = [&]() {
+ if (perm->rank() != 1)
+ return false;
+ if (perm->dim(0).value() != const_orig->rank())
+ return false;
+ return true;
+ };
+ if (not okay())
+ INTERNAL_EXN("Input and permutation for Transpose is not congruent");
+ }
+
+ uint32_t rank = const_orig->rank();
+
+ // TFLConst to replace
+ auto const_new = transpose->graph()->nodes()->create<locoex::TFLConst>();
+
+ const_new->dtype(FLOAT32);
+ const_new->rank(rank);
+ const_new->size<FLOAT32>(const_orig->size<FLOAT32>());
+ for (uint32_t axis = 0; axis < rank; ++axis)
+ const_new->dim(axis) = const_orig->dim(perm->at<S32>(axis)).value();
+
+ // TODO remove dependency to angkor
+ auto shape_orig = angkor_shape(const_orig);
+ auto shape_new = angkor_shape(const_new);
+
+ nncc::core::ADT::tensor::LexicalLayout l;
+ nncc::core::ADT::tensor::IndexEnumerator e{shape_new};
+
+ for (; e.valid(); e.advance())
+ {
+ loco::TensorIndex index_new = e.current();
+ loco::TensorIndex index_orig;
+
+ // Set original index from matching new index
+ index_orig.resize(rank);
+ for (uint32_t axis = 0; axis < rank; ++axis)
+ index_orig.at(perm->at<S32>(axis)) = index_new.at(axis);
+
+ const_new->at<FLOAT32>(l.offset(shape_new, index_new)) =
+ const_orig->at<FLOAT32>(l.offset(shape_orig, index_orig));
+ }
+
+ // replace
+ loco::replace(transpose).with(const_new);
+}
+
+} // namespace
+
+namespace exo
+{
+
+bool FoldTransposeOfConstPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto transpose = as_candidate(node))
+ {
+ fold_transpose_of_const(transpose);
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Pass/FoldTransposeOfConstPass.h b/compiler/exo/src/Pass/FoldTransposeOfConstPass.h
new file mode 100644
index 000000000..26656a118
--- /dev/null
+++ b/compiler/exo/src/Pass/FoldTransposeOfConstPass.h
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __PASS_FOLD_TRANSPOSE_OF_CONST_PASS_H__
+#define __PASS_FOLD_TRANSPOSE_OF_CONST_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace exo
+{
+
+/**
+ * @brief Class to fuse TFLTranspose + TFLConst into one equivalent TFLConst
+ *
+ * <before>
+ * TFLConst --- TFLTranspose --- Out
+ *
+ * <after>
+ * TFLConst --- TFLTranspose ---
+ * TFLConst (new) -------------- Out
+ *
+ * TODO This pass is for temporary. Deprecate this pass.
+ */
+struct FoldTransposeOfConstPass final : public logo::Pass
+{
+ const char *name(void) const final { return "exo::FoldTransposeOfConstPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace exo
+
+#endif // __PASS_FOLD_TRANSPOSE_OF_CONST_PASS_H__
diff --git a/compiler/exo/src/Pass/FuseBiasAddPass.cpp b/compiler/exo/src/Pass/FuseBiasAddPass.cpp
new file mode 100644
index 000000000..aab820995
--- /dev/null
+++ b/compiler/exo/src/Pass/FuseBiasAddPass.cpp
@@ -0,0 +1,362 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "FuseBiasAddPass.h"
+
+#include "Dialect/IR/TFLNodes.h"
+#include "Dialect/IR/TFLDialect.h"
+#include "Dialect/IR/TFLNodeVisitor.h"
+
+#include <loco/Service/TypeInference.h>
+#include <loco/Service/ShapeInference.h>
+
+#include <oops/InternalExn.h>
+
+#include <set>
+
+/*
+ Note: Terms for variables in this implementation is as follows:
+
+ ex) subgraph handled: TFLConv2D -------- TFLAdd
+ (or TFLDepthwiseConv2D) (or TFLSub)
+ | |
+ \|/ \|/
+ variable name : former latter
+ Type : FormerT LatterT
+ (shortened name from Mixin) (template type)
+*/
+namespace
+{
+
+using FormerT = locoex::TFLNodeMixin<locoex::TFLNodeTrait::Bias>;
+
+loco::Node *as_loco_node(FormerT *former)
+{
+ auto loco_node = dynamic_cast<loco::Node *>(former);
+ assert(loco_node != nullptr);
+
+ return loco_node;
+}
+
+locoex::TFLConst *get_const(loco::Node *x, loco::Node *y)
+{
+ if (auto const_node = dynamic_cast<locoex::TFLConst *>(x))
+ return const_node;
+ else if (auto const_node = dynamic_cast<locoex::TFLConst *>(y))
+ return const_node;
+
+ return nullptr;
+}
+
+FormerT *get_former(loco::Node *x, loco::Node *y)
+{
+ if (auto node = dynamic_cast<FormerT *>(x))
+ return node;
+ else if (auto node = dynamic_cast<FormerT *>(y))
+ return node;
+
+ return nullptr;
+}
+
+/// @brief Finds input that is TFLConst and set it to new_input
+void set_const_input(locoex::TFLNode *node, locoex::TFLConst *new_input)
+{
+ if (auto add = dynamic_cast<locoex::TFLAdd *>(node))
+ {
+ if (dynamic_cast<locoex::TFLConst *>(add->x()))
+ add->x(new_input);
+ else if (dynamic_cast<locoex::TFLConst *>(add->y()))
+ add->y(new_input);
+ else
+ assert(false and "One node should be TFLConst");
+
+ return;
+ }
+
+ if (auto sub = dynamic_cast<locoex::TFLSub *>(node))
+ {
+ if (dynamic_cast<locoex::TFLConst *>(sub->x()))
+ sub->x(new_input);
+ else if (dynamic_cast<locoex::TFLConst *>(sub->y()))
+ sub->y(new_input);
+ else
+ assert(false and "One node should be TFLConst");
+
+ return;
+ }
+
+ assert(false and "Param should be TFLAdd or TFLSub");
+}
+
+/**
+ * @brief Creates a TFLConst whose shape is [to] and values are all const_node->at(0),
+ * where const_node has only one element(a scalar or a tensor of shape [1])
+ */
+locoex::TFLConst *create_widened(locoex::TFLConst *const_node, uint32_t to)
+{
+ auto const_shape = loco::shape_get(const_node).as<loco::TensorShape>();
+
+ assert(const_shape.rank() == 0 or (const_shape.rank() == 1 and const_shape.dim(0) == 1));
+
+ auto g = const_node->graph();
+
+ auto widened_const = g->nodes()->create<locoex::TFLConst>();
+ {
+ widened_const->dtype(loco::DataType::FLOAT32);
+ widened_const->rank(1);
+ widened_const->dim(0) = to;
+ widened_const->size<loco::DataType::FLOAT32>(to);
+ for (uint32_t x = 0; x < to; x++)
+ widened_const->at<loco::DataType::FLOAT32>(x) = const_node->at<loco::DataType::FLOAT32>(0);
+ }
+ return widened_const;
+}
+
+template <typename TFLType> float calc(float, float);
+
+template <> float calc<locoex::TFLAdd>(float x, float y) { return x + y; }
+template <> float calc<locoex::TFLSub>(float x, float y) { return x - y; }
+
+template <class LatterT> class Fuser
+{
+public:
+ Fuser(LatterT *latter)
+ {
+ static_assert(std::is_same<LatterT, locoex::TFLAdd>::value ||
+ std::is_same<LatterT, locoex::TFLSub>::value,
+ "wrong template type");
+
+ _latter = latter;
+ _graph = _latter->graph();
+ _const_node = get_const(_latter->x(), _latter->y());
+ _former = get_former(_latter->x(), _latter->y());
+
+ assert(_const_node && _former);
+ }
+
+ void fuse(void);
+
+private:
+ loco::Graph *_graph;
+ LatterT *_latter;
+ locoex::TFLConst *_const_node;
+ FormerT *_former;
+
+ locoex::TFLConst *create_fused_bias_const();
+};
+
+// instantiation
+template class Fuser<locoex::TFLAdd>;
+template class Fuser<locoex::TFLSub>;
+
+template <class LatterT> locoex::TFLConst *Fuser<LatterT>::create_fused_bias_const()
+{
+ // we have to create a new bias const by adding/substracting bias and const node (of TFLAdd or
+ // TFLSub)
+ auto bias = dynamic_cast<locoex::TFLConst *>(_former->bias());
+ assert(bias->dtype() == loco::DataType::FLOAT32 &&
+ _const_node->dtype() == loco::DataType::FLOAT32);
+
+ assert(bias->rank() == 1 && _const_node->rank() == 1);
+ assert(bias->dim(0) == _const_node->dim(0));
+
+ // build a new bias const
+ auto new_bias = _graph->nodes()->create<locoex::TFLConst>();
+ {
+ new_bias->dtype(loco::DataType::FLOAT32);
+
+ new_bias->rank(1);
+ new_bias->dim(0) = bias->dim(0);
+
+ new_bias->size<loco::DataType::FLOAT32>(bias->dim(0).value());
+
+ for (uint32_t x = 0; x < bias->dim(0).value(); x++)
+ new_bias->at<loco::DataType::FLOAT32>(x) = calc<LatterT>(
+ bias->at<loco::DataType::FLOAT32>(x), _const_node->at<loco::DataType::FLOAT32>(x));
+ }
+
+ return new_bias;
+}
+
+// FuseBiasAddPass works when former->fusedActivationFunction() == NONE
+bool check_act_func(FormerT *former)
+{
+ using FusedActFuncMixin = locoex::TFLNodeMixin<locoex::TFLNodeTrait::FusedActFunc>;
+
+ if (auto node = dynamic_cast<FusedActFuncMixin *>(former))
+ return node->fusedActivationFunction() == locoex::FusedActFunc::NONE;
+ else
+ return true;
+}
+
+template <class LatterT> void set_act_func(FormerT *former, LatterT *latter)
+{
+ using FusedActFuncMixin = locoex::TFLNodeMixin<locoex::TFLNodeTrait::FusedActFunc>;
+
+ if (auto node = dynamic_cast<FusedActFuncMixin *>(former))
+ node->fusedActivationFunction(latter->fusedActivationFunction());
+}
+
+// instantiation
+template void set_act_func(FormerT *, locoex::TFLAdd *);
+template void set_act_func(FormerT *, locoex::TFLSub *);
+
+/**
+ * @brief Fuse TFLAdd or TFLSub (latter) into TFLConv2d or TFLDepthwiseConv2D (former).
+ * All conditions should be checked before calling this.
+ *
+ * @note TFLAdd can have fused activation function (let's call this FAF for simplicity).
+ *
+ * Conv2D's FAF | TFLAdd's FAF => FAF after fusing TFLAdd into TFLConv2D
+ * ----------------|--------------- --------------------------------------
+ * NONE | NONE, RELU or RELU6 => TFLAdd's FAF
+ * other than NONE | anything => cannot be fused
+ */
+template <class LatterT> void Fuser<LatterT>::fuse(void)
+{
+ // check fused activation function
+ {
+ assert(check_act_func(_former));
+
+ set_act_func<LatterT>(_former, _latter);
+ }
+
+ auto new_bias = create_fused_bias_const();
+
+ // replace node with new_bias
+ // note that loco::replace() is not used because bias could be input of other op just in case
+ _former->bias(new_bias);
+
+ // remove TFLAdd or TFLSub node
+ loco::replace(_latter).with(as_loco_node(_former));
+ _latter->x(nullptr);
+ _latter->y(nullptr);
+}
+
+struct Collector final : public locoex::TFLNodeMutableVisitor<void>
+{
+ template <class LatterT>
+ void setCandidate(FormerT *former, LatterT *latter, locoex::TFLConst *const_node)
+ {
+ static_assert(std::is_same<LatterT, locoex::TFLAdd>::value ||
+ std::is_same<LatterT, locoex::TFLSub>::value,
+ "wrong template type");
+
+ if (!check_act_func(former))
+ return;
+
+ auto depth =
+ loco::shape_get(as_loco_node(former)).template as<loco::TensorShape>().dim(3).value();
+ auto const_shape = loco::shape_get(const_node).template as<loco::TensorShape>();
+
+ if (const_shape.rank() == 1 and const_shape.dim(0) == depth)
+ {
+ candidates.insert(latter);
+ }
+ // when Const has only one value, create a new const with shape [depth]
+ else if (const_shape.rank() == 0 or (const_shape.rank() == 1 and const_shape.dim(0) == 1))
+ {
+ if (!(loco::dtype_get(as_loco_node(former)) == loco::DataType::FLOAT32))
+ INTERNAL_EXN_V("Unsupported data type",
+ oops::to_uint32(loco::dtype_get(as_loco_node(former))));
+ if (!(const_node->dtype() == loco::DataType::FLOAT32))
+ INTERNAL_EXN_V("Unsupported data type", oops::to_uint32(const_node->dtype()));
+
+ auto new_bias_node = create_widened(const_node, depth);
+
+ // Replacing TFLConst input of TFLAdd or TFLSub.
+ // Note that calling loco::replace(const_node).with(new_bias_node) could be dangerous
+ // because const_node could be the input of many nodes
+ set_const_input(latter, new_bias_node);
+
+ candidates.insert(latter);
+ }
+ }
+
+ void visit(locoex::TFLAdd *latter) final
+ {
+ auto former = get_former(latter->x(), latter->y());
+ auto const_node = get_const(latter->x(), latter->y());
+
+ if (former && const_node)
+ setCandidate<locoex::TFLAdd>(former, latter, const_node);
+ }
+
+ void visit(locoex::TFLSub *latter) final
+ {
+ // TFLSub, of which x() = TFLConv2D or TFLDepthwiseConv2D, y() = TFLConst, is fusing target
+ auto former = dynamic_cast<FormerT *>(latter->x());
+ auto const_node = dynamic_cast<locoex::TFLConst *>(latter->y());
+
+ if (former && const_node)
+ setCandidate<locoex::TFLSub>(former, latter, const_node);
+ }
+
+ void visit(locoex::TFLNode *) final { return; }
+
+ std::set<locoex::TFLNode *> candidates;
+};
+
+struct Performer final : public locoex::TFLNodeMutableVisitor<void>
+{
+ void visit(locoex::TFLAdd *latter) final
+ {
+ assert(get_former(latter->x(), latter->y()));
+
+ Fuser<locoex::TFLAdd> fuser(latter);
+ fuser.fuse();
+ }
+
+ void visit(locoex::TFLSub *latter) final
+ {
+ assert(get_former(latter->x(), latter->y()));
+
+ Fuser<locoex::TFLSub> fuser(latter);
+ fuser.fuse();
+ }
+
+ void visit(locoex::TFLNode *) final { assert(false && "should not be called"); }
+};
+
+} // namespace
+
+namespace exo
+{
+
+bool FuseBiasAddPass::run(loco::Graph *g)
+{
+ Collector collector;
+
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (node->dialect() == locoex::TFLDialect::get())
+ {
+ auto tfl_node = dynamic_cast<locoex::TFLNode *>(node);
+ tfl_node->accept(&collector);
+ }
+ }
+
+ Performer performer;
+
+ for (auto node : collector.candidates)
+ {
+ node->accept(&performer);
+ }
+
+ return collector.candidates.size() > 0;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Pass/FuseBiasAddPass.h b/compiler/exo/src/Pass/FuseBiasAddPass.h
new file mode 100644
index 000000000..68e624c6b
--- /dev/null
+++ b/compiler/exo/src/Pass/FuseBiasAddPass.h
@@ -0,0 +1,61 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __PASS_FUSE_BIASADD_PASS_H__
+#define __PASS_FUSE_BIASADD_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace exo
+{
+
+/**
+ * @brief Class to fuse TFLAdd or TFLSub into Bias input of the following ops:
+ * - TFLConv2D, TFLDepthwiseConv2D
+ * - TODO Consider to add FullyConnected, etc.
+ *
+ * Case 1. Conv2D and TFLAdd
+ *
+ * BEFORE:
+ *
+ * TFLConst A (a scalar or a tensor of shape [1] or [depth of TFLConv2D])
+ * |
+ * Foo -- TFLConv2D -- TFLAdd (or TFLSub) -- Bar
+ * |
+ * TFLConst B --+ (bias)
+ *
+ * AFTER:
+ * Foo ----- TFLConv2D ----- Bar
+ * |
+ * TFLConst A' --+ (bias)
+ *
+ * TFLConst B (dead node)
+ *
+ * TFLAdd (or TFLSub) (dead node)
+ *
+ * @note TFLSub, of which x() == TFLConv2D and y() == TFLConst, will be fused.
+ * If x() == TFLConst and y() == TFLConv2D, it won't be fused.
+ */
+struct FuseBiasAddPass final : public logo::Pass
+{
+ const char *name(void) const final { return "exo::FuseBiasAddPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace exo
+
+#endif // __PASS_FUSE_BIASADD_PASS_H__
diff --git a/compiler/exo/src/Pass/FuseBiasAddPass.test.cpp b/compiler/exo/src/Pass/FuseBiasAddPass.test.cpp
new file mode 100644
index 000000000..6ba728de0
--- /dev/null
+++ b/compiler/exo/src/Pass/FuseBiasAddPass.test.cpp
@@ -0,0 +1,361 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "FuseBiasAddPass.h"
+
+#include "Dialect/IR/TFLNodes.h"
+#include "TestGraph.h"
+#include "TestHelper.h"
+
+#include <loco.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+void init(loco::Pull *pull)
+{
+ pull->dtype(loco::DataType::FLOAT32);
+ pull->shape({2, 3, 3, 2});
+}
+
+/// @brief Initializes TFLConv2D and related filter and bias
+void init(locoex::TFLConv2D *conv2d, locoex::TFLConst *filter, locoex::TFLConst *bias)
+{
+ // set conv2d
+ {
+ conv2d->fusedActivationFunction(locoex::FusedActFunc::NONE);
+ conv2d->padding(locoex::Padding::VALID);
+ }
+
+ // set filter
+ {
+ filter->dtype(loco::DataType::FLOAT32);
+ filter->shape({2, 3, 3, 2});
+ filter->size<loco::DataType::FLOAT32>(2 * 3 * 3 * 2);
+
+ for (uint32_t x = 0; x < 2 * 3 * 3 * 2; x++)
+ filter->at<loco::DataType::FLOAT32>(x) = 0.0;
+ }
+
+ // set bias
+ {
+ bias->dtype(loco::DataType::FLOAT32);
+ bias->shape({2});
+ bias->size<loco::DataType::FLOAT32>(2);
+
+ for (uint32_t x = 0; x < 2; x++)
+ bias->at<loco::DataType::FLOAT32>(x) = 0.0;
+ }
+}
+
+template <class T> void init(T *node, locoex::FusedActFunc f)
+{
+ static_assert(std::is_same<T, locoex::TFLAdd>::value || std::is_same<T, locoex::TFLSub>::value,
+ "wrong template type");
+
+ node->fusedActivationFunction(f);
+}
+
+/// @brief Initializes one param of TFLAdd or TFLSub
+void init(locoex::TFLConst *addsub_param)
+{
+ // set addsub_param : y() value of TFLAdd or TFLSub
+ addsub_param->dtype(loco::DataType::FLOAT32);
+ addsub_param->shape({2});
+ addsub_param->size<loco::DataType::FLOAT32>(2);
+
+ for (uint32_t x = 0; x < 2; x++)
+ addsub_param->at<loco::DataType::FLOAT32>(x) = (x + 1) * 1.5; // 1.5, 3
+}
+
+} // namespace
+
+// A case when
+// - TFLConv2D has bias (0, 0)
+// - TFLAdd, of which x() or y() == TFLConv2D
+// - Another param of TFLAdd is TFLConst, (1.5, 3)
+//
+// After fusion, bias shold be (1.5, 3)
+TEST(FuseBiasAddPassTest, Conv2D_Add_01_basic)
+{
+ exo::test::TestGraph g;
+ auto filter = g.append<locoex::TFLConst>();
+ auto bias = g.append<locoex::TFLConst>();
+ auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias);
+
+ auto add_y = g.append<locoex::TFLConst>();
+ auto add = g.append<locoex::TFLAdd>(conv2d, add_y);
+
+ g.complete(add);
+
+ init(g.pull);
+ init(conv2d, filter, bias);
+ init(add, locoex::FusedActFunc::NONE);
+ init(add_y);
+
+ // let's run fusion
+ {
+ exo::test::TypeShapeReadyPhase test_phase;
+
+ test_phase.add_pass<exo::FuseBiasAddPass>();
+ test_phase.run(g.graph());
+ }
+
+ auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph());
+ ASSERT_TRUE(a_conv2d != nullptr);
+
+ auto a_bias = dynamic_cast<locoex::TFLConst *>(a_conv2d->bias());
+ ASSERT_TRUE(a_bias != nullptr);
+
+ ASSERT_TRUE(a_bias->dim(0) == 2);
+ ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(0),
+ bias->at<loco::DataType::FLOAT32>(0) + add_y->at<loco::DataType::FLOAT32>(0));
+ ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(1),
+ bias->at<loco::DataType::FLOAT32>(1) + add_y->at<loco::DataType::FLOAT32>(1));
+}
+
+// A case when
+// - TFLConv2D has bias (0, 0)
+// - TFLAdd, of which x() or y() == TFLConv2D
+// - Another param of TFLAdd is TFLConst, (1.5) <-- scalar
+//
+// After fusion, bias shold be (1.5, 1.5)
+TEST(FuseBiasAddPassTest, Conv2D_Add_02_TFLAdd_y_is_scalar)
+{
+ exo::test::TestGraph g;
+ auto filter = g.append<locoex::TFLConst>();
+ auto bias = g.append<locoex::TFLConst>();
+ auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias);
+
+ auto add_y = g.append<locoex::TFLConst>();
+ auto add = g.append<locoex::TFLAdd>(conv2d, add_y);
+
+ g.complete(add);
+
+ init(g.pull);
+ init(conv2d, filter, bias); // channel of conv2d is 2
+
+ {
+ // Size of this TFLConst is 1.
+ // Note that this should be widened later to the shape of [channel of Conv2D], which is [2]
+ add_y->dtype(loco::DataType::FLOAT32);
+ add_y->shape({1});
+ add_y->size<loco::DataType::FLOAT32>(1);
+ add_y->at<loco::DataType::FLOAT32>(0) = 1.5;
+ }
+ init(add, locoex::FusedActFunc::NONE);
+
+ // let's run fusion
+ {
+ exo::test::TypeShapeReadyPhase test_phase;
+
+ test_phase.add_pass<exo::FuseBiasAddPass>();
+ test_phase.run(g.graph());
+ }
+
+ auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph());
+ ASSERT_TRUE(a_conv2d != nullptr);
+
+ auto a_bias = dynamic_cast<locoex::TFLConst *>(a_conv2d->bias());
+ ASSERT_TRUE(a_bias != nullptr);
+
+ ASSERT_TRUE(a_bias->dim(0) == 2);
+ ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(0),
+ bias->at<loco::DataType::FLOAT32>(0) + 1.5);
+ ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(1),
+ bias->at<loco::DataType::FLOAT32>(1) + 1.5);
+}
+
+// A case when
+// - TFLConv2D has bias (0, 0)
+// - TFLSub.x() == TFLConv2D
+// - TFLSub.y() == TFLConst, (1.5, 3)
+//
+// After fusion, bias shold be (-1.5, -3)
+TEST(FuseBiasAddPassTest, Conv2D_Sub_01_basic)
+{
+ exo::test::TestGraph g;
+ auto filter = g.append<locoex::TFLConst>();
+ auto bias = g.append<locoex::TFLConst>();
+ auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias);
+
+ auto sub_y = g.append<locoex::TFLConst>();
+ auto sub = g.append<locoex::TFLSub>(conv2d, sub_y);
+
+ g.complete(sub);
+
+ init(g.pull);
+ init(conv2d, filter, bias);
+ init(sub, locoex::FusedActFunc::NONE);
+ init(sub_y);
+
+ // let's run fusion
+ {
+ exo::test::TypeShapeReadyPhase test_phase;
+
+ test_phase.add_pass<exo::FuseBiasAddPass>();
+ test_phase.run(g.graph());
+ }
+
+ auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph());
+ ASSERT_TRUE(a_conv2d != nullptr);
+
+ auto a_bias = dynamic_cast<locoex::TFLConst *>(a_conv2d->bias());
+ ASSERT_TRUE(a_bias != nullptr);
+
+ ASSERT_TRUE(a_bias->dim(0) == 2);
+ ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(0),
+ bias->at<loco::DataType::FLOAT32>(0) - sub_y->at<loco::DataType::FLOAT32>(0));
+ ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(1),
+ bias->at<loco::DataType::FLOAT32>(1) - sub_y->at<loco::DataType::FLOAT32>(1));
+}
+
+// A case when TFLConv2D is input of TFLSub but fusion cannot be performed.
+// - TFLSub.x() == TFLConst
+// - TFLSub.y() == TFLConv2D
+//
+// Here, TFLSub cannot be fused into TFLConst. To be fused, TFLSub.x() should be TFLConv2D and
+// TFLSub.y() should be TFLConst. So fusion will NOT happen.
+TEST(FuseBiasAddPassTest, Conv2D_Sub_02_fusing_will_not_performed)
+{
+ exo::test::TestGraph g;
+ auto filter = g.append<locoex::TFLConst>();
+ auto bias = g.append<locoex::TFLConst>();
+ auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias);
+
+ auto sub_y = g.append<locoex::TFLConst>();
+ auto sub = g.append<locoex::TFLSub>(sub_y, conv2d); // This WON'T be fused
+
+ g.complete(sub);
+
+ init(g.pull);
+ init(conv2d, filter, bias);
+ init(sub, locoex::FusedActFunc::NONE);
+ init(sub_y);
+
+ // let's run fusion
+ {
+ exo::test::TypeShapeReadyPhase test_phase;
+
+ test_phase.add_pass<exo::FuseBiasAddPass>();
+ test_phase.run(g.graph());
+ }
+
+ auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph());
+ ASSERT_TRUE(a_conv2d != nullptr);
+
+ auto a_bias = dynamic_cast<locoex::TFLConst *>(a_conv2d->bias());
+ ASSERT_TRUE(a_bias != nullptr);
+
+ ASSERT_TRUE(a_bias->dim(0) == 2);
+ ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(0), 0);
+ ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(1), 0);
+
+ auto a_sub = exo::test::find_first_node_bytype<locoex::TFLSub>(g.graph());
+ ASSERT_TRUE(a_sub != nullptr);
+ ASSERT_TRUE(a_sub->y() == a_conv2d); // Checking 'not-fused' state
+}
+
+// A case when
+// - TFLConv2D has an activation function with Relu
+// - TFLAdd, has no activation function
+//
+// No fusion should happen
+TEST(FuseBiasAddPassTest, Regression_Conv2D_Add_fused_action_00)
+{
+ exo::test::TestGraph g;
+ auto filter = g.append<locoex::TFLConst>();
+ auto bias = g.append<locoex::TFLConst>();
+ auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias);
+
+ auto add_y = g.append<locoex::TFLConst>();
+ auto add = g.append<locoex::TFLAdd>(conv2d, add_y);
+
+ g.complete(add);
+
+ init(g.pull);
+ init(conv2d, filter, bias);
+ init(add, locoex::FusedActFunc::NONE);
+ init(add_y);
+
+ // Updating Fused Activation for this test
+ conv2d->fusedActivationFunction(locoex::FusedActFunc::RELU);
+
+ // let's run fusion
+ {
+ exo::test::TypeShapeReadyPhase test_phase;
+
+ test_phase.add_pass<exo::FuseBiasAddPass>();
+ test_phase.run(g.graph());
+ }
+
+ auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph());
+ ASSERT_TRUE(a_conv2d != nullptr);
+ ASSERT_TRUE(a_conv2d->fusedActivationFunction() == locoex::FusedActFunc::RELU);
+
+ auto an_add = exo::test::find_first_node_bytype<locoex::TFLAdd>(g.graph());
+ ASSERT_TRUE(an_add != nullptr);
+ ASSERT_TRUE(an_add->fusedActivationFunction() == locoex::FusedActFunc::NONE);
+
+ ASSERT_TRUE(an_add->x() == a_conv2d or an_add->y() == a_conv2d);
+}
+
+// A case when
+// - TFLConv2D has NONE activation function
+// - TFLAdd has Relu activation function
+//
+// TFLConv2D should have Relu activation function, TFLAdd is fused into bias input
+TEST(FuseBiasAddPassTest, Regression_Conv2D_Add_fused_action_01)
+{
+ exo::test::TestGraph g;
+ auto filter = g.append<locoex::TFLConst>();
+ auto bias = g.append<locoex::TFLConst>();
+ auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias);
+
+ auto add_y = g.append<locoex::TFLConst>();
+ auto add = g.append<locoex::TFLAdd>(conv2d, add_y);
+
+ g.complete(add);
+
+ init(g.pull);
+ init(conv2d, filter, bias);
+ init(add, locoex::FusedActFunc::RELU);
+ init(add_y);
+
+ // let's run fusion
+ {
+ exo::test::TypeShapeReadyPhase test_phase;
+
+ test_phase.add_pass<exo::FuseBiasAddPass>();
+ test_phase.run(g.graph());
+ }
+
+ auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph());
+ ASSERT_TRUE(a_conv2d != nullptr);
+
+ auto a_bias = dynamic_cast<locoex::TFLConst *>(a_conv2d->bias());
+ ASSERT_TRUE(a_bias != nullptr);
+
+ ASSERT_TRUE(a_bias->dim(0) == 2);
+ ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(0),
+ bias->at<loco::DataType::FLOAT32>(0) + add_y->at<loco::DataType::FLOAT32>(0));
+ ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(1),
+ bias->at<loco::DataType::FLOAT32>(1) + add_y->at<loco::DataType::FLOAT32>(1));
+
+ ASSERT_TRUE(a_conv2d->fusedActivationFunction() == locoex::FusedActFunc::RELU);
+}
diff --git a/compiler/exo/src/Pass/FuseInstanceNormPass.cpp b/compiler/exo/src/Pass/FuseInstanceNormPass.cpp
new file mode 100644
index 000000000..04d4a62cd
--- /dev/null
+++ b/compiler/exo/src/Pass/FuseInstanceNormPass.cpp
@@ -0,0 +1,402 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "FuseInstanceNormPass.h"
+
+#include "Dialect/IR/TFLNodes.h"
+#include "Dialect/IR/CircleNodes.h"
+
+#include <loco/Service/ShapeInference.h>
+
+#include <cassert>
+#include <set>
+
+// Helper to find commutative node's arguments
+namespace
+{
+
+/**
+ * INTRODUCTION
+ * Binary operation f(x,y) is 'commutative' when
+ * f(x,y) == f(y,x) holds for all x, y.
+ * For examples, ADD, MUL and SQUARED_DIFFERENCE are commutative.
+ * These helpers make it easy to find commutative arguemnts of commtative node.
+ *
+ * HOW TO USE
+ * COMM_NODE *node;
+ * ARG_TYPE_1 *arg1;
+ * ARG_TYPE_2 *arg2;
+ *
+ * bool ok = fill(&arg1, &arg2).with_commutative_args_of(node);
+ *
+ * Result
+ * If 'node's commutative argument types are actually {ARG_TYPE_1, ARG_TYPE_2}
+ * (as a set), 'arg1' and 'arg2' set as actual 'node's arguemnts with matching
+ * type, and return value 'ok' is true.
+ * Otherwise, 'arg1' and 'arg2' not changed, 'ok' is false.
+ */
+
+template <class ARG_TYPE_1, class ARG_TYPE_2> class NodeFiller final
+{
+public:
+ NodeFiller(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2) : _arg_1(arg_1), _arg_2(arg_2)
+ {
+ // DO NOTHING
+ }
+
+ /**
+ * @return true When 'node's argument types are 'ARG_TYPE_1' and 'ARG_TYPE_2'
+ * In such case, it assign '_arg_1' and '_arg_2' to actual arguments
+ *
+ * @return false When 'node's argument types are NOT matched with 'ARG_TYPE_*'
+ * In such case, it does not amend '_arg_1' and '_arg_2'
+ *
+ * @require COMM_NODE has member x() and y()
+ */
+ template <class COMM_NODE> bool with_commutative_args_of(const COMM_NODE *node);
+
+private:
+ ARG_TYPE_1 **_arg_1;
+ ARG_TYPE_2 **_arg_2;
+};
+
+template <class ARG_TYPE_1, class ARG_TYPE_2>
+inline NodeFiller<ARG_TYPE_1, ARG_TYPE_2> fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2)
+{
+ return NodeFiller<ARG_TYPE_1, ARG_TYPE_2>{arg_1, arg_2};
+}
+
+template <class ARG_TYPE_1, class ARG_TYPE_2>
+template <class COMM_NODE>
+bool NodeFiller<ARG_TYPE_1, ARG_TYPE_2>::with_commutative_args_of(const COMM_NODE *node)
+{
+ // Case 1) X == ARG_TYPE_1 / Y == ARG_TYPE_2
+ {
+ auto x = dynamic_cast<ARG_TYPE_1 *>(node->x());
+ auto y = dynamic_cast<ARG_TYPE_2 *>(node->y());
+
+ if (x && y)
+ {
+ *_arg_1 = x;
+ *_arg_2 = y;
+ return true;
+ }
+ }
+
+ // Case 2) X == ARG_TYPE_2 / Y == ARG_TYPE_1
+ {
+ auto x = dynamic_cast<ARG_TYPE_2 *>(node->x());
+ auto y = dynamic_cast<ARG_TYPE_1 *>(node->y());
+
+ if (x && y)
+ {
+ *_arg_1 = y;
+ *_arg_2 = x;
+ return true;
+ }
+ }
+
+ return false;
+}
+
+} // namespace
+
+// Helper to check detail
+namespace
+{
+
+/// @return true When node has shape of '1 x .. x 1 x depth'
+bool is_1D_with_dummy_dim(locoex::TFLConst *node, uint32_t depth)
+{
+ auto rank = node->rank();
+ uint32_t axis;
+ for (axis = 0; axis < rank - 1; ++axis)
+ {
+ if (node->dim(axis).value() != 1)
+ return false;
+ }
+ return node->dim(axis).value() == depth;
+}
+
+bool is_instance_mean(locoex::TFLMean *mean)
+{
+ //
+ // CHECK 1) input is rank 4
+ //
+ auto input = mean->input();
+ if (not loco::shape_known(input))
+ return false;
+ auto input_shape = loco::shape_get(input).as<loco::TensorShape>();
+ if (input_shape.rank() != 4)
+ return false;
+
+ //
+ // CHECK 2) 'reduction indices' is TFLConst of value [1,2], that is HW of NHWC
+ //
+ // TODO Support equivalent case, like [-3,-2]
+ // TODO Support non-Const case?
+ // TODO What if input is NCHW format in Circle?
+ auto red_indices = dynamic_cast<locoex::TFLConst *>(mean->reduction_indices());
+ if (not red_indices)
+ return false;
+ if (red_indices->rank() != 1)
+ return false;
+ std::set<int32_t> red_indices_set;
+ {
+ // TODO Currently only support S32, support other types
+ assert(red_indices->dtype() == loco::DataType::S32);
+ for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i)
+ red_indices_set.insert(red_indices->at<loco::DataType::S32>(i));
+ }
+ if (red_indices_set.size() != 2)
+ return false;
+ if (red_indices_set.find(1) == red_indices_set.end())
+ return false;
+ if (red_indices_set.find(2) == red_indices_set.end())
+ return false;
+
+ //
+ // CHECK 3) keep_dims == true (?)
+ //
+ // We only have case of 'keep_dims == true' so far, but it might be okay with 'keep_dims == false'
+ // TODO Check this fact, and if true, return true regardless of keep_dims
+ return mean->keep_dims();
+}
+
+} // namespace
+
+// Helper to fuse Instance Norm
+namespace
+{
+
+/**
+ * SUBGRAPH PATTERN
+ *
+ * - Below diagram shows Instance Norm pattern to fuse.
+ * - Execution dependency order is top to the bottom.
+ * - Node name is matched with variable name of InstanceNormPattern class.
+ * - Usually, first word of node name (variable name) is node type. For e.g.
+ * variable 'mean_as_variance' is pointer to TFLMean.
+ * - (Item in parenthesis) means actually exist, but not having a name and
+ * not a variable of InstanceNormPattern class.
+ *
+ * TODO support other semantically same patterns for instance norm
+ *
+ * [In]
+ * |
+ * V
+ * +----------- ifm -----+ (reduction indicies)
+ * | | | |
+ * | | V V
+ * | | mean_of_ifm ----------------+
+ * | V | |
+ * | sqdiff <--+ (reduction indicies) |
+ * | | | |
+ * | V | |
+ * | mean_as_variance <---+ const_as_epsilon |
+ * | | | |
+ * | V | |
+ * | add_as_variance <--------+ |
+ * | | |
+ * | V |
+ * | rsqrt const_as_gamma |
+ * | | | |
+ * | V | |
+ * | mul_gamma <--+ |
+ * | | | |
+ * V V V |
+ * mul_as_scaled_ifm mul_as_scaled_mean <-------------+
+ * | |
+ * | const_as_beta |
+ * | | V
+ * | +------> sub
+ * V |
+ * add_as_terminal <----------+
+ * |
+ * V
+ * [Out]
+ */
+class InstanceNormPattern final
+{
+public:
+ InstanceNormPattern(locoex::TFLAdd *candidate)
+ {
+ assert(candidate);
+ add_as_terminal = candidate;
+ }
+
+public:
+ bool matched();
+ bool matched() const { return _matched; }
+
+public:
+ // Context
+ loco::Node *ifm = nullptr;
+ locoex::TFLMean *mean_of_ifm = nullptr;
+ locoex::TFLSquaredDifference *sqdiff = nullptr;
+ locoex::TFLMean *mean_as_variance = nullptr;
+ locoex::TFLConst *const_as_epsilon = nullptr;
+ locoex::TFLAdd *add_as_variance = nullptr;
+ locoex::TFLRsqrt *rsqrt = nullptr;
+ locoex::TFLConst *const_as_gamma = nullptr;
+ locoex::TFLMul *mul_gamma = nullptr;
+ locoex::TFLMul *mul_as_scaled_ifm = nullptr;
+ locoex::TFLMul *mul_as_scaled_mean = nullptr;
+ locoex::TFLConst *const_as_beta = nullptr;
+ locoex::TFLSub *sub = nullptr;
+ locoex::TFLAdd *add_as_terminal = nullptr;
+
+private:
+ bool _matched = false;
+};
+
+bool InstanceNormPattern::matched()
+{
+ if (_matched)
+ return true;
+
+#define CHECK_OR_FALSE(condition) \
+ if (not(condition)) \
+ return false;
+
+ // Check order is DFS
+
+ CHECK_OR_FALSE(fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
+ CHECK_OR_FALSE(fill(&ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm));
+
+ CHECK_OR_FALSE(loco::shape_known(ifm));
+ auto ifm_shape = loco::shape_get(ifm);
+ CHECK_OR_FALSE(ifm_shape.domain() == loco::Domain::Tensor);
+ auto ifm_tensor_shape = ifm_shape.as<loco::TensorShape>();
+ CHECK_OR_FALSE(ifm_tensor_shape.rank() == 4);
+ uint32_t ifm_channel_depth = ifm_tensor_shape.dim(3).value();
+
+ CHECK_OR_FALSE(fill(&rsqrt, &const_as_gamma).with_commutative_args_of(mul_gamma));
+ CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_gamma, ifm_channel_depth));
+
+ add_as_variance = dynamic_cast<locoex::TFLAdd *>(rsqrt->x());
+ CHECK_OR_FALSE(add_as_variance);
+
+ CHECK_OR_FALSE(
+ fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
+
+ CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
+ // TODO Support regarding broadcast
+ CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
+
+ CHECK_OR_FALSE(is_instance_mean(mean_as_variance));
+ sqdiff = dynamic_cast<locoex::TFLSquaredDifference *>(mean_as_variance->input());
+ CHECK_OR_FALSE(sqdiff);
+
+ loco::Node *ifm_should_be = nullptr;
+ CHECK_OR_FALSE(fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff));
+ CHECK_OR_FALSE(ifm == ifm_should_be);
+ CHECK_OR_FALSE(is_instance_mean(mean_of_ifm));
+ CHECK_OR_FALSE(ifm == mean_of_ifm->input());
+
+ const_as_beta = dynamic_cast<locoex::TFLConst *>(sub->x());
+ CHECK_OR_FALSE(const_as_beta);
+ CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_beta, ifm_channel_depth));
+
+ mul_as_scaled_mean = dynamic_cast<locoex::TFLMul *>(sub->y());
+ CHECK_OR_FALSE(mul_as_scaled_mean);
+
+ locoex::TFLMul *mul_gamma_should_be = nullptr;
+ locoex::TFLMean *mean_of_ifm_should_be = nullptr;
+ CHECK_OR_FALSE(fill(&mul_gamma_should_be, &mean_of_ifm_should_be)
+ .with_commutative_args_of(mul_as_scaled_mean));
+ CHECK_OR_FALSE(mul_gamma == mul_gamma_should_be);
+ CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
+#undef CHECK_OR_FALSE
+ _matched = true;
+ return true;
+}
+
+/**
+ * Instance norm pattern would be fused like following diagram:
+ *
+ * [In] --------------------------- CircleInstanceNorm --- [Out]
+ * / /
+ * const_as_gamma --- TFLReshape --- /
+ * /
+ * const_as_beta ---- TFLReshape ---
+ *
+ * Note
+ * - 'const_as_gamma' and 'const_as_beta' are from original graph
+ * - Value of 'const_as_epsilon' would be copied to CircleInstanceNorm's attribute
+ * - TFLReshape is added as CircleInstanceNorm only accept 1D tensor
+ * - 'TFLConst --- TFLReshape' is expected to be fused in constant folding for Reshape
+ */
+void fuse_instance_norm(const InstanceNormPattern &p)
+{
+ assert(p.matched());
+
+ auto graph = p.add_as_terminal->graph();
+
+ // Make reshape for gamma & beta
+ auto reshape_gamma = graph->nodes()->create<locoex::TFLReshape>();
+ auto reshape_beta = graph->nodes()->create<locoex::TFLReshape>();
+ {
+ auto ifm_shape = loco::shape_get(p.ifm).as<loco::TensorShape>();
+ uint32_t ifm_channel_depth = ifm_shape.dim(3).value();
+
+ int32_t new_shape[1] = {static_cast<int32_t>(ifm_channel_depth)};
+
+ reshape_gamma->tensor(p.const_as_gamma);
+ reshape_beta->tensor(p.const_as_beta);
+
+ locoex::set_new_shape(reshape_gamma, new_shape, 1);
+ locoex::set_new_shape(reshape_beta, new_shape, 1);
+ }
+
+ // Make Instance Norm to replace
+ auto instance_norm = graph->nodes()->create<locoex::CircleInstanceNorm>();
+ instance_norm->input(p.ifm);
+ instance_norm->gamma(reshape_gamma);
+ instance_norm->beta(reshape_beta);
+ float epsilon = p.const_as_epsilon->at<loco::DataType::FLOAT32>(0);
+ instance_norm->epsilon(epsilon);
+ instance_norm->fusedActivationFunction(p.add_as_terminal->fusedActivationFunction());
+
+ replace(p.add_as_terminal).with(instance_norm);
+}
+
+} // namespace
+
+namespace exo
+{
+
+bool FuseInstanceNormPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto add = dynamic_cast<locoex::TFLAdd *>(node);
+ if (not add)
+ continue;
+
+ InstanceNormPattern pattern(add);
+ if (not pattern.matched())
+ continue;
+
+ fuse_instance_norm(pattern);
+ changed = true;
+ }
+
+ return changed;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Pass/FuseInstanceNormPass.h b/compiler/exo/src/Pass/FuseInstanceNormPass.h
new file mode 100644
index 000000000..e6361021c
--- /dev/null
+++ b/compiler/exo/src/Pass/FuseInstanceNormPass.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __FUSE_INSTANCE_NORM_PASS_H__
+#define __FUSE_INSTANCE_NORM_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace exo
+{
+
+/**
+ * @brief Class to fuse certain pattern of subgraph into CircleInstanceNorm
+ * with auxiliary nodes
+ *
+ * For detailed subgraph pattern to be fused, please check its implementation.
+ */
+struct FuseInstanceNormPass final : public logo::Pass
+{
+ const char *name(void) const final { return "exo::FuseInstanceNormPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace exo
+
+#endif // __FUSE_INSTANCE_NORM_PASS_H__
diff --git a/compiler/exo/src/Pass/FuseReluPass.cpp b/compiler/exo/src/Pass/FuseReluPass.cpp
new file mode 100644
index 000000000..d7af0c506
--- /dev/null
+++ b/compiler/exo/src/Pass/FuseReluPass.cpp
@@ -0,0 +1,115 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "FuseReluPass.h"
+
+#include "Dialect/IR/TFLNodes.h"
+#include "Dialect/IR/TFLDialect.h"
+#include "Dialect/IR/TFLNodeVisitor.h"
+
+#include <set>
+
+namespace
+{
+
+bool is_pred_fusable(loco::Node *node)
+{
+ using namespace locoex;
+
+ auto fusable_node = dynamic_cast<TFLNodeMixin<TFLNodeTrait::FusedActFunc> *>(node);
+
+ return (fusable_node and fusable_node->fusedActivationFunction() == FusedActFunc::NONE);
+};
+
+struct Collector final : public locoex::TFLNodeMutableVisitor<void>
+{
+ void visit(locoex::TFLRelu *node) final
+ {
+ if (is_pred_fusable(node->features()))
+ candidates.insert(node);
+ }
+
+ void visit(locoex::TFLRelu6 *node) final
+ {
+ if (is_pred_fusable(node->features()))
+ candidates.insert(node);
+ }
+
+ void visit(locoex::TFLNode *) final { return; }
+
+ std::set<locoex::TFLNode *> candidates;
+};
+
+void set_activation_fusion(loco::Node *node, locoex::FusedActFunc f)
+{
+ using namespace locoex;
+
+ if (auto fusable_node = dynamic_cast<TFLNodeMixin<TFLNodeTrait::FusedActFunc> *>(node))
+ fusable_node->fusedActivationFunction(f);
+ else
+ assert(false);
+}
+
+struct Performer final : public locoex::TFLNodeMutableVisitor<void>
+{
+ void visit(locoex::TFLRelu *the_relu) final
+ {
+ set_activation_fusion(the_relu->features(), locoex::FusedActFunc::RELU);
+
+ loco::replace(the_relu).with(the_relu->features());
+ the_relu->features(nullptr);
+ }
+
+ void visit(locoex::TFLRelu6 *the_relu6) final
+ {
+ set_activation_fusion(the_relu6->features(), locoex::FusedActFunc::RELU6);
+
+ loco::replace(the_relu6).with(the_relu6->features());
+ the_relu6->features(nullptr);
+ }
+
+ void visit(locoex::TFLNode *) final { assert(false && "should not be called"); }
+};
+
+} // namespace
+
+namespace exo
+{
+
+bool FuseReluPass::run(loco::Graph *g)
+{
+ Collector collector;
+
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (node->dialect() == locoex::TFLDialect::get())
+ {
+ auto tfl_node = dynamic_cast<locoex::TFLNode *>(node);
+ tfl_node->accept(&collector);
+ }
+ }
+
+ Performer performer;
+
+ for (auto node : collector.candidates)
+ {
+ node->accept(&performer);
+ }
+
+ return collector.candidates.size() > 0;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Pass/FuseReluPass.h b/compiler/exo/src/Pass/FuseReluPass.h
new file mode 100644
index 000000000..1cd276b29
--- /dev/null
+++ b/compiler/exo/src/Pass/FuseReluPass.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __PASS_FUSE_RELU_PASS_H__
+#define __PASS_FUSE_RELU_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace exo
+{
+
+/**
+ * @brief Class to fuse TFLRelu or TFLRelu6 into the TensorFlow Lite ops below:
+ *
+ * ADD, AVERAGE_POOL_2D, CONCATENATION, CONV_2D, DEPTHWISE_CONV_2D,
+ * FULLY_CONNECTED, L2_NORMALIZATION, L2_POOL_2D, MAX_POOL_2D, MUL
+ */
+struct FuseReluPass final : public logo::Pass
+{
+ const char *name(void) const final { return "exo::FuseReluPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace exo
+
+#endif // __PASS_FUSE_RELU_PASS_H__
diff --git a/compiler/exo/src/Pass/FuseReluPass.test.cpp b/compiler/exo/src/Pass/FuseReluPass.test.cpp
new file mode 100644
index 000000000..6f83d4dd0
--- /dev/null
+++ b/compiler/exo/src/Pass/FuseReluPass.test.cpp
@@ -0,0 +1,115 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "FuseReluPass.h"
+
+#include "Dialect/IR/TFLNodes.h"
+#include "TestGraph.h"
+
+#include <loco.h>
+#include <logo/RemoveDeadNodePass.h>
+
+#include <gtest/gtest.h>
+
+#include <type_traits> // for std::is_same
+
+namespace
+{
+
+void init(loco::Pull *pull)
+{
+ pull->dtype(loco::DataType::FLOAT32);
+ pull->shape({2, 3, 3, 2});
+}
+
+/// @brief Initializes TFLConv2D and related filter and bias
+void init(locoex::TFLConv2D *conv2d, locoex::TFLConst *filter, locoex::TFLConst *bias)
+{
+ // set conv2d
+ {
+ conv2d->fusedActivationFunction(locoex::FusedActFunc::NONE);
+ conv2d->padding(locoex::Padding::VALID);
+ }
+
+ // set filter
+ {
+ filter->dtype(loco::DataType::FLOAT32);
+ filter->shape({2, 3, 3, 2});
+ filter->size<loco::DataType::FLOAT32>(2 * 3 * 3 * 2);
+
+ for (uint32_t x = 0; x < 2 * 3 * 3 * 2; x++)
+ filter->at<loco::DataType::FLOAT32>(x) = 0.0;
+ }
+
+ // set bias
+ {
+ bias->dtype(loco::DataType::FLOAT32);
+ bias->shape({2});
+ bias->size<loco::DataType::FLOAT32>(2);
+
+ for (uint32_t x = 0; x < 2; x++)
+ bias->at<loco::DataType::FLOAT32>(x) = 0.0;
+ }
+}
+
+} // namespace
+
+/// Test code called by TEST(..)
+/// This tests whether Conv2D - FusedTFLType is fused.
+template <class FusedTFLType, locoex::FusedActFunc FusedActFunc> void test()
+{
+ static_assert((std::is_same<FusedTFLType, locoex::TFLRelu>::value &&
+ FusedActFunc == locoex::FusedActFunc::RELU) ||
+ (std::is_same<FusedTFLType, locoex::TFLRelu6>::value &&
+ FusedActFunc == locoex::FusedActFunc::RELU6),
+ "wrong template type");
+
+ exo::test::TestGraph g;
+ {
+ auto filter = g.append<locoex::TFLConst>();
+ auto bias = g.append<locoex::TFLConst>();
+ auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias);
+
+ auto fusable_node = g.append<FusedTFLType>(conv2d);
+
+ g.complete(fusable_node);
+
+ init(g.pull);
+ init(conv2d, filter, bias);
+ }
+
+ // let's run fusion
+ {
+ exo::test::TypeShapeReadyPhase test_phase;
+
+ test_phase.add_pass<exo::FuseReluPass>();
+ test_phase.add_pass<logo::RemoveDeadNodePass>(); // to remove TFLRelu
+ test_phase.run(g.graph());
+ }
+
+ auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph());
+ ASSERT_TRUE(a_conv2d != nullptr);
+ ASSERT_TRUE(a_conv2d->fusedActivationFunction() == FusedActFunc);
+
+ auto removed_fusable_node = exo::test::find_first_node_bytype<FusedTFLType>(g.graph());
+ ASSERT_TRUE(removed_fusable_node == nullptr);
+}
+
+// A case with Conv2D-Relu
+TEST(FuseReluTest, Conv2D_Relu_basic) { test<locoex::TFLRelu, locoex::FusedActFunc::RELU>(); }
+
+// A case with Conv2D-Relu6
+TEST(FuseReluTest, Conv2D_Relu6_basic) { test<locoex::TFLRelu6, locoex::FusedActFunc::RELU6>(); }
diff --git a/compiler/exo/src/Pass/FuseRsqrtPass.cpp b/compiler/exo/src/Pass/FuseRsqrtPass.cpp
new file mode 100644
index 000000000..08d704139
--- /dev/null
+++ b/compiler/exo/src/Pass/FuseRsqrtPass.cpp
@@ -0,0 +1,95 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "FuseRsqrtPass.h"
+
+#include "Check.h"
+
+#include "Dialect/IR/TFLNodes.h"
+
+namespace
+{
+
+/**
+ * @return Casted TFLDiv for fusable candidate, nullptr otherwise
+ *
+ * This helper checkes fusability with following conditions:
+ * - TFLDiv has no activation
+ * - TFLDiv's first argument is TFLConst with all value 1
+ * - TFLDiv's second argument is TFLSqrt
+ */
+locoex::TFLDiv *as_candidate(loco::Node *node)
+{
+ auto div = dynamic_cast<locoex::TFLDiv *>(node);
+ if (not div)
+ return nullptr;
+
+ // Cannot fuse Div with activation function
+ if (div->fusedActivationFunction() != locoex::FusedActFunc::NONE)
+ return nullptr;
+
+ auto const_one = dynamic_cast<locoex::TFLConst *>(div->x());
+ if (not const_one)
+ return nullptr;
+
+ const loco::DataType FLOAT32 = loco::DataType::FLOAT32;
+ // TODO Support other dtype
+ EXO_ASSERT(const_one->dtype() == FLOAT32, "Only support FLOAT32 now");
+ for (uint32_t i = 0; i < const_one->size<FLOAT32>(); ++i)
+ if (const_one->at<FLOAT32>(i) != 1.0f)
+ return nullptr;
+
+ auto sqrt = dynamic_cast<locoex::TFLSqrt *>(div->y());
+ if (not sqrt)
+ return nullptr;
+
+ return div;
+}
+
+void fuse_rsqrt(locoex::TFLDiv *div)
+{
+ auto sqrt = dynamic_cast<locoex::TFLSqrt *>(div->y());
+ EXO_ASSERT(sqrt, "sqrt should be valid at this point");
+
+ // TFLRsqrt to replace
+ auto rsqrt = div->graph()->nodes()->create<locoex::TFLRsqrt>();
+ rsqrt->x(sqrt->x());
+
+ // replace
+ loco::replace(div).with(rsqrt);
+}
+
+} // namespace
+
+namespace exo
+{
+
+bool FuseRsqrtPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto div = as_candidate(node))
+ {
+ fuse_rsqrt(div);
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Pass/FuseRsqrtPass.h b/compiler/exo/src/Pass/FuseRsqrtPass.h
new file mode 100644
index 000000000..1e60e4a49
--- /dev/null
+++ b/compiler/exo/src/Pass/FuseRsqrtPass.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __FUSE_RSQRT_PASS_H__
+#define __FUSE_RSQRT_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace exo
+{
+
+/**
+ * @brief Class to fuse TFLSqrt that is divided(TFLDiv) by 1, into TFLRsqrt
+ *
+ * <BEFORE>
+ *
+ * TFLConst(1) ------
+ * \
+ * A --- TFLSqrt --- TFLDiv --- B
+ *
+ * <AFTER>
+ *
+ * A --- TFLRsqrt --- B
+ */
+struct FuseRsqrtPass final : public logo::Pass
+{
+ const char *name(void) const final { return "exo::FuseRsqrtPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace exo
+
+#endif // __FUSE_RSQRT_PASS_H__
diff --git a/compiler/exo/src/Pass/FuseSquaredDifferencePass.cpp b/compiler/exo/src/Pass/FuseSquaredDifferencePass.cpp
new file mode 100644
index 000000000..3f985a505
--- /dev/null
+++ b/compiler/exo/src/Pass/FuseSquaredDifferencePass.cpp
@@ -0,0 +1,86 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "FuseSquaredDifferencePass.h"
+
+#include "Check.h"
+
+#include "Dialect/IR/TFLNodes.h"
+
+namespace
+{
+
+/**
+ * @return Casted TFLMul for fusable candidate, nullptr otherwise
+ *
+ * This helper checkes fusability with following conditions:
+ * - TFLMul has no activation
+ * - TFLMul's first and second arguments are equal and TFLSub
+ */
+locoex::TFLMul *as_candidate(loco::Node *node)
+{
+ auto mul = dynamic_cast<locoex::TFLMul *>(node);
+ if (not mul)
+ return nullptr;
+
+ // Cannot fuse mul with activation function
+ if (mul->fusedActivationFunction() != locoex::FusedActFunc::NONE)
+ return nullptr;
+
+ if (mul->x() != mul->y())
+ return nullptr;
+
+ if (not dynamic_cast<locoex::TFLSub *>(mul->x()))
+ return nullptr;
+
+ return mul;
+}
+
+void fuse_squared_difference(locoex::TFLMul *mul)
+{
+ auto sub = dynamic_cast<locoex::TFLSub *>(mul->x());
+ EXO_ASSERT(sub, "sub should be valid at this point");
+
+ // TFLSquaredDifference to replace
+ auto sq_diff = mul->graph()->nodes()->create<locoex::TFLSquaredDifference>();
+ sq_diff->x(sub->x());
+ sq_diff->y(sub->y());
+
+ // replace
+ loco::replace(mul).with(sq_diff);
+}
+
+} // namespace
+
+namespace exo
+{
+
+bool FuseSquaredDifferencePass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto mul = as_candidate(node))
+ {
+ fuse_squared_difference(mul);
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Pass/FuseSquaredDifferencePass.h b/compiler/exo/src/Pass/FuseSquaredDifferencePass.h
new file mode 100644
index 000000000..dbc15149f
--- /dev/null
+++ b/compiler/exo/src/Pass/FuseSquaredDifferencePass.h
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __FUSE_SQUARED_DIFFERENCE_PASS_H__
+#define __FUSE_SQUARED_DIFFERENCE_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace exo
+{
+
+/**
+ * @brief Class to fuse SquaredDifference pattern
+ *
+ * <BEFORE>
+ *
+ * A --- TFLSub --- TFLMul --- C
+ * / \ /
+ * B ---- -----
+ *
+ * <AFTER>
+ *
+ * A --- TFLSquaredDifference --- C
+ * /
+ * B ----
+ */
+struct FuseSquaredDifferencePass final : public logo::Pass
+{
+ const char *name(void) const final { return "exo::FuseSquaredDifferencePass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace exo
+
+#endif // __FUSE_SQUARED_DIFFERENCE_PASS_H__
diff --git a/compiler/exo/src/Pass/MergeConcatNodesPass.cpp b/compiler/exo/src/Pass/MergeConcatNodesPass.cpp
new file mode 100644
index 000000000..8945fcfce
--- /dev/null
+++ b/compiler/exo/src/Pass/MergeConcatNodesPass.cpp
@@ -0,0 +1,191 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "MergeConcatNodesPass.h"
+#include "Dialect/IR/TFLNodes.h"
+
+#include <oops/InternalExn.h>
+
+#include <vector>
+
+namespace
+{
+
+bool canMerge(locoex::TFLConcatenation *node1, locoex::TFLConcatenation *node2)
+{
+ if (node1->fusedActivationFunction() != node2->fusedActivationFunction())
+ return false;
+
+ if (node1->axis() != node2->axis())
+ return false;
+
+ switch (node1->fusedActivationFunction())
+ {
+ case locoex::FusedActFunc::NONE:
+ case locoex::FusedActFunc::RELU:
+ case locoex::FusedActFunc::RELU6:
+ return true;
+
+ // case locoex::FusedActFunc::TANH:
+ // return false;
+
+ default:
+ INTERNAL_EXN_V("Unknown FusedActFunc", oops::to_uint32(node1->fusedActivationFunction()));
+ }
+}
+
+/**
+ * @brief Collect all the inputs of newly created TFLConcatenation nodes
+ *
+ * in:0 -------------------------------\
+ * in:1 ---- TFLConcatenation:0 -------- TFLConcatenation:3 --- C
+ * (axis = 0, NONE) (axis = 0, NONE)
+ * in:2 ---/ /
+ * in:3 ---- TFLConcatenation:1 ------/
+ * (axis = 1, NONE) /
+ * in:4 ---/ /
+ * in:5 ---- TFLConcatenation:2 ---/
+ * (axis = 0, RELU)
+ * in:6 ---/
+ *
+ * For exmaple, if graph is like above, dfs(TFLConcatenation:3) will
+ * return [in:0, in:1, in:2, TFLConcatenation:1, TFLConcatenation:2]
+ *
+ * TFLConcatenation:0 can be merged to TFLConcatenation:3,
+ * because axis and fusedActivationFunction are same.
+ * It means that [in:1, in:2] will be linked as inputs of new TFLConcatenation.
+ *
+ * However, TFLConcatenation:1 and TFLConcatenation:2 cannot be merged to
+ * TFLConcatenation:3 because axis and fusedActivationFunction of each are different.
+ * So [in:3, in:4, in:5, in:6] will not be linked as inputs of new TFLConcatenation
+ * and [TFLConcatenation:1, TFLConcatenation:2] will be linked instead.
+ *
+ * Therefore, inputs of newly created TFLConcatenation node for merging
+ * TFLConcatenation:3 will be [in:0, in:1, in:2, TFLConcatenation:1, TFLConcatenation:2]
+ * and dfs(TFLConcatenation:3) will return it.
+ *
+ *
+ * @note The input nodes should be traversed by LRV,
+ * which is from left to right (input:0 --> input:N)
+ */
+std::vector<loco::Node *> dfs(locoex::TFLConcatenation *root)
+{
+ std::vector<loco::Node *> res;
+
+ for (uint32_t i = 0; i < root->numValues(); ++i)
+ {
+ auto input = dynamic_cast<locoex::TFLConcatenation *>(root->values(i));
+ if (input != nullptr && canMerge(input, root))
+ {
+ auto children = dfs(input);
+ for (auto child : children)
+ res.push_back(child);
+ }
+ else
+ {
+ res.push_back(root->values(i));
+ }
+ }
+
+ return res;
+}
+
+} // namespace
+
+namespace exo
+{
+
+/**
+ * @brief Merge TFLConcatenate nodes whose axis and fusedActivationFunction are same
+ *
+ * [Before]
+ * in:0 -------------------------------\
+ * in:1 ---- TFLConcatenation:0 -------- TFLConcatenation:3 --- C
+ * (axis = 0, NONE) (axis = 0, NONE)
+ * in:2 ---/ /
+ * in:3 ---- TFLConcatenation:1 ------/
+ * (axis = 1, NONE) /
+ * in:4 ---/ /
+ * in:5 ---- TFLConcatenation:2 ---/
+ * (axis = 0, RELU)
+ * in:6 ---/
+ *
+ * [After]
+ * in:0 -------------------------------\
+ * in:1 -------------------------------- TFLConcatenation:4 --- C
+ * (axis = 0, NONE)
+ * in:2 -------------------------------/
+ * in:3 ---- TFLConcatenation:1 ------/
+ * (axis = 1, NONE) /
+ * in:4 ---/ /
+ * in:5 ---- TFLConcatenation:2 ---/
+ * (axis = 0, RELU)
+ * in:6 ---/
+ *
+ *
+ * in:1 ---- TFLConcatenation:0 ----
+ * (axis = 0, NONE)
+ * in:2 ---/
+ *
+ *
+ * ---- TFLConcatenation:3 ----
+ * (axis = 0, NONE)
+ */
+bool MergeConcatNodesPass::run(loco::Graph *graph)
+{
+ // Let's enumerate nodes required to compute output nodes
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+
+ // Find TFLConcatenation nodes which have another TFLConcatenation nodes
+ // as inputs, with same axis and same fusedActivationFunction
+ std::vector<locoex::TFLConcatenation *> candidates;
+ for (auto node : active_nodes)
+ {
+ if (auto concat = dynamic_cast<locoex::TFLConcatenation *>(node))
+ {
+ for (uint32_t i = 0; i < concat->numValues(); ++i)
+ {
+ auto input = dynamic_cast<locoex::TFLConcatenation *>(concat->values(i));
+ if (input != nullptr && canMerge(input, concat))
+ {
+ candidates.push_back(concat);
+ break;
+ }
+ }
+ }
+ }
+
+ // Merge multiple TFLConcatenation nodes as one TFLConcatenation node
+ for (auto node : candidates)
+ {
+ auto inputs = dfs(node);
+
+ auto new_concat = graph->nodes()->create<locoex::TFLConcatenation>(inputs.size());
+ new_concat->axis(node->axis());
+ new_concat->fusedActivationFunction(node->fusedActivationFunction());
+
+ for (uint32_t i = 0; i < inputs.size(); ++i)
+ new_concat->values(i, inputs.at(i));
+
+ loco::replace(node).with(new_concat);
+ for (uint32_t i = 0; i < node->numValues(); ++i)
+ node->values(i, nullptr);
+ }
+
+ return candidates.size() > 0;
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Pass/MergeConcatNodesPass.h b/compiler/exo/src/Pass/MergeConcatNodesPass.h
new file mode 100644
index 000000000..823214f43
--- /dev/null
+++ b/compiler/exo/src/Pass/MergeConcatNodesPass.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __PASS_MERGE_CONCAT_NODES_H__
+#define __PASS_MERGE_CONCAT_NODES_H__
+
+#include <loco.h>
+#include <logo/Pass.h>
+
+namespace exo
+{
+
+/**
+ * @brief Merge concat nodes whose axis and fusedActivationFunction are same
+ *
+ */
+class MergeConcatNodesPass : public logo::Pass
+{
+public:
+ virtual const char *name(void) const { return "exo::MergeConcatNodesPass"; }
+
+public:
+ bool run(loco::Graph *graph);
+};
+
+} // namespace exo
+
+#endif // __PASS_MERGE_CONCAT_NODES_H__
diff --git a/compiler/exo/src/Pass/ShapeInferencePass.cpp b/compiler/exo/src/Pass/ShapeInferencePass.cpp
new file mode 100644
index 000000000..bc60f91c4
--- /dev/null
+++ b/compiler/exo/src/Pass/ShapeInferencePass.cpp
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ShapeInferencePass.h"
+
+#include "Dialect/IR/TFLDialect.h"
+#include "Dialect/Service/TFLShapeInferenceRule.h"
+
+#include "Dialect/IR/CircleDialect.h"
+#include "Dialect/Service/CircleShapeInferenceRule.h"
+
+#include <loco.h>
+#include <loco/IR/CanonicalDialect.h>
+#include <loco/Service/CanonicalShapeInferenceRule.h>
+#include <loco/Service/ShapeInference.h>
+#include <loco/Service/MultiDialectShapeInferenceRule.h>
+
+#include <locoex/COpDialect.h>
+#include <locoex/Service/COpShapeInferenceRule.h>
+
+namespace exo
+{
+
+/**
+ * @note Currently, TFL and Circle backend share this inference. However, TFL
+ * backend does not require rule for Circle dialect.
+ * TODO Make dedicated inference pass for Circle Dialect.
+ */
+bool ShapeInferencePass::run(loco::Graph *g)
+{
+ loco::CanonicalShapeInferenceRule canonical_rule;
+ locoex::TFLShapeInferenceRule tfl_rule;
+ locoex::CircleShapeInferenceRule circle_rule;
+ locoex::COpShapeInferenceRule cop_rule;
+
+ loco::MultiDialectShapeInferenceRule rules;
+
+ rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
+ .bind(locoex::TFLDialect::get(), &tfl_rule)
+ .bind(locoex::CircleDialect::get(), &circle_rule)
+ .bind(locoex::COpDialect::get(), &cop_rule);
+
+ return loco::apply(&rules).to(g);
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Pass/ShapeInferencePass.h b/compiler/exo/src/Pass/ShapeInferencePass.h
new file mode 100644
index 000000000..518c87403
--- /dev/null
+++ b/compiler/exo/src/Pass/ShapeInferencePass.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __PASS_SHAPE_INFERENCE_PASS_H__
+#define __PASS_SHAPE_INFERENCE_PASS_H__
+
+#include <loco.h>
+#include <logo/Pass.h>
+
+namespace exo
+{
+
+/**
+ * @brief Pass to infer shape of nodes
+ */
+class ShapeInferencePass : public logo::Pass
+{
+public:
+ virtual const char *name(void) const { return "exo::ShapeInferencePass"; }
+
+public:
+ bool run(loco::Graph *graph);
+};
+
+} // namespace exo
+
+#endif //__PASS_SHAPE_INFERENCE_PASS_H__
diff --git a/compiler/exo/src/Pass/TypeInferencePass.cpp b/compiler/exo/src/Pass/TypeInferencePass.cpp
new file mode 100644
index 000000000..31d4f13b6
--- /dev/null
+++ b/compiler/exo/src/Pass/TypeInferencePass.cpp
@@ -0,0 +1,57 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TypeInferencePass.h"
+
+#include "Dialect/IR/TFLDialect.h"
+#include "Dialect/Service/TFLTypeInferenceRule.h"
+
+#include "Dialect/IR/CircleDialect.h"
+#include "Dialect/Service/CircleTypeInferenceRule.h"
+
+#include <loco.h>
+#include <loco/IR/CanonicalDialect.h>
+#include <loco/Service/TypeInference.h>
+
+#include <locoex/COpDialect.h>
+#include <locoex/Service/COpTypeInference.h>
+
+namespace exo
+{
+
+/**
+ * @note Currently, TFL and Circle backend share this inference. However, TFL
+ * backend does not require rule for Circle dialect.
+ * TODO Make dedicated inference pass for Circle Dialect.
+ */
+bool TypeInferencePass::run(loco::Graph *g)
+{
+ loco::CanonicalTypeInferenceRule canonical_rule;
+ locoex::TFLTypeInferenceRule tfl_rule;
+ locoex::CircleTypeInferenceRule circle_rule;
+ locoex::COpTypeInferenceRule cop_rule;
+
+ loco::MultiDialectTypeInferenceRule rules;
+
+ rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
+ .bind(locoex::TFLDialect::get(), &tfl_rule)
+ .bind(locoex::CircleDialect::get(), &circle_rule)
+ .bind(locoex::COpDialect::get(), &cop_rule);
+
+ return loco::apply(&rules).to(g);
+}
+
+} // namespace exo
diff --git a/compiler/exo/src/Pass/TypeInferencePass.h b/compiler/exo/src/Pass/TypeInferencePass.h
new file mode 100644
index 000000000..3ede587a0
--- /dev/null
+++ b/compiler/exo/src/Pass/TypeInferencePass.h
@@ -0,0 +1,42 @@
+
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __PASS_TYPE_INFERENCE_PASS_H__
+#define __PASS_TYPE_INFERENCE_PASS_H__
+
+#include <loco.h>
+
+#include <logo/Pass.h>
+
+namespace exo
+{
+
+/**
+ * @brief Pass to infer type of nodes
+ */
+class TypeInferencePass : public logo::Pass
+{
+public:
+ virtual const char *name(void) const { return "exo::TypeInferencePass"; }
+
+public:
+ bool run(loco::Graph *graph);
+};
+
+} // namespace exo
+
+#endif //__PASS_TYPE_INFERENCE_PASS_H__