diff options
author | Chunseok Lee <chunseok.lee@samsung.com> | 2020-04-23 14:45:49 +0900 |
---|---|---|
committer | Chunseok Lee <chunseok.lee@samsung.com> | 2020-04-23 14:45:49 +0900 |
commit | e2ef8438a24f7c56a0744eb579a6e293ee2fbf8e (patch) | |
tree | 44a1a7951d168dd4370e13593ed03f4bc6d920c5 /compiler/exo/src/Pass | |
parent | 302e6564a7a76109e1178207e44e45a58631c477 (diff) | |
download | nnfw-e2ef8438a24f7c56a0744eb579a6e293ee2fbf8e.tar.gz nnfw-e2ef8438a24f7c56a0744eb579a6e293ee2fbf8e.tar.bz2 nnfw-e2ef8438a24f7c56a0744eb579a6e293ee2fbf8e.zip |
Imported Upstream version 1.4.0upstream/1.4.0submit/tizen/20200423.054851
Diffstat (limited to 'compiler/exo/src/Pass')
22 files changed, 2565 insertions, 0 deletions
diff --git a/compiler/exo/src/Pass/FoldReshapeOfConstPass.cpp b/compiler/exo/src/Pass/FoldReshapeOfConstPass.cpp new file mode 100644 index 000000000..0fdcea939 --- /dev/null +++ b/compiler/exo/src/Pass/FoldReshapeOfConstPass.cpp @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FoldReshapeOfConstPass.h" + +#include "Check.h" + +#include "Dialect/IR/TFLNodes.h" +#include "Dialect/IR/TFLNodeVisitor.h" + +#include <loco/Service/ShapeInference.h> + +#include <oops/InternalExn.h> + +namespace +{ + +/** + * @brief Check if node is TFLReshape and its input is TFLConst + * @return Casted TFLReshape for foldable candidate, nullptr otherwise + */ +locoex::TFLReshape *as_candidate(loco::Node *node) +{ + auto reshape = dynamic_cast<locoex::TFLReshape *>(node); + if (not reshape) + return nullptr; + + // Only accept Constant input of Reshape + if (not dynamic_cast<locoex::TFLConst *>(reshape->tensor())) + return nullptr; + + return reshape; +} + +uint32_t volume(loco::Node *tensor_node) +{ + auto shape = loco::shape_get(tensor_node).as<loco::TensorShape>(); + + uint32_t vol = 1; + for (uint32_t axis = 0; axis < shape.rank(); ++axis) + vol *= shape.dim(axis).value(); + + return vol; +} + +void fold_reshape_of_const(locoex::TFLReshape *reshape) +{ + const loco::DataType FLOAT32 = loco::DataType::FLOAT32; + + auto const_orig = dynamic_cast<locoex::TFLConst *>(reshape->tensor()); + + // Exceptions + { + EXO_ASSERT(const_orig, "Only support for Reshape-Const pair"); + // TODO support other data types + if (const_orig->dtype() != FLOAT32) + INTERNAL_EXN_V("NYI for this data type", oops::to_uint32(const_orig->dtype())); + + if (volume(const_orig) != volume(reshape)) + INTERNAL_EXN("New shape of Reshape is not matched"); + } + + auto new_shape = loco::shape_get(reshape).as<loco::TensorShape>(); + + // TFLConst to replace + auto const_new = reshape->graph()->nodes()->create<locoex::TFLConst>(); + + const_new->dtype(FLOAT32); + const_new->rank(new_shape.rank()); + const_new->size<FLOAT32>(const_orig->size<FLOAT32>()); + for (uint32_t axis = 0; axis < new_shape.rank(); ++axis) + const_new->dim(axis) = new_shape.dim(axis); + + for (uint32_t i = 0; i < const_new->size<FLOAT32>(); ++i) + { + const_new->at<FLOAT32>(i) = const_orig->at<FLOAT32>(i); + } + + // replace + loco::replace(reshape).with(const_new); +} + +} // namespace + +namespace exo +{ + +bool FoldReshapeOfConstPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto reshape = as_candidate(node)) + { + fold_reshape_of_const(reshape); + changed = true; + } + } + + return changed; +} + +} // namespace exo diff --git a/compiler/exo/src/Pass/FoldReshapeOfConstPass.h b/compiler/exo/src/Pass/FoldReshapeOfConstPass.h new file mode 100644 index 000000000..10f8004bf --- /dev/null +++ b/compiler/exo/src/Pass/FoldReshapeOfConstPass.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __PASS_FOLD_RESHAPE_OF_CONST_PASS_H__ +#define __PASS_FOLD_RESHAPE_OF_CONST_PASS_H__ + +#include <logo/Pass.h> + +namespace exo +{ + +/** + * @brief Class to fuse TFLReshape + TFLConst into one equivalent TFLConst + * + * <before> + * TFLConst --- TFLReshape --- Out + * + * <after> + * TFLConst --- TFLReshape --- + * TFLConst (new) ------------ Out + * + * TODO This pass is for temporary. Deprecate this pass. + */ +struct FoldReshapeOfConstPass final : public logo::Pass +{ + const char *name(void) const final { return "exo::FoldReshapeOfConstPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace exo + +#endif // __PASS_FOLD_RESHAPE_OF_CONST_PASS_H__ diff --git a/compiler/exo/src/Pass/FoldTransposeOfConstPass.cpp b/compiler/exo/src/Pass/FoldTransposeOfConstPass.cpp new file mode 100644 index 000000000..005c42944 --- /dev/null +++ b/compiler/exo/src/Pass/FoldTransposeOfConstPass.cpp @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FoldTransposeOfConstPass.h" + +#include "Check.h" + +#include "Dialect/IR/TFLNodes.h" +#include "Dialect/IR/TFLNodeVisitor.h" + +// TODO remove dependency to angkor +#include <nncc/core/ADT/tensor/IndexEnumerator.h> +#include <nncc/core/ADT/tensor/LexicalLayout.h> + +#include <oops/InternalExn.h> + +namespace +{ + +/** + * @brief Check if node is TFLTranspose and its input is TFLConst + * @return Casted TFLTranspose for foldable candidate, nullptr otherwise + */ +locoex::TFLTranspose *as_candidate(loco::Node *node) +{ + auto transpose = dynamic_cast<locoex::TFLTranspose *>(node); + if (not transpose) + return nullptr; + + // Only accept Constant input of Transpose + if (not dynamic_cast<locoex::TFLConst *>(transpose->a())) + return nullptr; + + // Only accept Constant permutation of Transpose + if (not dynamic_cast<locoex::TFLConst *>(transpose->perm())) + return nullptr; + + return transpose; +} + +nncc::core::ADT::tensor::Shape angkor_shape(locoex::TFLConst *node) +{ + nncc::core::ADT::tensor::Shape ret; + + ret.resize(node->rank()); + for (uint32_t axis = 0; axis < node->rank(); ++axis) + { + ret.dim(axis) = node->dim(axis).value(); + } + + return ret; +} + +void fold_transpose_of_const(locoex::TFLTranspose *transpose) +{ + const loco::DataType FLOAT32 = loco::DataType::FLOAT32; + const loco::DataType S32 = loco::DataType::S32; + + auto const_orig = dynamic_cast<locoex::TFLConst *>(transpose->a()); + auto perm = dynamic_cast<locoex::TFLConst *>(transpose->perm()); + + // Exceptions + { + EXO_ASSERT(const_orig, "Only support for Transpose-Const pair"); + // TODO support other data types + if (const_orig->dtype() != FLOAT32) + INTERNAL_EXN_V("NYI for this data type", oops::to_uint32(const_orig->dtype())); + + EXO_ASSERT(perm, "Only support for constant permutation for Transpose"); + // TODO support other data types + if (perm->dtype() != S32) + INTERNAL_EXN_V("NYI for this data type", oops::to_uint32(perm->dtype())); + + auto okay = [&]() { + if (perm->rank() != 1) + return false; + if (perm->dim(0).value() != const_orig->rank()) + return false; + return true; + }; + if (not okay()) + INTERNAL_EXN("Input and permutation for Transpose is not congruent"); + } + + uint32_t rank = const_orig->rank(); + + // TFLConst to replace + auto const_new = transpose->graph()->nodes()->create<locoex::TFLConst>(); + + const_new->dtype(FLOAT32); + const_new->rank(rank); + const_new->size<FLOAT32>(const_orig->size<FLOAT32>()); + for (uint32_t axis = 0; axis < rank; ++axis) + const_new->dim(axis) = const_orig->dim(perm->at<S32>(axis)).value(); + + // TODO remove dependency to angkor + auto shape_orig = angkor_shape(const_orig); + auto shape_new = angkor_shape(const_new); + + nncc::core::ADT::tensor::LexicalLayout l; + nncc::core::ADT::tensor::IndexEnumerator e{shape_new}; + + for (; e.valid(); e.advance()) + { + loco::TensorIndex index_new = e.current(); + loco::TensorIndex index_orig; + + // Set original index from matching new index + index_orig.resize(rank); + for (uint32_t axis = 0; axis < rank; ++axis) + index_orig.at(perm->at<S32>(axis)) = index_new.at(axis); + + const_new->at<FLOAT32>(l.offset(shape_new, index_new)) = + const_orig->at<FLOAT32>(l.offset(shape_orig, index_orig)); + } + + // replace + loco::replace(transpose).with(const_new); +} + +} // namespace + +namespace exo +{ + +bool FoldTransposeOfConstPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto transpose = as_candidate(node)) + { + fold_transpose_of_const(transpose); + changed = true; + } + } + + return changed; +} + +} // namespace exo diff --git a/compiler/exo/src/Pass/FoldTransposeOfConstPass.h b/compiler/exo/src/Pass/FoldTransposeOfConstPass.h new file mode 100644 index 000000000..26656a118 --- /dev/null +++ b/compiler/exo/src/Pass/FoldTransposeOfConstPass.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __PASS_FOLD_TRANSPOSE_OF_CONST_PASS_H__ +#define __PASS_FOLD_TRANSPOSE_OF_CONST_PASS_H__ + +#include <logo/Pass.h> + +namespace exo +{ + +/** + * @brief Class to fuse TFLTranspose + TFLConst into one equivalent TFLConst + * + * <before> + * TFLConst --- TFLTranspose --- Out + * + * <after> + * TFLConst --- TFLTranspose --- + * TFLConst (new) -------------- Out + * + * TODO This pass is for temporary. Deprecate this pass. + */ +struct FoldTransposeOfConstPass final : public logo::Pass +{ + const char *name(void) const final { return "exo::FoldTransposeOfConstPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace exo + +#endif // __PASS_FOLD_TRANSPOSE_OF_CONST_PASS_H__ diff --git a/compiler/exo/src/Pass/FuseBiasAddPass.cpp b/compiler/exo/src/Pass/FuseBiasAddPass.cpp new file mode 100644 index 000000000..aab820995 --- /dev/null +++ b/compiler/exo/src/Pass/FuseBiasAddPass.cpp @@ -0,0 +1,362 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FuseBiasAddPass.h" + +#include "Dialect/IR/TFLNodes.h" +#include "Dialect/IR/TFLDialect.h" +#include "Dialect/IR/TFLNodeVisitor.h" + +#include <loco/Service/TypeInference.h> +#include <loco/Service/ShapeInference.h> + +#include <oops/InternalExn.h> + +#include <set> + +/* + Note: Terms for variables in this implementation is as follows: + + ex) subgraph handled: TFLConv2D -------- TFLAdd + (or TFLDepthwiseConv2D) (or TFLSub) + | | + \|/ \|/ + variable name : former latter + Type : FormerT LatterT + (shortened name from Mixin) (template type) +*/ +namespace +{ + +using FormerT = locoex::TFLNodeMixin<locoex::TFLNodeTrait::Bias>; + +loco::Node *as_loco_node(FormerT *former) +{ + auto loco_node = dynamic_cast<loco::Node *>(former); + assert(loco_node != nullptr); + + return loco_node; +} + +locoex::TFLConst *get_const(loco::Node *x, loco::Node *y) +{ + if (auto const_node = dynamic_cast<locoex::TFLConst *>(x)) + return const_node; + else if (auto const_node = dynamic_cast<locoex::TFLConst *>(y)) + return const_node; + + return nullptr; +} + +FormerT *get_former(loco::Node *x, loco::Node *y) +{ + if (auto node = dynamic_cast<FormerT *>(x)) + return node; + else if (auto node = dynamic_cast<FormerT *>(y)) + return node; + + return nullptr; +} + +/// @brief Finds input that is TFLConst and set it to new_input +void set_const_input(locoex::TFLNode *node, locoex::TFLConst *new_input) +{ + if (auto add = dynamic_cast<locoex::TFLAdd *>(node)) + { + if (dynamic_cast<locoex::TFLConst *>(add->x())) + add->x(new_input); + else if (dynamic_cast<locoex::TFLConst *>(add->y())) + add->y(new_input); + else + assert(false and "One node should be TFLConst"); + + return; + } + + if (auto sub = dynamic_cast<locoex::TFLSub *>(node)) + { + if (dynamic_cast<locoex::TFLConst *>(sub->x())) + sub->x(new_input); + else if (dynamic_cast<locoex::TFLConst *>(sub->y())) + sub->y(new_input); + else + assert(false and "One node should be TFLConst"); + + return; + } + + assert(false and "Param should be TFLAdd or TFLSub"); +} + +/** + * @brief Creates a TFLConst whose shape is [to] and values are all const_node->at(0), + * where const_node has only one element(a scalar or a tensor of shape [1]) + */ +locoex::TFLConst *create_widened(locoex::TFLConst *const_node, uint32_t to) +{ + auto const_shape = loco::shape_get(const_node).as<loco::TensorShape>(); + + assert(const_shape.rank() == 0 or (const_shape.rank() == 1 and const_shape.dim(0) == 1)); + + auto g = const_node->graph(); + + auto widened_const = g->nodes()->create<locoex::TFLConst>(); + { + widened_const->dtype(loco::DataType::FLOAT32); + widened_const->rank(1); + widened_const->dim(0) = to; + widened_const->size<loco::DataType::FLOAT32>(to); + for (uint32_t x = 0; x < to; x++) + widened_const->at<loco::DataType::FLOAT32>(x) = const_node->at<loco::DataType::FLOAT32>(0); + } + return widened_const; +} + +template <typename TFLType> float calc(float, float); + +template <> float calc<locoex::TFLAdd>(float x, float y) { return x + y; } +template <> float calc<locoex::TFLSub>(float x, float y) { return x - y; } + +template <class LatterT> class Fuser +{ +public: + Fuser(LatterT *latter) + { + static_assert(std::is_same<LatterT, locoex::TFLAdd>::value || + std::is_same<LatterT, locoex::TFLSub>::value, + "wrong template type"); + + _latter = latter; + _graph = _latter->graph(); + _const_node = get_const(_latter->x(), _latter->y()); + _former = get_former(_latter->x(), _latter->y()); + + assert(_const_node && _former); + } + + void fuse(void); + +private: + loco::Graph *_graph; + LatterT *_latter; + locoex::TFLConst *_const_node; + FormerT *_former; + + locoex::TFLConst *create_fused_bias_const(); +}; + +// instantiation +template class Fuser<locoex::TFLAdd>; +template class Fuser<locoex::TFLSub>; + +template <class LatterT> locoex::TFLConst *Fuser<LatterT>::create_fused_bias_const() +{ + // we have to create a new bias const by adding/substracting bias and const node (of TFLAdd or + // TFLSub) + auto bias = dynamic_cast<locoex::TFLConst *>(_former->bias()); + assert(bias->dtype() == loco::DataType::FLOAT32 && + _const_node->dtype() == loco::DataType::FLOAT32); + + assert(bias->rank() == 1 && _const_node->rank() == 1); + assert(bias->dim(0) == _const_node->dim(0)); + + // build a new bias const + auto new_bias = _graph->nodes()->create<locoex::TFLConst>(); + { + new_bias->dtype(loco::DataType::FLOAT32); + + new_bias->rank(1); + new_bias->dim(0) = bias->dim(0); + + new_bias->size<loco::DataType::FLOAT32>(bias->dim(0).value()); + + for (uint32_t x = 0; x < bias->dim(0).value(); x++) + new_bias->at<loco::DataType::FLOAT32>(x) = calc<LatterT>( + bias->at<loco::DataType::FLOAT32>(x), _const_node->at<loco::DataType::FLOAT32>(x)); + } + + return new_bias; +} + +// FuseBiasAddPass works when former->fusedActivationFunction() == NONE +bool check_act_func(FormerT *former) +{ + using FusedActFuncMixin = locoex::TFLNodeMixin<locoex::TFLNodeTrait::FusedActFunc>; + + if (auto node = dynamic_cast<FusedActFuncMixin *>(former)) + return node->fusedActivationFunction() == locoex::FusedActFunc::NONE; + else + return true; +} + +template <class LatterT> void set_act_func(FormerT *former, LatterT *latter) +{ + using FusedActFuncMixin = locoex::TFLNodeMixin<locoex::TFLNodeTrait::FusedActFunc>; + + if (auto node = dynamic_cast<FusedActFuncMixin *>(former)) + node->fusedActivationFunction(latter->fusedActivationFunction()); +} + +// instantiation +template void set_act_func(FormerT *, locoex::TFLAdd *); +template void set_act_func(FormerT *, locoex::TFLSub *); + +/** + * @brief Fuse TFLAdd or TFLSub (latter) into TFLConv2d or TFLDepthwiseConv2D (former). + * All conditions should be checked before calling this. + * + * @note TFLAdd can have fused activation function (let's call this FAF for simplicity). + * + * Conv2D's FAF | TFLAdd's FAF => FAF after fusing TFLAdd into TFLConv2D + * ----------------|--------------- -------------------------------------- + * NONE | NONE, RELU or RELU6 => TFLAdd's FAF + * other than NONE | anything => cannot be fused + */ +template <class LatterT> void Fuser<LatterT>::fuse(void) +{ + // check fused activation function + { + assert(check_act_func(_former)); + + set_act_func<LatterT>(_former, _latter); + } + + auto new_bias = create_fused_bias_const(); + + // replace node with new_bias + // note that loco::replace() is not used because bias could be input of other op just in case + _former->bias(new_bias); + + // remove TFLAdd or TFLSub node + loco::replace(_latter).with(as_loco_node(_former)); + _latter->x(nullptr); + _latter->y(nullptr); +} + +struct Collector final : public locoex::TFLNodeMutableVisitor<void> +{ + template <class LatterT> + void setCandidate(FormerT *former, LatterT *latter, locoex::TFLConst *const_node) + { + static_assert(std::is_same<LatterT, locoex::TFLAdd>::value || + std::is_same<LatterT, locoex::TFLSub>::value, + "wrong template type"); + + if (!check_act_func(former)) + return; + + auto depth = + loco::shape_get(as_loco_node(former)).template as<loco::TensorShape>().dim(3).value(); + auto const_shape = loco::shape_get(const_node).template as<loco::TensorShape>(); + + if (const_shape.rank() == 1 and const_shape.dim(0) == depth) + { + candidates.insert(latter); + } + // when Const has only one value, create a new const with shape [depth] + else if (const_shape.rank() == 0 or (const_shape.rank() == 1 and const_shape.dim(0) == 1)) + { + if (!(loco::dtype_get(as_loco_node(former)) == loco::DataType::FLOAT32)) + INTERNAL_EXN_V("Unsupported data type", + oops::to_uint32(loco::dtype_get(as_loco_node(former)))); + if (!(const_node->dtype() == loco::DataType::FLOAT32)) + INTERNAL_EXN_V("Unsupported data type", oops::to_uint32(const_node->dtype())); + + auto new_bias_node = create_widened(const_node, depth); + + // Replacing TFLConst input of TFLAdd or TFLSub. + // Note that calling loco::replace(const_node).with(new_bias_node) could be dangerous + // because const_node could be the input of many nodes + set_const_input(latter, new_bias_node); + + candidates.insert(latter); + } + } + + void visit(locoex::TFLAdd *latter) final + { + auto former = get_former(latter->x(), latter->y()); + auto const_node = get_const(latter->x(), latter->y()); + + if (former && const_node) + setCandidate<locoex::TFLAdd>(former, latter, const_node); + } + + void visit(locoex::TFLSub *latter) final + { + // TFLSub, of which x() = TFLConv2D or TFLDepthwiseConv2D, y() = TFLConst, is fusing target + auto former = dynamic_cast<FormerT *>(latter->x()); + auto const_node = dynamic_cast<locoex::TFLConst *>(latter->y()); + + if (former && const_node) + setCandidate<locoex::TFLSub>(former, latter, const_node); + } + + void visit(locoex::TFLNode *) final { return; } + + std::set<locoex::TFLNode *> candidates; +}; + +struct Performer final : public locoex::TFLNodeMutableVisitor<void> +{ + void visit(locoex::TFLAdd *latter) final + { + assert(get_former(latter->x(), latter->y())); + + Fuser<locoex::TFLAdd> fuser(latter); + fuser.fuse(); + } + + void visit(locoex::TFLSub *latter) final + { + assert(get_former(latter->x(), latter->y())); + + Fuser<locoex::TFLSub> fuser(latter); + fuser.fuse(); + } + + void visit(locoex::TFLNode *) final { assert(false && "should not be called"); } +}; + +} // namespace + +namespace exo +{ + +bool FuseBiasAddPass::run(loco::Graph *g) +{ + Collector collector; + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (node->dialect() == locoex::TFLDialect::get()) + { + auto tfl_node = dynamic_cast<locoex::TFLNode *>(node); + tfl_node->accept(&collector); + } + } + + Performer performer; + + for (auto node : collector.candidates) + { + node->accept(&performer); + } + + return collector.candidates.size() > 0; +} + +} // namespace exo diff --git a/compiler/exo/src/Pass/FuseBiasAddPass.h b/compiler/exo/src/Pass/FuseBiasAddPass.h new file mode 100644 index 000000000..68e624c6b --- /dev/null +++ b/compiler/exo/src/Pass/FuseBiasAddPass.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __PASS_FUSE_BIASADD_PASS_H__ +#define __PASS_FUSE_BIASADD_PASS_H__ + +#include <logo/Pass.h> + +namespace exo +{ + +/** + * @brief Class to fuse TFLAdd or TFLSub into Bias input of the following ops: + * - TFLConv2D, TFLDepthwiseConv2D + * - TODO Consider to add FullyConnected, etc. + * + * Case 1. Conv2D and TFLAdd + * + * BEFORE: + * + * TFLConst A (a scalar or a tensor of shape [1] or [depth of TFLConv2D]) + * | + * Foo -- TFLConv2D -- TFLAdd (or TFLSub) -- Bar + * | + * TFLConst B --+ (bias) + * + * AFTER: + * Foo ----- TFLConv2D ----- Bar + * | + * TFLConst A' --+ (bias) + * + * TFLConst B (dead node) + * + * TFLAdd (or TFLSub) (dead node) + * + * @note TFLSub, of which x() == TFLConv2D and y() == TFLConst, will be fused. + * If x() == TFLConst and y() == TFLConv2D, it won't be fused. + */ +struct FuseBiasAddPass final : public logo::Pass +{ + const char *name(void) const final { return "exo::FuseBiasAddPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace exo + +#endif // __PASS_FUSE_BIASADD_PASS_H__ diff --git a/compiler/exo/src/Pass/FuseBiasAddPass.test.cpp b/compiler/exo/src/Pass/FuseBiasAddPass.test.cpp new file mode 100644 index 000000000..6ba728de0 --- /dev/null +++ b/compiler/exo/src/Pass/FuseBiasAddPass.test.cpp @@ -0,0 +1,361 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FuseBiasAddPass.h" + +#include "Dialect/IR/TFLNodes.h" +#include "TestGraph.h" +#include "TestHelper.h" + +#include <loco.h> + +#include <gtest/gtest.h> + +namespace +{ + +void init(loco::Pull *pull) +{ + pull->dtype(loco::DataType::FLOAT32); + pull->shape({2, 3, 3, 2}); +} + +/// @brief Initializes TFLConv2D and related filter and bias +void init(locoex::TFLConv2D *conv2d, locoex::TFLConst *filter, locoex::TFLConst *bias) +{ + // set conv2d + { + conv2d->fusedActivationFunction(locoex::FusedActFunc::NONE); + conv2d->padding(locoex::Padding::VALID); + } + + // set filter + { + filter->dtype(loco::DataType::FLOAT32); + filter->shape({2, 3, 3, 2}); + filter->size<loco::DataType::FLOAT32>(2 * 3 * 3 * 2); + + for (uint32_t x = 0; x < 2 * 3 * 3 * 2; x++) + filter->at<loco::DataType::FLOAT32>(x) = 0.0; + } + + // set bias + { + bias->dtype(loco::DataType::FLOAT32); + bias->shape({2}); + bias->size<loco::DataType::FLOAT32>(2); + + for (uint32_t x = 0; x < 2; x++) + bias->at<loco::DataType::FLOAT32>(x) = 0.0; + } +} + +template <class T> void init(T *node, locoex::FusedActFunc f) +{ + static_assert(std::is_same<T, locoex::TFLAdd>::value || std::is_same<T, locoex::TFLSub>::value, + "wrong template type"); + + node->fusedActivationFunction(f); +} + +/// @brief Initializes one param of TFLAdd or TFLSub +void init(locoex::TFLConst *addsub_param) +{ + // set addsub_param : y() value of TFLAdd or TFLSub + addsub_param->dtype(loco::DataType::FLOAT32); + addsub_param->shape({2}); + addsub_param->size<loco::DataType::FLOAT32>(2); + + for (uint32_t x = 0; x < 2; x++) + addsub_param->at<loco::DataType::FLOAT32>(x) = (x + 1) * 1.5; // 1.5, 3 +} + +} // namespace + +// A case when +// - TFLConv2D has bias (0, 0) +// - TFLAdd, of which x() or y() == TFLConv2D +// - Another param of TFLAdd is TFLConst, (1.5, 3) +// +// After fusion, bias shold be (1.5, 3) +TEST(FuseBiasAddPassTest, Conv2D_Add_01_basic) +{ + exo::test::TestGraph g; + auto filter = g.append<locoex::TFLConst>(); + auto bias = g.append<locoex::TFLConst>(); + auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias); + + auto add_y = g.append<locoex::TFLConst>(); + auto add = g.append<locoex::TFLAdd>(conv2d, add_y); + + g.complete(add); + + init(g.pull); + init(conv2d, filter, bias); + init(add, locoex::FusedActFunc::NONE); + init(add_y); + + // let's run fusion + { + exo::test::TypeShapeReadyPhase test_phase; + + test_phase.add_pass<exo::FuseBiasAddPass>(); + test_phase.run(g.graph()); + } + + auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph()); + ASSERT_TRUE(a_conv2d != nullptr); + + auto a_bias = dynamic_cast<locoex::TFLConst *>(a_conv2d->bias()); + ASSERT_TRUE(a_bias != nullptr); + + ASSERT_TRUE(a_bias->dim(0) == 2); + ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(0), + bias->at<loco::DataType::FLOAT32>(0) + add_y->at<loco::DataType::FLOAT32>(0)); + ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(1), + bias->at<loco::DataType::FLOAT32>(1) + add_y->at<loco::DataType::FLOAT32>(1)); +} + +// A case when +// - TFLConv2D has bias (0, 0) +// - TFLAdd, of which x() or y() == TFLConv2D +// - Another param of TFLAdd is TFLConst, (1.5) <-- scalar +// +// After fusion, bias shold be (1.5, 1.5) +TEST(FuseBiasAddPassTest, Conv2D_Add_02_TFLAdd_y_is_scalar) +{ + exo::test::TestGraph g; + auto filter = g.append<locoex::TFLConst>(); + auto bias = g.append<locoex::TFLConst>(); + auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias); + + auto add_y = g.append<locoex::TFLConst>(); + auto add = g.append<locoex::TFLAdd>(conv2d, add_y); + + g.complete(add); + + init(g.pull); + init(conv2d, filter, bias); // channel of conv2d is 2 + + { + // Size of this TFLConst is 1. + // Note that this should be widened later to the shape of [channel of Conv2D], which is [2] + add_y->dtype(loco::DataType::FLOAT32); + add_y->shape({1}); + add_y->size<loco::DataType::FLOAT32>(1); + add_y->at<loco::DataType::FLOAT32>(0) = 1.5; + } + init(add, locoex::FusedActFunc::NONE); + + // let's run fusion + { + exo::test::TypeShapeReadyPhase test_phase; + + test_phase.add_pass<exo::FuseBiasAddPass>(); + test_phase.run(g.graph()); + } + + auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph()); + ASSERT_TRUE(a_conv2d != nullptr); + + auto a_bias = dynamic_cast<locoex::TFLConst *>(a_conv2d->bias()); + ASSERT_TRUE(a_bias != nullptr); + + ASSERT_TRUE(a_bias->dim(0) == 2); + ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(0), + bias->at<loco::DataType::FLOAT32>(0) + 1.5); + ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(1), + bias->at<loco::DataType::FLOAT32>(1) + 1.5); +} + +// A case when +// - TFLConv2D has bias (0, 0) +// - TFLSub.x() == TFLConv2D +// - TFLSub.y() == TFLConst, (1.5, 3) +// +// After fusion, bias shold be (-1.5, -3) +TEST(FuseBiasAddPassTest, Conv2D_Sub_01_basic) +{ + exo::test::TestGraph g; + auto filter = g.append<locoex::TFLConst>(); + auto bias = g.append<locoex::TFLConst>(); + auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias); + + auto sub_y = g.append<locoex::TFLConst>(); + auto sub = g.append<locoex::TFLSub>(conv2d, sub_y); + + g.complete(sub); + + init(g.pull); + init(conv2d, filter, bias); + init(sub, locoex::FusedActFunc::NONE); + init(sub_y); + + // let's run fusion + { + exo::test::TypeShapeReadyPhase test_phase; + + test_phase.add_pass<exo::FuseBiasAddPass>(); + test_phase.run(g.graph()); + } + + auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph()); + ASSERT_TRUE(a_conv2d != nullptr); + + auto a_bias = dynamic_cast<locoex::TFLConst *>(a_conv2d->bias()); + ASSERT_TRUE(a_bias != nullptr); + + ASSERT_TRUE(a_bias->dim(0) == 2); + ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(0), + bias->at<loco::DataType::FLOAT32>(0) - sub_y->at<loco::DataType::FLOAT32>(0)); + ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(1), + bias->at<loco::DataType::FLOAT32>(1) - sub_y->at<loco::DataType::FLOAT32>(1)); +} + +// A case when TFLConv2D is input of TFLSub but fusion cannot be performed. +// - TFLSub.x() == TFLConst +// - TFLSub.y() == TFLConv2D +// +// Here, TFLSub cannot be fused into TFLConst. To be fused, TFLSub.x() should be TFLConv2D and +// TFLSub.y() should be TFLConst. So fusion will NOT happen. +TEST(FuseBiasAddPassTest, Conv2D_Sub_02_fusing_will_not_performed) +{ + exo::test::TestGraph g; + auto filter = g.append<locoex::TFLConst>(); + auto bias = g.append<locoex::TFLConst>(); + auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias); + + auto sub_y = g.append<locoex::TFLConst>(); + auto sub = g.append<locoex::TFLSub>(sub_y, conv2d); // This WON'T be fused + + g.complete(sub); + + init(g.pull); + init(conv2d, filter, bias); + init(sub, locoex::FusedActFunc::NONE); + init(sub_y); + + // let's run fusion + { + exo::test::TypeShapeReadyPhase test_phase; + + test_phase.add_pass<exo::FuseBiasAddPass>(); + test_phase.run(g.graph()); + } + + auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph()); + ASSERT_TRUE(a_conv2d != nullptr); + + auto a_bias = dynamic_cast<locoex::TFLConst *>(a_conv2d->bias()); + ASSERT_TRUE(a_bias != nullptr); + + ASSERT_TRUE(a_bias->dim(0) == 2); + ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(0), 0); + ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(1), 0); + + auto a_sub = exo::test::find_first_node_bytype<locoex::TFLSub>(g.graph()); + ASSERT_TRUE(a_sub != nullptr); + ASSERT_TRUE(a_sub->y() == a_conv2d); // Checking 'not-fused' state +} + +// A case when +// - TFLConv2D has an activation function with Relu +// - TFLAdd, has no activation function +// +// No fusion should happen +TEST(FuseBiasAddPassTest, Regression_Conv2D_Add_fused_action_00) +{ + exo::test::TestGraph g; + auto filter = g.append<locoex::TFLConst>(); + auto bias = g.append<locoex::TFLConst>(); + auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias); + + auto add_y = g.append<locoex::TFLConst>(); + auto add = g.append<locoex::TFLAdd>(conv2d, add_y); + + g.complete(add); + + init(g.pull); + init(conv2d, filter, bias); + init(add, locoex::FusedActFunc::NONE); + init(add_y); + + // Updating Fused Activation for this test + conv2d->fusedActivationFunction(locoex::FusedActFunc::RELU); + + // let's run fusion + { + exo::test::TypeShapeReadyPhase test_phase; + + test_phase.add_pass<exo::FuseBiasAddPass>(); + test_phase.run(g.graph()); + } + + auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph()); + ASSERT_TRUE(a_conv2d != nullptr); + ASSERT_TRUE(a_conv2d->fusedActivationFunction() == locoex::FusedActFunc::RELU); + + auto an_add = exo::test::find_first_node_bytype<locoex::TFLAdd>(g.graph()); + ASSERT_TRUE(an_add != nullptr); + ASSERT_TRUE(an_add->fusedActivationFunction() == locoex::FusedActFunc::NONE); + + ASSERT_TRUE(an_add->x() == a_conv2d or an_add->y() == a_conv2d); +} + +// A case when +// - TFLConv2D has NONE activation function +// - TFLAdd has Relu activation function +// +// TFLConv2D should have Relu activation function, TFLAdd is fused into bias input +TEST(FuseBiasAddPassTest, Regression_Conv2D_Add_fused_action_01) +{ + exo::test::TestGraph g; + auto filter = g.append<locoex::TFLConst>(); + auto bias = g.append<locoex::TFLConst>(); + auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias); + + auto add_y = g.append<locoex::TFLConst>(); + auto add = g.append<locoex::TFLAdd>(conv2d, add_y); + + g.complete(add); + + init(g.pull); + init(conv2d, filter, bias); + init(add, locoex::FusedActFunc::RELU); + init(add_y); + + // let's run fusion + { + exo::test::TypeShapeReadyPhase test_phase; + + test_phase.add_pass<exo::FuseBiasAddPass>(); + test_phase.run(g.graph()); + } + + auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph()); + ASSERT_TRUE(a_conv2d != nullptr); + + auto a_bias = dynamic_cast<locoex::TFLConst *>(a_conv2d->bias()); + ASSERT_TRUE(a_bias != nullptr); + + ASSERT_TRUE(a_bias->dim(0) == 2); + ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(0), + bias->at<loco::DataType::FLOAT32>(0) + add_y->at<loco::DataType::FLOAT32>(0)); + ASSERT_FLOAT_EQ(a_bias->at<loco::DataType::FLOAT32>(1), + bias->at<loco::DataType::FLOAT32>(1) + add_y->at<loco::DataType::FLOAT32>(1)); + + ASSERT_TRUE(a_conv2d->fusedActivationFunction() == locoex::FusedActFunc::RELU); +} diff --git a/compiler/exo/src/Pass/FuseInstanceNormPass.cpp b/compiler/exo/src/Pass/FuseInstanceNormPass.cpp new file mode 100644 index 000000000..04d4a62cd --- /dev/null +++ b/compiler/exo/src/Pass/FuseInstanceNormPass.cpp @@ -0,0 +1,402 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FuseInstanceNormPass.h" + +#include "Dialect/IR/TFLNodes.h" +#include "Dialect/IR/CircleNodes.h" + +#include <loco/Service/ShapeInference.h> + +#include <cassert> +#include <set> + +// Helper to find commutative node's arguments +namespace +{ + +/** + * INTRODUCTION + * Binary operation f(x,y) is 'commutative' when + * f(x,y) == f(y,x) holds for all x, y. + * For examples, ADD, MUL and SQUARED_DIFFERENCE are commutative. + * These helpers make it easy to find commutative arguemnts of commtative node. + * + * HOW TO USE + * COMM_NODE *node; + * ARG_TYPE_1 *arg1; + * ARG_TYPE_2 *arg2; + * + * bool ok = fill(&arg1, &arg2).with_commutative_args_of(node); + * + * Result + * If 'node's commutative argument types are actually {ARG_TYPE_1, ARG_TYPE_2} + * (as a set), 'arg1' and 'arg2' set as actual 'node's arguemnts with matching + * type, and return value 'ok' is true. + * Otherwise, 'arg1' and 'arg2' not changed, 'ok' is false. + */ + +template <class ARG_TYPE_1, class ARG_TYPE_2> class NodeFiller final +{ +public: + NodeFiller(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2) : _arg_1(arg_1), _arg_2(arg_2) + { + // DO NOTHING + } + + /** + * @return true When 'node's argument types are 'ARG_TYPE_1' and 'ARG_TYPE_2' + * In such case, it assign '_arg_1' and '_arg_2' to actual arguments + * + * @return false When 'node's argument types are NOT matched with 'ARG_TYPE_*' + * In such case, it does not amend '_arg_1' and '_arg_2' + * + * @require COMM_NODE has member x() and y() + */ + template <class COMM_NODE> bool with_commutative_args_of(const COMM_NODE *node); + +private: + ARG_TYPE_1 **_arg_1; + ARG_TYPE_2 **_arg_2; +}; + +template <class ARG_TYPE_1, class ARG_TYPE_2> +inline NodeFiller<ARG_TYPE_1, ARG_TYPE_2> fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2) +{ + return NodeFiller<ARG_TYPE_1, ARG_TYPE_2>{arg_1, arg_2}; +} + +template <class ARG_TYPE_1, class ARG_TYPE_2> +template <class COMM_NODE> +bool NodeFiller<ARG_TYPE_1, ARG_TYPE_2>::with_commutative_args_of(const COMM_NODE *node) +{ + // Case 1) X == ARG_TYPE_1 / Y == ARG_TYPE_2 + { + auto x = dynamic_cast<ARG_TYPE_1 *>(node->x()); + auto y = dynamic_cast<ARG_TYPE_2 *>(node->y()); + + if (x && y) + { + *_arg_1 = x; + *_arg_2 = y; + return true; + } + } + + // Case 2) X == ARG_TYPE_2 / Y == ARG_TYPE_1 + { + auto x = dynamic_cast<ARG_TYPE_2 *>(node->x()); + auto y = dynamic_cast<ARG_TYPE_1 *>(node->y()); + + if (x && y) + { + *_arg_1 = y; + *_arg_2 = x; + return true; + } + } + + return false; +} + +} // namespace + +// Helper to check detail +namespace +{ + +/// @return true When node has shape of '1 x .. x 1 x depth' +bool is_1D_with_dummy_dim(locoex::TFLConst *node, uint32_t depth) +{ + auto rank = node->rank(); + uint32_t axis; + for (axis = 0; axis < rank - 1; ++axis) + { + if (node->dim(axis).value() != 1) + return false; + } + return node->dim(axis).value() == depth; +} + +bool is_instance_mean(locoex::TFLMean *mean) +{ + // + // CHECK 1) input is rank 4 + // + auto input = mean->input(); + if (not loco::shape_known(input)) + return false; + auto input_shape = loco::shape_get(input).as<loco::TensorShape>(); + if (input_shape.rank() != 4) + return false; + + // + // CHECK 2) 'reduction indices' is TFLConst of value [1,2], that is HW of NHWC + // + // TODO Support equivalent case, like [-3,-2] + // TODO Support non-Const case? + // TODO What if input is NCHW format in Circle? + auto red_indices = dynamic_cast<locoex::TFLConst *>(mean->reduction_indices()); + if (not red_indices) + return false; + if (red_indices->rank() != 1) + return false; + std::set<int32_t> red_indices_set; + { + // TODO Currently only support S32, support other types + assert(red_indices->dtype() == loco::DataType::S32); + for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i) + red_indices_set.insert(red_indices->at<loco::DataType::S32>(i)); + } + if (red_indices_set.size() != 2) + return false; + if (red_indices_set.find(1) == red_indices_set.end()) + return false; + if (red_indices_set.find(2) == red_indices_set.end()) + return false; + + // + // CHECK 3) keep_dims == true (?) + // + // We only have case of 'keep_dims == true' so far, but it might be okay with 'keep_dims == false' + // TODO Check this fact, and if true, return true regardless of keep_dims + return mean->keep_dims(); +} + +} // namespace + +// Helper to fuse Instance Norm +namespace +{ + +/** + * SUBGRAPH PATTERN + * + * - Below diagram shows Instance Norm pattern to fuse. + * - Execution dependency order is top to the bottom. + * - Node name is matched with variable name of InstanceNormPattern class. + * - Usually, first word of node name (variable name) is node type. For e.g. + * variable 'mean_as_variance' is pointer to TFLMean. + * - (Item in parenthesis) means actually exist, but not having a name and + * not a variable of InstanceNormPattern class. + * + * TODO support other semantically same patterns for instance norm + * + * [In] + * | + * V + * +----------- ifm -----+ (reduction indicies) + * | | | | + * | | V V + * | | mean_of_ifm ----------------+ + * | V | | + * | sqdiff <--+ (reduction indicies) | + * | | | | + * | V | | + * | mean_as_variance <---+ const_as_epsilon | + * | | | | + * | V | | + * | add_as_variance <--------+ | + * | | | + * | V | + * | rsqrt const_as_gamma | + * | | | | + * | V | | + * | mul_gamma <--+ | + * | | | | + * V V V | + * mul_as_scaled_ifm mul_as_scaled_mean <-------------+ + * | | + * | const_as_beta | + * | | V + * | +------> sub + * V | + * add_as_terminal <----------+ + * | + * V + * [Out] + */ +class InstanceNormPattern final +{ +public: + InstanceNormPattern(locoex::TFLAdd *candidate) + { + assert(candidate); + add_as_terminal = candidate; + } + +public: + bool matched(); + bool matched() const { return _matched; } + +public: + // Context + loco::Node *ifm = nullptr; + locoex::TFLMean *mean_of_ifm = nullptr; + locoex::TFLSquaredDifference *sqdiff = nullptr; + locoex::TFLMean *mean_as_variance = nullptr; + locoex::TFLConst *const_as_epsilon = nullptr; + locoex::TFLAdd *add_as_variance = nullptr; + locoex::TFLRsqrt *rsqrt = nullptr; + locoex::TFLConst *const_as_gamma = nullptr; + locoex::TFLMul *mul_gamma = nullptr; + locoex::TFLMul *mul_as_scaled_ifm = nullptr; + locoex::TFLMul *mul_as_scaled_mean = nullptr; + locoex::TFLConst *const_as_beta = nullptr; + locoex::TFLSub *sub = nullptr; + locoex::TFLAdd *add_as_terminal = nullptr; + +private: + bool _matched = false; +}; + +bool InstanceNormPattern::matched() +{ + if (_matched) + return true; + +#define CHECK_OR_FALSE(condition) \ + if (not(condition)) \ + return false; + + // Check order is DFS + + CHECK_OR_FALSE(fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal)); + CHECK_OR_FALSE(fill(&ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm)); + + CHECK_OR_FALSE(loco::shape_known(ifm)); + auto ifm_shape = loco::shape_get(ifm); + CHECK_OR_FALSE(ifm_shape.domain() == loco::Domain::Tensor); + auto ifm_tensor_shape = ifm_shape.as<loco::TensorShape>(); + CHECK_OR_FALSE(ifm_tensor_shape.rank() == 4); + uint32_t ifm_channel_depth = ifm_tensor_shape.dim(3).value(); + + CHECK_OR_FALSE(fill(&rsqrt, &const_as_gamma).with_commutative_args_of(mul_gamma)); + CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_gamma, ifm_channel_depth)); + + add_as_variance = dynamic_cast<locoex::TFLAdd *>(rsqrt->x()); + CHECK_OR_FALSE(add_as_variance); + + CHECK_OR_FALSE( + fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance)); + + CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32); + // TODO Support regarding broadcast + CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1); + + CHECK_OR_FALSE(is_instance_mean(mean_as_variance)); + sqdiff = dynamic_cast<locoex::TFLSquaredDifference *>(mean_as_variance->input()); + CHECK_OR_FALSE(sqdiff); + + loco::Node *ifm_should_be = nullptr; + CHECK_OR_FALSE(fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff)); + CHECK_OR_FALSE(ifm == ifm_should_be); + CHECK_OR_FALSE(is_instance_mean(mean_of_ifm)); + CHECK_OR_FALSE(ifm == mean_of_ifm->input()); + + const_as_beta = dynamic_cast<locoex::TFLConst *>(sub->x()); + CHECK_OR_FALSE(const_as_beta); + CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_beta, ifm_channel_depth)); + + mul_as_scaled_mean = dynamic_cast<locoex::TFLMul *>(sub->y()); + CHECK_OR_FALSE(mul_as_scaled_mean); + + locoex::TFLMul *mul_gamma_should_be = nullptr; + locoex::TFLMean *mean_of_ifm_should_be = nullptr; + CHECK_OR_FALSE(fill(&mul_gamma_should_be, &mean_of_ifm_should_be) + .with_commutative_args_of(mul_as_scaled_mean)); + CHECK_OR_FALSE(mul_gamma == mul_gamma_should_be); + CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be); +#undef CHECK_OR_FALSE + _matched = true; + return true; +} + +/** + * Instance norm pattern would be fused like following diagram: + * + * [In] --------------------------- CircleInstanceNorm --- [Out] + * / / + * const_as_gamma --- TFLReshape --- / + * / + * const_as_beta ---- TFLReshape --- + * + * Note + * - 'const_as_gamma' and 'const_as_beta' are from original graph + * - Value of 'const_as_epsilon' would be copied to CircleInstanceNorm's attribute + * - TFLReshape is added as CircleInstanceNorm only accept 1D tensor + * - 'TFLConst --- TFLReshape' is expected to be fused in constant folding for Reshape + */ +void fuse_instance_norm(const InstanceNormPattern &p) +{ + assert(p.matched()); + + auto graph = p.add_as_terminal->graph(); + + // Make reshape for gamma & beta + auto reshape_gamma = graph->nodes()->create<locoex::TFLReshape>(); + auto reshape_beta = graph->nodes()->create<locoex::TFLReshape>(); + { + auto ifm_shape = loco::shape_get(p.ifm).as<loco::TensorShape>(); + uint32_t ifm_channel_depth = ifm_shape.dim(3).value(); + + int32_t new_shape[1] = {static_cast<int32_t>(ifm_channel_depth)}; + + reshape_gamma->tensor(p.const_as_gamma); + reshape_beta->tensor(p.const_as_beta); + + locoex::set_new_shape(reshape_gamma, new_shape, 1); + locoex::set_new_shape(reshape_beta, new_shape, 1); + } + + // Make Instance Norm to replace + auto instance_norm = graph->nodes()->create<locoex::CircleInstanceNorm>(); + instance_norm->input(p.ifm); + instance_norm->gamma(reshape_gamma); + instance_norm->beta(reshape_beta); + float epsilon = p.const_as_epsilon->at<loco::DataType::FLOAT32>(0); + instance_norm->epsilon(epsilon); + instance_norm->fusedActivationFunction(p.add_as_terminal->fusedActivationFunction()); + + replace(p.add_as_terminal).with(instance_norm); +} + +} // namespace + +namespace exo +{ + +bool FuseInstanceNormPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto add = dynamic_cast<locoex::TFLAdd *>(node); + if (not add) + continue; + + InstanceNormPattern pattern(add); + if (not pattern.matched()) + continue; + + fuse_instance_norm(pattern); + changed = true; + } + + return changed; +} + +} // namespace exo diff --git a/compiler/exo/src/Pass/FuseInstanceNormPass.h b/compiler/exo/src/Pass/FuseInstanceNormPass.h new file mode 100644 index 000000000..e6361021c --- /dev/null +++ b/compiler/exo/src/Pass/FuseInstanceNormPass.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __FUSE_INSTANCE_NORM_PASS_H__ +#define __FUSE_INSTANCE_NORM_PASS_H__ + +#include <logo/Pass.h> + +namespace exo +{ + +/** + * @brief Class to fuse certain pattern of subgraph into CircleInstanceNorm + * with auxiliary nodes + * + * For detailed subgraph pattern to be fused, please check its implementation. + */ +struct FuseInstanceNormPass final : public logo::Pass +{ + const char *name(void) const final { return "exo::FuseInstanceNormPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace exo + +#endif // __FUSE_INSTANCE_NORM_PASS_H__ diff --git a/compiler/exo/src/Pass/FuseReluPass.cpp b/compiler/exo/src/Pass/FuseReluPass.cpp new file mode 100644 index 000000000..d7af0c506 --- /dev/null +++ b/compiler/exo/src/Pass/FuseReluPass.cpp @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FuseReluPass.h" + +#include "Dialect/IR/TFLNodes.h" +#include "Dialect/IR/TFLDialect.h" +#include "Dialect/IR/TFLNodeVisitor.h" + +#include <set> + +namespace +{ + +bool is_pred_fusable(loco::Node *node) +{ + using namespace locoex; + + auto fusable_node = dynamic_cast<TFLNodeMixin<TFLNodeTrait::FusedActFunc> *>(node); + + return (fusable_node and fusable_node->fusedActivationFunction() == FusedActFunc::NONE); +}; + +struct Collector final : public locoex::TFLNodeMutableVisitor<void> +{ + void visit(locoex::TFLRelu *node) final + { + if (is_pred_fusable(node->features())) + candidates.insert(node); + } + + void visit(locoex::TFLRelu6 *node) final + { + if (is_pred_fusable(node->features())) + candidates.insert(node); + } + + void visit(locoex::TFLNode *) final { return; } + + std::set<locoex::TFLNode *> candidates; +}; + +void set_activation_fusion(loco::Node *node, locoex::FusedActFunc f) +{ + using namespace locoex; + + if (auto fusable_node = dynamic_cast<TFLNodeMixin<TFLNodeTrait::FusedActFunc> *>(node)) + fusable_node->fusedActivationFunction(f); + else + assert(false); +} + +struct Performer final : public locoex::TFLNodeMutableVisitor<void> +{ + void visit(locoex::TFLRelu *the_relu) final + { + set_activation_fusion(the_relu->features(), locoex::FusedActFunc::RELU); + + loco::replace(the_relu).with(the_relu->features()); + the_relu->features(nullptr); + } + + void visit(locoex::TFLRelu6 *the_relu6) final + { + set_activation_fusion(the_relu6->features(), locoex::FusedActFunc::RELU6); + + loco::replace(the_relu6).with(the_relu6->features()); + the_relu6->features(nullptr); + } + + void visit(locoex::TFLNode *) final { assert(false && "should not be called"); } +}; + +} // namespace + +namespace exo +{ + +bool FuseReluPass::run(loco::Graph *g) +{ + Collector collector; + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (node->dialect() == locoex::TFLDialect::get()) + { + auto tfl_node = dynamic_cast<locoex::TFLNode *>(node); + tfl_node->accept(&collector); + } + } + + Performer performer; + + for (auto node : collector.candidates) + { + node->accept(&performer); + } + + return collector.candidates.size() > 0; +} + +} // namespace exo diff --git a/compiler/exo/src/Pass/FuseReluPass.h b/compiler/exo/src/Pass/FuseReluPass.h new file mode 100644 index 000000000..1cd276b29 --- /dev/null +++ b/compiler/exo/src/Pass/FuseReluPass.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __PASS_FUSE_RELU_PASS_H__ +#define __PASS_FUSE_RELU_PASS_H__ + +#include <logo/Pass.h> + +namespace exo +{ + +/** + * @brief Class to fuse TFLRelu or TFLRelu6 into the TensorFlow Lite ops below: + * + * ADD, AVERAGE_POOL_2D, CONCATENATION, CONV_2D, DEPTHWISE_CONV_2D, + * FULLY_CONNECTED, L2_NORMALIZATION, L2_POOL_2D, MAX_POOL_2D, MUL + */ +struct FuseReluPass final : public logo::Pass +{ + const char *name(void) const final { return "exo::FuseReluPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace exo + +#endif // __PASS_FUSE_RELU_PASS_H__ diff --git a/compiler/exo/src/Pass/FuseReluPass.test.cpp b/compiler/exo/src/Pass/FuseReluPass.test.cpp new file mode 100644 index 000000000..6f83d4dd0 --- /dev/null +++ b/compiler/exo/src/Pass/FuseReluPass.test.cpp @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FuseReluPass.h" + +#include "Dialect/IR/TFLNodes.h" +#include "TestGraph.h" + +#include <loco.h> +#include <logo/RemoveDeadNodePass.h> + +#include <gtest/gtest.h> + +#include <type_traits> // for std::is_same + +namespace +{ + +void init(loco::Pull *pull) +{ + pull->dtype(loco::DataType::FLOAT32); + pull->shape({2, 3, 3, 2}); +} + +/// @brief Initializes TFLConv2D and related filter and bias +void init(locoex::TFLConv2D *conv2d, locoex::TFLConst *filter, locoex::TFLConst *bias) +{ + // set conv2d + { + conv2d->fusedActivationFunction(locoex::FusedActFunc::NONE); + conv2d->padding(locoex::Padding::VALID); + } + + // set filter + { + filter->dtype(loco::DataType::FLOAT32); + filter->shape({2, 3, 3, 2}); + filter->size<loco::DataType::FLOAT32>(2 * 3 * 3 * 2); + + for (uint32_t x = 0; x < 2 * 3 * 3 * 2; x++) + filter->at<loco::DataType::FLOAT32>(x) = 0.0; + } + + // set bias + { + bias->dtype(loco::DataType::FLOAT32); + bias->shape({2}); + bias->size<loco::DataType::FLOAT32>(2); + + for (uint32_t x = 0; x < 2; x++) + bias->at<loco::DataType::FLOAT32>(x) = 0.0; + } +} + +} // namespace + +/// Test code called by TEST(..) +/// This tests whether Conv2D - FusedTFLType is fused. +template <class FusedTFLType, locoex::FusedActFunc FusedActFunc> void test() +{ + static_assert((std::is_same<FusedTFLType, locoex::TFLRelu>::value && + FusedActFunc == locoex::FusedActFunc::RELU) || + (std::is_same<FusedTFLType, locoex::TFLRelu6>::value && + FusedActFunc == locoex::FusedActFunc::RELU6), + "wrong template type"); + + exo::test::TestGraph g; + { + auto filter = g.append<locoex::TFLConst>(); + auto bias = g.append<locoex::TFLConst>(); + auto conv2d = g.append<locoex::TFLConv2D>(g.pull, filter, bias); + + auto fusable_node = g.append<FusedTFLType>(conv2d); + + g.complete(fusable_node); + + init(g.pull); + init(conv2d, filter, bias); + } + + // let's run fusion + { + exo::test::TypeShapeReadyPhase test_phase; + + test_phase.add_pass<exo::FuseReluPass>(); + test_phase.add_pass<logo::RemoveDeadNodePass>(); // to remove TFLRelu + test_phase.run(g.graph()); + } + + auto a_conv2d = exo::test::find_first_node_bytype<locoex::TFLConv2D>(g.graph()); + ASSERT_TRUE(a_conv2d != nullptr); + ASSERT_TRUE(a_conv2d->fusedActivationFunction() == FusedActFunc); + + auto removed_fusable_node = exo::test::find_first_node_bytype<FusedTFLType>(g.graph()); + ASSERT_TRUE(removed_fusable_node == nullptr); +} + +// A case with Conv2D-Relu +TEST(FuseReluTest, Conv2D_Relu_basic) { test<locoex::TFLRelu, locoex::FusedActFunc::RELU>(); } + +// A case with Conv2D-Relu6 +TEST(FuseReluTest, Conv2D_Relu6_basic) { test<locoex::TFLRelu6, locoex::FusedActFunc::RELU6>(); } diff --git a/compiler/exo/src/Pass/FuseRsqrtPass.cpp b/compiler/exo/src/Pass/FuseRsqrtPass.cpp new file mode 100644 index 000000000..08d704139 --- /dev/null +++ b/compiler/exo/src/Pass/FuseRsqrtPass.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FuseRsqrtPass.h" + +#include "Check.h" + +#include "Dialect/IR/TFLNodes.h" + +namespace +{ + +/** + * @return Casted TFLDiv for fusable candidate, nullptr otherwise + * + * This helper checkes fusability with following conditions: + * - TFLDiv has no activation + * - TFLDiv's first argument is TFLConst with all value 1 + * - TFLDiv's second argument is TFLSqrt + */ +locoex::TFLDiv *as_candidate(loco::Node *node) +{ + auto div = dynamic_cast<locoex::TFLDiv *>(node); + if (not div) + return nullptr; + + // Cannot fuse Div with activation function + if (div->fusedActivationFunction() != locoex::FusedActFunc::NONE) + return nullptr; + + auto const_one = dynamic_cast<locoex::TFLConst *>(div->x()); + if (not const_one) + return nullptr; + + const loco::DataType FLOAT32 = loco::DataType::FLOAT32; + // TODO Support other dtype + EXO_ASSERT(const_one->dtype() == FLOAT32, "Only support FLOAT32 now"); + for (uint32_t i = 0; i < const_one->size<FLOAT32>(); ++i) + if (const_one->at<FLOAT32>(i) != 1.0f) + return nullptr; + + auto sqrt = dynamic_cast<locoex::TFLSqrt *>(div->y()); + if (not sqrt) + return nullptr; + + return div; +} + +void fuse_rsqrt(locoex::TFLDiv *div) +{ + auto sqrt = dynamic_cast<locoex::TFLSqrt *>(div->y()); + EXO_ASSERT(sqrt, "sqrt should be valid at this point"); + + // TFLRsqrt to replace + auto rsqrt = div->graph()->nodes()->create<locoex::TFLRsqrt>(); + rsqrt->x(sqrt->x()); + + // replace + loco::replace(div).with(rsqrt); +} + +} // namespace + +namespace exo +{ + +bool FuseRsqrtPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto div = as_candidate(node)) + { + fuse_rsqrt(div); + changed = true; + } + } + + return changed; +} + +} // namespace exo diff --git a/compiler/exo/src/Pass/FuseRsqrtPass.h b/compiler/exo/src/Pass/FuseRsqrtPass.h new file mode 100644 index 000000000..1e60e4a49 --- /dev/null +++ b/compiler/exo/src/Pass/FuseRsqrtPass.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __FUSE_RSQRT_PASS_H__ +#define __FUSE_RSQRT_PASS_H__ + +#include <logo/Pass.h> + +namespace exo +{ + +/** + * @brief Class to fuse TFLSqrt that is divided(TFLDiv) by 1, into TFLRsqrt + * + * <BEFORE> + * + * TFLConst(1) ------ + * \ + * A --- TFLSqrt --- TFLDiv --- B + * + * <AFTER> + * + * A --- TFLRsqrt --- B + */ +struct FuseRsqrtPass final : public logo::Pass +{ + const char *name(void) const final { return "exo::FuseRsqrtPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace exo + +#endif // __FUSE_RSQRT_PASS_H__ diff --git a/compiler/exo/src/Pass/FuseSquaredDifferencePass.cpp b/compiler/exo/src/Pass/FuseSquaredDifferencePass.cpp new file mode 100644 index 000000000..3f985a505 --- /dev/null +++ b/compiler/exo/src/Pass/FuseSquaredDifferencePass.cpp @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "FuseSquaredDifferencePass.h" + +#include "Check.h" + +#include "Dialect/IR/TFLNodes.h" + +namespace +{ + +/** + * @return Casted TFLMul for fusable candidate, nullptr otherwise + * + * This helper checkes fusability with following conditions: + * - TFLMul has no activation + * - TFLMul's first and second arguments are equal and TFLSub + */ +locoex::TFLMul *as_candidate(loco::Node *node) +{ + auto mul = dynamic_cast<locoex::TFLMul *>(node); + if (not mul) + return nullptr; + + // Cannot fuse mul with activation function + if (mul->fusedActivationFunction() != locoex::FusedActFunc::NONE) + return nullptr; + + if (mul->x() != mul->y()) + return nullptr; + + if (not dynamic_cast<locoex::TFLSub *>(mul->x())) + return nullptr; + + return mul; +} + +void fuse_squared_difference(locoex::TFLMul *mul) +{ + auto sub = dynamic_cast<locoex::TFLSub *>(mul->x()); + EXO_ASSERT(sub, "sub should be valid at this point"); + + // TFLSquaredDifference to replace + auto sq_diff = mul->graph()->nodes()->create<locoex::TFLSquaredDifference>(); + sq_diff->x(sub->x()); + sq_diff->y(sub->y()); + + // replace + loco::replace(mul).with(sq_diff); +} + +} // namespace + +namespace exo +{ + +bool FuseSquaredDifferencePass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto mul = as_candidate(node)) + { + fuse_squared_difference(mul); + changed = true; + } + } + + return changed; +} + +} // namespace exo diff --git a/compiler/exo/src/Pass/FuseSquaredDifferencePass.h b/compiler/exo/src/Pass/FuseSquaredDifferencePass.h new file mode 100644 index 000000000..dbc15149f --- /dev/null +++ b/compiler/exo/src/Pass/FuseSquaredDifferencePass.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __FUSE_SQUARED_DIFFERENCE_PASS_H__ +#define __FUSE_SQUARED_DIFFERENCE_PASS_H__ + +#include <logo/Pass.h> + +namespace exo +{ + +/** + * @brief Class to fuse SquaredDifference pattern + * + * <BEFORE> + * + * A --- TFLSub --- TFLMul --- C + * / \ / + * B ---- ----- + * + * <AFTER> + * + * A --- TFLSquaredDifference --- C + * / + * B ---- + */ +struct FuseSquaredDifferencePass final : public logo::Pass +{ + const char *name(void) const final { return "exo::FuseSquaredDifferencePass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace exo + +#endif // __FUSE_SQUARED_DIFFERENCE_PASS_H__ diff --git a/compiler/exo/src/Pass/MergeConcatNodesPass.cpp b/compiler/exo/src/Pass/MergeConcatNodesPass.cpp new file mode 100644 index 000000000..8945fcfce --- /dev/null +++ b/compiler/exo/src/Pass/MergeConcatNodesPass.cpp @@ -0,0 +1,191 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "MergeConcatNodesPass.h" +#include "Dialect/IR/TFLNodes.h" + +#include <oops/InternalExn.h> + +#include <vector> + +namespace +{ + +bool canMerge(locoex::TFLConcatenation *node1, locoex::TFLConcatenation *node2) +{ + if (node1->fusedActivationFunction() != node2->fusedActivationFunction()) + return false; + + if (node1->axis() != node2->axis()) + return false; + + switch (node1->fusedActivationFunction()) + { + case locoex::FusedActFunc::NONE: + case locoex::FusedActFunc::RELU: + case locoex::FusedActFunc::RELU6: + return true; + + // case locoex::FusedActFunc::TANH: + // return false; + + default: + INTERNAL_EXN_V("Unknown FusedActFunc", oops::to_uint32(node1->fusedActivationFunction())); + } +} + +/** + * @brief Collect all the inputs of newly created TFLConcatenation nodes + * + * in:0 -------------------------------\ + * in:1 ---- TFLConcatenation:0 -------- TFLConcatenation:3 --- C + * (axis = 0, NONE) (axis = 0, NONE) + * in:2 ---/ / + * in:3 ---- TFLConcatenation:1 ------/ + * (axis = 1, NONE) / + * in:4 ---/ / + * in:5 ---- TFLConcatenation:2 ---/ + * (axis = 0, RELU) + * in:6 ---/ + * + * For exmaple, if graph is like above, dfs(TFLConcatenation:3) will + * return [in:0, in:1, in:2, TFLConcatenation:1, TFLConcatenation:2] + * + * TFLConcatenation:0 can be merged to TFLConcatenation:3, + * because axis and fusedActivationFunction are same. + * It means that [in:1, in:2] will be linked as inputs of new TFLConcatenation. + * + * However, TFLConcatenation:1 and TFLConcatenation:2 cannot be merged to + * TFLConcatenation:3 because axis and fusedActivationFunction of each are different. + * So [in:3, in:4, in:5, in:6] will not be linked as inputs of new TFLConcatenation + * and [TFLConcatenation:1, TFLConcatenation:2] will be linked instead. + * + * Therefore, inputs of newly created TFLConcatenation node for merging + * TFLConcatenation:3 will be [in:0, in:1, in:2, TFLConcatenation:1, TFLConcatenation:2] + * and dfs(TFLConcatenation:3) will return it. + * + * + * @note The input nodes should be traversed by LRV, + * which is from left to right (input:0 --> input:N) + */ +std::vector<loco::Node *> dfs(locoex::TFLConcatenation *root) +{ + std::vector<loco::Node *> res; + + for (uint32_t i = 0; i < root->numValues(); ++i) + { + auto input = dynamic_cast<locoex::TFLConcatenation *>(root->values(i)); + if (input != nullptr && canMerge(input, root)) + { + auto children = dfs(input); + for (auto child : children) + res.push_back(child); + } + else + { + res.push_back(root->values(i)); + } + } + + return res; +} + +} // namespace + +namespace exo +{ + +/** + * @brief Merge TFLConcatenate nodes whose axis and fusedActivationFunction are same + * + * [Before] + * in:0 -------------------------------\ + * in:1 ---- TFLConcatenation:0 -------- TFLConcatenation:3 --- C + * (axis = 0, NONE) (axis = 0, NONE) + * in:2 ---/ / + * in:3 ---- TFLConcatenation:1 ------/ + * (axis = 1, NONE) / + * in:4 ---/ / + * in:5 ---- TFLConcatenation:2 ---/ + * (axis = 0, RELU) + * in:6 ---/ + * + * [After] + * in:0 -------------------------------\ + * in:1 -------------------------------- TFLConcatenation:4 --- C + * (axis = 0, NONE) + * in:2 -------------------------------/ + * in:3 ---- TFLConcatenation:1 ------/ + * (axis = 1, NONE) / + * in:4 ---/ / + * in:5 ---- TFLConcatenation:2 ---/ + * (axis = 0, RELU) + * in:6 ---/ + * + * + * in:1 ---- TFLConcatenation:0 ---- + * (axis = 0, NONE) + * in:2 ---/ + * + * + * ---- TFLConcatenation:3 ---- + * (axis = 0, NONE) + */ +bool MergeConcatNodesPass::run(loco::Graph *graph) +{ + // Let's enumerate nodes required to compute output nodes + auto active_nodes = loco::active_nodes(loco::output_nodes(graph)); + + // Find TFLConcatenation nodes which have another TFLConcatenation nodes + // as inputs, with same axis and same fusedActivationFunction + std::vector<locoex::TFLConcatenation *> candidates; + for (auto node : active_nodes) + { + if (auto concat = dynamic_cast<locoex::TFLConcatenation *>(node)) + { + for (uint32_t i = 0; i < concat->numValues(); ++i) + { + auto input = dynamic_cast<locoex::TFLConcatenation *>(concat->values(i)); + if (input != nullptr && canMerge(input, concat)) + { + candidates.push_back(concat); + break; + } + } + } + } + + // Merge multiple TFLConcatenation nodes as one TFLConcatenation node + for (auto node : candidates) + { + auto inputs = dfs(node); + + auto new_concat = graph->nodes()->create<locoex::TFLConcatenation>(inputs.size()); + new_concat->axis(node->axis()); + new_concat->fusedActivationFunction(node->fusedActivationFunction()); + + for (uint32_t i = 0; i < inputs.size(); ++i) + new_concat->values(i, inputs.at(i)); + + loco::replace(node).with(new_concat); + for (uint32_t i = 0; i < node->numValues(); ++i) + node->values(i, nullptr); + } + + return candidates.size() > 0; +} + +} // namespace exo diff --git a/compiler/exo/src/Pass/MergeConcatNodesPass.h b/compiler/exo/src/Pass/MergeConcatNodesPass.h new file mode 100644 index 000000000..823214f43 --- /dev/null +++ b/compiler/exo/src/Pass/MergeConcatNodesPass.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __PASS_MERGE_CONCAT_NODES_H__ +#define __PASS_MERGE_CONCAT_NODES_H__ + +#include <loco.h> +#include <logo/Pass.h> + +namespace exo +{ + +/** + * @brief Merge concat nodes whose axis and fusedActivationFunction are same + * + */ +class MergeConcatNodesPass : public logo::Pass +{ +public: + virtual const char *name(void) const { return "exo::MergeConcatNodesPass"; } + +public: + bool run(loco::Graph *graph); +}; + +} // namespace exo + +#endif // __PASS_MERGE_CONCAT_NODES_H__ diff --git a/compiler/exo/src/Pass/ShapeInferencePass.cpp b/compiler/exo/src/Pass/ShapeInferencePass.cpp new file mode 100644 index 000000000..bc60f91c4 --- /dev/null +++ b/compiler/exo/src/Pass/ShapeInferencePass.cpp @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ShapeInferencePass.h" + +#include "Dialect/IR/TFLDialect.h" +#include "Dialect/Service/TFLShapeInferenceRule.h" + +#include "Dialect/IR/CircleDialect.h" +#include "Dialect/Service/CircleShapeInferenceRule.h" + +#include <loco.h> +#include <loco/IR/CanonicalDialect.h> +#include <loco/Service/CanonicalShapeInferenceRule.h> +#include <loco/Service/ShapeInference.h> +#include <loco/Service/MultiDialectShapeInferenceRule.h> + +#include <locoex/COpDialect.h> +#include <locoex/Service/COpShapeInferenceRule.h> + +namespace exo +{ + +/** + * @note Currently, TFL and Circle backend share this inference. However, TFL + * backend does not require rule for Circle dialect. + * TODO Make dedicated inference pass for Circle Dialect. + */ +bool ShapeInferencePass::run(loco::Graph *g) +{ + loco::CanonicalShapeInferenceRule canonical_rule; + locoex::TFLShapeInferenceRule tfl_rule; + locoex::CircleShapeInferenceRule circle_rule; + locoex::COpShapeInferenceRule cop_rule; + + loco::MultiDialectShapeInferenceRule rules; + + rules.bind(loco::CanonicalDialect::get(), &canonical_rule) + .bind(locoex::TFLDialect::get(), &tfl_rule) + .bind(locoex::CircleDialect::get(), &circle_rule) + .bind(locoex::COpDialect::get(), &cop_rule); + + return loco::apply(&rules).to(g); +} + +} // namespace exo diff --git a/compiler/exo/src/Pass/ShapeInferencePass.h b/compiler/exo/src/Pass/ShapeInferencePass.h new file mode 100644 index 000000000..518c87403 --- /dev/null +++ b/compiler/exo/src/Pass/ShapeInferencePass.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __PASS_SHAPE_INFERENCE_PASS_H__ +#define __PASS_SHAPE_INFERENCE_PASS_H__ + +#include <loco.h> +#include <logo/Pass.h> + +namespace exo +{ + +/** + * @brief Pass to infer shape of nodes + */ +class ShapeInferencePass : public logo::Pass +{ +public: + virtual const char *name(void) const { return "exo::ShapeInferencePass"; } + +public: + bool run(loco::Graph *graph); +}; + +} // namespace exo + +#endif //__PASS_SHAPE_INFERENCE_PASS_H__ diff --git a/compiler/exo/src/Pass/TypeInferencePass.cpp b/compiler/exo/src/Pass/TypeInferencePass.cpp new file mode 100644 index 000000000..31d4f13b6 --- /dev/null +++ b/compiler/exo/src/Pass/TypeInferencePass.cpp @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TypeInferencePass.h" + +#include "Dialect/IR/TFLDialect.h" +#include "Dialect/Service/TFLTypeInferenceRule.h" + +#include "Dialect/IR/CircleDialect.h" +#include "Dialect/Service/CircleTypeInferenceRule.h" + +#include <loco.h> +#include <loco/IR/CanonicalDialect.h> +#include <loco/Service/TypeInference.h> + +#include <locoex/COpDialect.h> +#include <locoex/Service/COpTypeInference.h> + +namespace exo +{ + +/** + * @note Currently, TFL and Circle backend share this inference. However, TFL + * backend does not require rule for Circle dialect. + * TODO Make dedicated inference pass for Circle Dialect. + */ +bool TypeInferencePass::run(loco::Graph *g) +{ + loco::CanonicalTypeInferenceRule canonical_rule; + locoex::TFLTypeInferenceRule tfl_rule; + locoex::CircleTypeInferenceRule circle_rule; + locoex::COpTypeInferenceRule cop_rule; + + loco::MultiDialectTypeInferenceRule rules; + + rules.bind(loco::CanonicalDialect::get(), &canonical_rule) + .bind(locoex::TFLDialect::get(), &tfl_rule) + .bind(locoex::CircleDialect::get(), &circle_rule) + .bind(locoex::COpDialect::get(), &cop_rule); + + return loco::apply(&rules).to(g); +} + +} // namespace exo diff --git a/compiler/exo/src/Pass/TypeInferencePass.h b/compiler/exo/src/Pass/TypeInferencePass.h new file mode 100644 index 000000000..3ede587a0 --- /dev/null +++ b/compiler/exo/src/Pass/TypeInferencePass.h @@ -0,0 +1,42 @@ + +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __PASS_TYPE_INFERENCE_PASS_H__ +#define __PASS_TYPE_INFERENCE_PASS_H__ + +#include <loco.h> + +#include <logo/Pass.h> + +namespace exo +{ + +/** + * @brief Pass to infer type of nodes + */ +class TypeInferencePass : public logo::Pass +{ +public: + virtual const char *name(void) const { return "exo::TypeInferencePass"; } + +public: + bool run(loco::Graph *graph); +}; + +} // namespace exo + +#endif //__PASS_TYPE_INFERENCE_PASS_H__ |