summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/FoldDequantizePass.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/pass/src/FoldDequantizePass.cpp')
-rw-r--r--compiler/luci/pass/src/FoldDequantizePass.cpp96
1 files changed, 67 insertions, 29 deletions
diff --git a/compiler/luci/pass/src/FoldDequantizePass.cpp b/compiler/luci/pass/src/FoldDequantizePass.cpp
index 3dd4f8cea..b6526deb0 100644
--- a/compiler/luci/pass/src/FoldDequantizePass.cpp
+++ b/compiler/luci/pass/src/FoldDequantizePass.cpp
@@ -19,6 +19,8 @@
#include <luci/IR/CircleNodes.h>
#include <luci/Profile/CircleNodeOrigin.h>
+#include <fp16.h>
+
namespace
{
@@ -32,6 +34,9 @@ bool is_hybrid_kernel_supported(loco::Node *node)
bool is_foldable_const(luci::CircleConst *node)
{
+ if (node->dtype() == loco::DataType::FLOAT16)
+ return true;
+
if (node->quantparam() == nullptr)
return false;
@@ -39,17 +44,18 @@ bool is_foldable_const(luci::CircleConst *node)
return true;
if (node->dtype() == loco::DataType::U8)
return true;
+ if (node->dtype() == loco::DataType::S16)
+ return true;
+ if (node->dtype() == loco::DataType::S32)
+ return true;
+ if (node->dtype() == loco::DataType::S64)
+ return true;
return false;
}
luci::CircleConst *dequantized_const_node(luci::CircleConst *const_node)
{
- if (const_node->quantparam() == nullptr)
- {
- throw std::runtime_error("Given constant node has no quantization parameter");
- }
-
auto name = const_node->name();
assert(name.length() > 0);
auto g = const_node->graph();
@@ -67,38 +73,70 @@ luci::CircleConst *dequantized_const_node(luci::CircleConst *const_node)
new_const_node->shape_status(luci::ShapeStatus::VALID);
new_const_node->name(name + "_DQ");
+ if (const_node->dtype() == loco::DataType::FLOAT16)
+ {
+ for (uint32_t i = 0; i < new_const_node->size<loco::DataType::FLOAT32>(); ++i)
+ {
+ auto raw = const_node->at<loco::DataType::FLOAT16>(i);
+ new_const_node->at<loco::DataType::FLOAT32>(i) = fp16_ieee_to_fp32_value(raw);
+ }
+ return new_const_node;
+ }
+
+ if (const_node->quantparam() == nullptr)
+ {
+ throw std::runtime_error("Given constant node has no quantization parameter");
+ }
+
const int32_t q_dim = const_node->quantparam()->quantized_dimension;
- const int32_t q_dim_value = const_node->dim(q_dim).value();
+ // For scalar, q_dim_value is 1
+ // For non-scalar, q_dim_value is the size of quantized dimension
+ const int32_t q_dim_value = const_node->rank() == 0 ? 1 : const_node->dim(q_dim).value();
int32_t right_count = q_dim_value;
for (uint32_t i = q_dim + 1; i < const_node->rank(); ++i)
right_count *= const_node->dim(i).value();
- if (const_node->dtype() == loco::DataType::S8)
+ for (uint32_t i = 0; i < new_const_node->size<loco::DataType::FLOAT32>(); ++i)
{
- for (uint32_t i = 0; i < const_node->size<loco::DataType::S8>(); ++i)
- {
- uint32_t qd = (i % right_count) / (right_count / q_dim_value);
- if (qd >= const_node->quantparam()->zerop.size())
- qd = 0;
+ uint32_t qd = (i % right_count) / (right_count / q_dim_value);
+ if (qd >= const_node->quantparam()->zerop.size())
+ qd = 0;
- new_const_node->at<loco::DataType::FLOAT32>(i) =
- (float)(const_node->at<loco::DataType::S8>(i) - const_node->quantparam()->zerop.at(qd)) *
- const_node->quantparam()->scale.at(qd);
- }
- }
- else
- {
- for (uint32_t i = 0; i < const_node->size<loco::DataType::U8>(); ++i)
+ switch (const_node->dtype())
{
- uint32_t qd = (i % right_count) / (right_count / q_dim_value);
- if (qd >= const_node->quantparam()->zerop.size())
- qd = 0;
-
- new_const_node->at<loco::DataType::FLOAT32>(i) =
- (float)((int)const_node->at<loco::DataType::U8>(i) -
- const_node->quantparam()->zerop.at(qd)) *
- const_node->quantparam()->scale.at(qd);
+ case loco::DataType::S8:
+ new_const_node->at<loco::DataType::FLOAT32>(i) =
+ static_cast<float>(const_node->at<loco::DataType::S8>(i) -
+ const_node->quantparam()->zerop.at(qd)) *
+ const_node->quantparam()->scale.at(qd);
+ break;
+ case loco::DataType::S16:
+ new_const_node->at<loco::DataType::FLOAT32>(i) =
+ static_cast<float>(const_node->at<loco::DataType::S16>(i) -
+ const_node->quantparam()->zerop.at(qd)) *
+ const_node->quantparam()->scale.at(qd);
+ break;
+ case loco::DataType::S32:
+ new_const_node->at<loco::DataType::FLOAT32>(i) =
+ static_cast<float>(const_node->at<loco::DataType::S32>(i) -
+ const_node->quantparam()->zerop.at(qd)) *
+ const_node->quantparam()->scale.at(qd);
+ break;
+ case loco::DataType::S64:
+ new_const_node->at<loco::DataType::FLOAT32>(i) =
+ static_cast<float>(const_node->at<loco::DataType::S64>(i) -
+ const_node->quantparam()->zerop.at(qd)) *
+ const_node->quantparam()->scale.at(qd);
+ break;
+ case loco::DataType::U8:
+ new_const_node->at<loco::DataType::FLOAT32>(i) =
+ static_cast<float>(const_node->at<loco::DataType::U8>(i) -
+ const_node->quantparam()->zerop.at(qd)) *
+ const_node->quantparam()->scale.at(qd);
+ break;
+ default:
+ throw std::runtime_error("Not supported dtype for FoldDequantizePass");
}
}
@@ -160,7 +198,7 @@ bool FoldDequantizePass::run(loco::Graph *g)
{
bool changed = false;
- for (auto node : loco::all_nodes(g))
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
if (auto circle_dequant = dynamic_cast<luci::CircleDequantize *>(node))
{