diff options
Diffstat (limited to 'compiler/luci/pass/src/ForwardTransposeOpPass.cpp')
-rw-r--r-- | compiler/luci/pass/src/ForwardTransposeOpPass.cpp | 366 |
1 files changed, 366 insertions, 0 deletions
diff --git a/compiler/luci/pass/src/ForwardTransposeOpPass.cpp b/compiler/luci/pass/src/ForwardTransposeOpPass.cpp new file mode 100644 index 000000000..c76d73344 --- /dev/null +++ b/compiler/luci/pass/src/ForwardTransposeOpPass.cpp @@ -0,0 +1,366 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/ForwardTransposeOpPass.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNodeVisitor.h> +#include <luci/Profile/CircleNodeOrigin.h> +#include <luci/Service/Nodes/CircleConst.h> +#include <luci/Service/CircleNodeClone.h> + +using namespace luci; + +namespace +{ + +// Create new Transpose Op including perm +// Return nullptr if failed +CircleTranspose *create_cloned_transpose(CircleTranspose *transpose) +{ + assert(transpose != nullptr); // FIX_CALLER_UNLESS + + auto perm = dynamic_cast<CircleConst *>(transpose->perm()); + if (not perm) + return nullptr; + + CircleConst *cloned_perm = clone(perm); + if (cloned_perm == nullptr) + return nullptr; + + cloned_perm->name(perm->name() + "_C"); + luci::add_origin(cloned_perm, luci::get_origin(perm)); + + auto cloned_node = clone_node(transpose, transpose->graph()); + if (cloned_node == nullptr) + return nullptr; + + auto new_transpose = loco::must_cast<luci::CircleTranspose *>(cloned_node); + new_transpose->perm(cloned_perm); + new_transpose->name(transpose->name() + "_C"); + luci::add_origin(new_transpose, luci::get_origin(transpose)); + + return new_transpose; +} + +uint32_t cal_offset(const std::vector<uint32_t> &shape, const std::vector<uint32_t> &indices) +{ + assert(shape.size() == indices.size()); // FIX_CALLER_UNLESS + + uint32_t offset = 0; + for (uint32_t i = 0; i < indices.size(); i++) + { + uint32_t index = indices[i]; + for (uint32_t j = shape.size() - 1; j > i; j--) + { + index *= shape[j]; + } + offset += index; + } + return offset; +} + +// Return reverse-transpose of 'node' +// i.e., Transpose(return value) = node +CircleConst *reverse_transposed(CircleConst *node, std::vector<uint32_t> &t) +{ + assert(node->rank() == t.size()); // FIX_CALLER_UNLESS + assert(node->rank() == 4); // FIX_CALLER_UNLESS + + std::vector<uint32_t> orig_shape(node->rank()); + std::vector<uint32_t> new_shape(node->rank()); + + for (uint32_t i = 0; i < node->rank(); i++) + { + assert(t[i] < node->rank()); // FIX_CALLER_UNLESS + + orig_shape[i] = node->dim(i).value(); + new_shape[t[i]] = node->dim(i).value(); + } + + auto clone_const = clone(node); + for (uint32_t i = 0; i < node->rank(); i++) + clone_const->dim(i).set(new_shape[i]); + + clone_const->name(clone_const->name() + "_r_transposed"); + add_origin(clone_const, luci::get_origin(node)); + + for (uint32_t n = 0; n < clone_const->dim(0).value(); n++) + { + for (uint32_t h = 0; h < clone_const->dim(1).value(); h++) + { + for (uint32_t w = 0; w < clone_const->dim(2).value(); w++) + { + for (uint32_t c = 0; c < clone_const->dim(3).value(); c++) + { + std::vector<uint32_t> new_indices{n, h, w, c}; + std::vector<uint32_t> orig_indices{new_indices[t[0]], new_indices[t[1]], + new_indices[t[2]], new_indices[t[3]]}; + + const auto data = node->at<loco::DataType::FLOAT32>(cal_offset(orig_shape, orig_indices)); + clone_const->at<loco::DataType::FLOAT32>(cal_offset(new_shape, new_indices)) = data; + } + } + } + } + + return clone_const; +} + +bool check_rank_four(const CircleConst *c) { return c->rank() == 4; } + +// Return true if below conditions are met +// 1. t->perm() is CircleConst +// 2. t->perm() is S32 +bool check_perm(const CircleTranspose *t) +{ + auto perm = dynamic_cast<CircleConst *>(t->perm()); + if (not perm) + return false; + + switch (perm->dtype()) + { + case loco::DataType::S32: + for (uint32_t i = 0; i < perm->size<loco::DataType::S32>(); i++) + { + auto data = perm->at<loco::DataType::S32>(i); + // TODO Support not normalized index + if (data < 0 or data >= static_cast<int32_t>(t->rank())) + return false; + } + break; + // TODO Support S64 data type + default: + return false; + } + + return true; +} + +#define RETURN_FALSE_UNLESS(COND) \ + if (not(COND)) \ + return false; + +// Elementwise Binary Operator with const +class EBOWithConstPattern final : public CircleNodeMutableVisitor<bool> +{ +private: + template <typename CIRCLE_OP_PTR> bool has_pattern(CIRCLE_OP_PTR node) + { + if (auto x = dynamic_cast<luci::CircleConst *>(node->x())) + { + if (auto y = dynamic_cast<luci::CircleTranspose *>(node->y())) + { + RETURN_FALSE_UNLESS(check_rank_four(x)); + RETURN_FALSE_UNLESS(check_perm(y)); + + auto new_const = gen_new_const(y, x); + assert(new_const); // FIX_ME_UNLESS + + auto new_transpose = create_cloned_transpose(y); + assert(new_transpose); // FIX_ME_UNLESS + + // Reconnect network + node->x(new_const); + node->y(y->a()); + loco::replace(node).with(new_transpose); + new_transpose->a(node); + + // Do shape inference for this node again. + node->shape_status(luci::ShapeStatus::UNDEFINED); + + return true; + } + } + + if (auto y = dynamic_cast<luci::CircleConst *>(node->y())) + { + if (auto x = dynamic_cast<luci::CircleTranspose *>(node->x())) + { + RETURN_FALSE_UNLESS(check_rank_four(y)); + RETURN_FALSE_UNLESS(check_perm(x)); + + auto new_const = gen_new_const(x, y); + assert(new_const); // FIX_ME_UNLESS + + auto new_transpose = create_cloned_transpose(x); + assert(new_transpose); // FIX_ME_UNLESS + + // Reconnect network + node->y(new_const); + node->x(x->a()); + loco::replace(node).with(new_transpose); + new_transpose->a(node); + + // Do shape inference for this node again. + node->shape_status(luci::ShapeStatus::UNDEFINED); + + return true; + } + } + + return false; + } + +public: + // Default + bool visit(luci::CircleNode *) { return false; } + + bool visit(luci::CircleAdd *node) { return has_pattern(node); } + + bool visit(luci::CircleMul *node) { return has_pattern(node); } + +private: + // Return a new const node after Tranpose Op is forwarded + // Return nullptr if unsupported cases + CircleConst *gen_new_const(CircleTranspose *t, CircleConst *c) + { + const auto perm = dynamic_cast<CircleConst *>(t->perm()); + + // Only support constant perm + if (not perm) + return nullptr; + + std::vector<uint32_t> perm_data; + switch (perm->dtype()) + { + case loco::DataType::S32: + for (uint32_t i = 0; i < perm->size<loco::DataType::S32>(); i++) + { + auto data = perm->at<loco::DataType::S32>(i); + assert(data >= 0 and data < static_cast<int32_t>(t->rank())); + perm_data.emplace_back(static_cast<uint32_t>(data)); + } + break; + // TODO Support S64 data type + default: + return nullptr; + } + + assert(perm_data.size() == t->rank()); // FIX_CALLER_UNLESS + + return reverse_transposed(c, perm_data); + } +}; + +// Elementwise Unary Operator +class EwUnaryPattern final : public CircleNodeMutableVisitor<bool> +{ +private: + // input is 'x' + template <typename CIRCLE_OP_PTR> bool has_pattern_x(CIRCLE_OP_PTR node) + { + if (auto x = dynamic_cast<luci::CircleTranspose *>(node->x())) + { + RETURN_FALSE_UNLESS(check_perm(x)); + + auto new_transpose = create_cloned_transpose(x); + assert(new_transpose); // FIX_ME_UNLESS + + // Reconnect network + node->x(x->a()); + loco::replace(node).with(new_transpose); + new_transpose->a(node); + + // Do shape inference for this node again. + node->shape_status(luci::ShapeStatus::UNDEFINED); + + return true; + } + + return false; + } + +public: + // Default + bool visit(luci::CircleNode *) { return false; } + + bool visit(luci::CircleAbs *node) { return has_pattern_x(node); } +}; + +} // namespace + +namespace luci +{ + +/** + * BEFORE + * | + * [CircleNode] [CircleConst] + * | / + * [CircleTranspose] [CircleConst] + * / | / + * [CircleNode] [(BinaryOp)] + * | | \ + * | | [CircleNode] + * | | | + * + * BinaryOp: CircleAdd, CircleMul, ... + * + * | + * [CircleNode] [CircleConst] + * | / + * [CircleTranspose] + * / | + * [CircleNode] [(UnaryOp)] + * | | \ + * | | [CircleNode] + * | | | + * + * UnaryOp: CircleAbs, ... + * + * AFTER + * | + * [CircleConst] [CircleNode] [CircleConst(updated)] + * | / | / + * [CircleTranspose] [(BinaryOp)] [CircleConst] + * | | / + * [CircleNode] [CircleTranspose] + * | | \ + * | | [CircleNode] + * | | | + * + * | + * [CircleConst] [CircleNode] + * | / | + * [CircleTranspose] [(UnaryOp)] [CircleConst] + * | | / + * [CircleNode] [CircleTranspose] + * | | \ + * | | [CircleNode] + * | | | + * + * Note: new [CircleTranspose] is added after [(BinaryOp)] + */ +bool ForwardTransposeOpPass::run(loco::Graph *g) +{ + bool changed = false; + EBOWithConstPattern eboc; + EwUnaryPattern ewu; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + if (circle_node->accept(&eboc)) + changed = true; + else if (circle_node->accept(&ewu)) + changed = true; + } + return changed; +} + +#undef RETURN_FALSE_UNLESS + +} // namespace luci |