diff options
Diffstat (limited to 'compiler/luci/pass/src/FoldDequantizePass.cpp')
-rw-r--r-- | compiler/luci/pass/src/FoldDequantizePass.cpp | 206 |
1 files changed, 206 insertions, 0 deletions
diff --git a/compiler/luci/pass/src/FoldDequantizePass.cpp b/compiler/luci/pass/src/FoldDequantizePass.cpp new file mode 100644 index 000000000..01c04f478 --- /dev/null +++ b/compiler/luci/pass/src/FoldDequantizePass.cpp @@ -0,0 +1,206 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FoldDequantizePass.h" + +#include <luci/IR/CircleNodes.h> + +#include <loco/Service/TypeInference.h> + +namespace +{ + +bool is_hybrid_kernel_supported(loco::Node *node) +{ + if (dynamic_cast<luci::CircleFullyConnected *>(node) != nullptr) + return true; + + return false; +} + +bool is_foldable_const(luci::CircleConst *node) +{ + if (node->quantparam() == nullptr) + return false; + + if (node->dtype() == loco::DataType::S8) + return true; + if (node->dtype() == loco::DataType::U8) + return true; + + return false; +} + +luci::CircleConst *dequantized_const_node(luci::CircleConst *const_node) +{ + if (const_node->quantparam() == nullptr) + { + throw std::runtime_error("Given constant node has no quantization parameter"); + } + + auto g = const_node->graph(); + auto new_const_node = g->nodes()->create<luci::CircleConst>(); + + new_const_node->dtype(loco::DataType::FLOAT32); + new_const_node->rank(const_node->rank()); + uint32_t dim_size = 1; + for (uint32_t i = 0; i < new_const_node->rank(); ++i) + { + new_const_node->dim(i) = const_node->dim(i); + dim_size *= const_node->dim(i).value(); + } + new_const_node->size<loco::DataType::FLOAT32>(dim_size); + new_const_node->shape_status(luci::ShapeStatus::VALID); + + const int32_t q_dim = const_node->quantparam()->quantized_dimension; + const int32_t q_dim_value = const_node->dim(q_dim).value(); + + int32_t right_count = q_dim_value; + for (uint32_t i = q_dim + 1; i < const_node->rank(); ++i) + right_count *= const_node->dim(i).value(); + + if (const_node->dtype() == loco::DataType::S8) + { + for (uint32_t i = 0; i < const_node->size<loco::DataType::S8>(); ++i) + { + uint32_t qd = (i % right_count) / (right_count / q_dim_value); + if (qd >= const_node->quantparam()->zerop.size()) + qd = 0; + + new_const_node->at<loco::DataType::FLOAT32>(i) = + (float)(const_node->at<loco::DataType::S8>(i) - const_node->quantparam()->zerop.at(qd)) * + const_node->quantparam()->scale.at(qd); + } + } + else + { + for (uint32_t i = 0; i < const_node->size<loco::DataType::U8>(); ++i) + { + uint32_t qd = (i % right_count) / (right_count / q_dim_value); + if (qd >= const_node->quantparam()->zerop.size()) + qd = 0; + + new_const_node->at<loco::DataType::FLOAT32>(i) = + (float)((int)const_node->at<loco::DataType::U8>(i) - + const_node->quantparam()->zerop.at(qd)) * + const_node->quantparam()->scale.at(qd); + } + } + + return new_const_node; +} + +bool replace_const_node(loco::Node *node, luci::CircleConst *const_node) +{ + if (auto gather = dynamic_cast<luci::CircleGather *>(node)) + { + gather->params(dequantized_const_node(const_node)); + gather->dtype(loco::DataType::FLOAT32); + return true; + } + else + { + // TODO Support more ops + return false; + } +} + +} // namespace + +namespace luci +{ + +/** + * + * Folding pattern 1 - When input of Dequantize is foldable constant + * + * [Before] + * quantized_const_input ---------- Dequantize ---------- Op --- + * +-- Op1_with_quant_input --- + * +-- Op2_with_quant_input --- + * + * [After] + * dequantized_const_input -------------------------------- Op --- + * + * quantized_const_input ----- Op1_with_quant_input --- + * +-- Op2_with_quant_input --- + * + * + * Folding pattern 2 - When input of Dequantize uses quantized output value + * + * [Before] + * quantized_const_input ----- Gather ----- Dequantize --- Op --- + * +-- Op1_with_quant_input --- + * +-- Op2_with_quant_input --- + * + * [After] + * dequantized_const_input ------Gather -------------------- Op --- + * + * quantized_const_input ----- Op1_with_quant_input --- + * +-- Op2_with_quant_input --- + * + * + */ +bool FoldDequantizePass::run(loco::Graph *g) +{ + bool changed = false; + + for (auto node : loco::all_nodes(g)) + { + if (auto circle_dequant = dynamic_cast<luci::CircleDequantize *>(node)) + { + if (auto const_input = dynamic_cast<luci::CircleConst *>(circle_dequant->input())) + { + // Pattern 1 - When input of Dequantize is foldable constant + if (is_foldable_const(const_input)) + { + loco::replace(circle_dequant).with(dequantized_const_node(const_input)); + changed = true; + } + } + } + else if (auto const_node = dynamic_cast<luci::CircleConst *>(node)) + { + if (is_foldable_const(const_node)) + { + for (auto const_node_user : loco::succs(const_node)) + { + // If user is hybrid kernel supported operation, do not dequantize + if (is_hybrid_kernel_supported(const_node_user)) + continue; + + auto users = loco::succs(const_node_user); + if (users.size() > 1) + continue; + + // Pattern 2 - When input of Dequantize uses quantized output value + if (auto dequant = dynamic_cast<luci::CircleDequantize *>(*users.begin())) + { + if (replace_const_node(const_node_user, const_node)) + { + loco::replace(dequant).with(const_node_user); + changed = true; + } + } + } + } + } + } + + return changed; +} + +} // namespace luci |