diff options
Diffstat (limited to 'compiler/luci/pass/src/FoldDequantizePass.cpp')
-rw-r--r-- | compiler/luci/pass/src/FoldDequantizePass.cpp | 96 |
1 files changed, 67 insertions, 29 deletions
diff --git a/compiler/luci/pass/src/FoldDequantizePass.cpp b/compiler/luci/pass/src/FoldDequantizePass.cpp index 3dd4f8cea..b6526deb0 100644 --- a/compiler/luci/pass/src/FoldDequantizePass.cpp +++ b/compiler/luci/pass/src/FoldDequantizePass.cpp @@ -19,6 +19,8 @@ #include <luci/IR/CircleNodes.h> #include <luci/Profile/CircleNodeOrigin.h> +#include <fp16.h> + namespace { @@ -32,6 +34,9 @@ bool is_hybrid_kernel_supported(loco::Node *node) bool is_foldable_const(luci::CircleConst *node) { + if (node->dtype() == loco::DataType::FLOAT16) + return true; + if (node->quantparam() == nullptr) return false; @@ -39,17 +44,18 @@ bool is_foldable_const(luci::CircleConst *node) return true; if (node->dtype() == loco::DataType::U8) return true; + if (node->dtype() == loco::DataType::S16) + return true; + if (node->dtype() == loco::DataType::S32) + return true; + if (node->dtype() == loco::DataType::S64) + return true; return false; } luci::CircleConst *dequantized_const_node(luci::CircleConst *const_node) { - if (const_node->quantparam() == nullptr) - { - throw std::runtime_error("Given constant node has no quantization parameter"); - } - auto name = const_node->name(); assert(name.length() > 0); auto g = const_node->graph(); @@ -67,38 +73,70 @@ luci::CircleConst *dequantized_const_node(luci::CircleConst *const_node) new_const_node->shape_status(luci::ShapeStatus::VALID); new_const_node->name(name + "_DQ"); + if (const_node->dtype() == loco::DataType::FLOAT16) + { + for (uint32_t i = 0; i < new_const_node->size<loco::DataType::FLOAT32>(); ++i) + { + auto raw = const_node->at<loco::DataType::FLOAT16>(i); + new_const_node->at<loco::DataType::FLOAT32>(i) = fp16_ieee_to_fp32_value(raw); + } + return new_const_node; + } + + if (const_node->quantparam() == nullptr) + { + throw std::runtime_error("Given constant node has no quantization parameter"); + } + const int32_t q_dim = const_node->quantparam()->quantized_dimension; - const int32_t q_dim_value = const_node->dim(q_dim).value(); + // For scalar, q_dim_value is 1 + // For non-scalar, q_dim_value is the size of quantized dimension + const int32_t q_dim_value = const_node->rank() == 0 ? 1 : const_node->dim(q_dim).value(); int32_t right_count = q_dim_value; for (uint32_t i = q_dim + 1; i < const_node->rank(); ++i) right_count *= const_node->dim(i).value(); - if (const_node->dtype() == loco::DataType::S8) + for (uint32_t i = 0; i < new_const_node->size<loco::DataType::FLOAT32>(); ++i) { - for (uint32_t i = 0; i < const_node->size<loco::DataType::S8>(); ++i) - { - uint32_t qd = (i % right_count) / (right_count / q_dim_value); - if (qd >= const_node->quantparam()->zerop.size()) - qd = 0; + uint32_t qd = (i % right_count) / (right_count / q_dim_value); + if (qd >= const_node->quantparam()->zerop.size()) + qd = 0; - new_const_node->at<loco::DataType::FLOAT32>(i) = - (float)(const_node->at<loco::DataType::S8>(i) - const_node->quantparam()->zerop.at(qd)) * - const_node->quantparam()->scale.at(qd); - } - } - else - { - for (uint32_t i = 0; i < const_node->size<loco::DataType::U8>(); ++i) + switch (const_node->dtype()) { - uint32_t qd = (i % right_count) / (right_count / q_dim_value); - if (qd >= const_node->quantparam()->zerop.size()) - qd = 0; - - new_const_node->at<loco::DataType::FLOAT32>(i) = - (float)((int)const_node->at<loco::DataType::U8>(i) - - const_node->quantparam()->zerop.at(qd)) * - const_node->quantparam()->scale.at(qd); + case loco::DataType::S8: + new_const_node->at<loco::DataType::FLOAT32>(i) = + static_cast<float>(const_node->at<loco::DataType::S8>(i) - + const_node->quantparam()->zerop.at(qd)) * + const_node->quantparam()->scale.at(qd); + break; + case loco::DataType::S16: + new_const_node->at<loco::DataType::FLOAT32>(i) = + static_cast<float>(const_node->at<loco::DataType::S16>(i) - + const_node->quantparam()->zerop.at(qd)) * + const_node->quantparam()->scale.at(qd); + break; + case loco::DataType::S32: + new_const_node->at<loco::DataType::FLOAT32>(i) = + static_cast<float>(const_node->at<loco::DataType::S32>(i) - + const_node->quantparam()->zerop.at(qd)) * + const_node->quantparam()->scale.at(qd); + break; + case loco::DataType::S64: + new_const_node->at<loco::DataType::FLOAT32>(i) = + static_cast<float>(const_node->at<loco::DataType::S64>(i) - + const_node->quantparam()->zerop.at(qd)) * + const_node->quantparam()->scale.at(qd); + break; + case loco::DataType::U8: + new_const_node->at<loco::DataType::FLOAT32>(i) = + static_cast<float>(const_node->at<loco::DataType::U8>(i) - + const_node->quantparam()->zerop.at(qd)) * + const_node->quantparam()->scale.at(qd); + break; + default: + throw std::runtime_error("Not supported dtype for FoldDequantizePass"); } } @@ -160,7 +198,7 @@ bool FoldDequantizePass::run(loco::Graph *g) { bool changed = false; - for (auto node : loco::all_nodes(g)) + for (auto node : loco::active_nodes(loco::output_nodes(g))) { if (auto circle_dequant = dynamic_cast<luci::CircleDequantize *>(node)) { |