summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp')
-rw-r--r--compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp91
1 files changed, 80 insertions, 11 deletions
diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
index d9a9d4db7..005144516 100644
--- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
+++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
@@ -41,10 +41,28 @@ namespace
{
using namespace luci;
+
+bool use_predefined_values(ActivationQType qtype)
+{
+ switch (qtype)
+ {
+ case ActivationQType::PreDefinedLogistic:
+ case ActivationQType::PreDefinedTanh:
+ case ActivationQType::PreDefinedSoftmax:
+ return true;
+ default:
+ // This ensures this switch-statement handles all ActivationQTypes
+ assert(qtype == ActivationQType::IntScale or qtype == ActivationQType::MinMax);
+ break;
+ }
+
+ return false;
+}
+
// Create a Quantize Op whose
// dtype is out_type
// shape is the same with node
-// qparam is computed using node's min/max
+// qparam is computed according to node's qtype
luci::CircleQuantize *create_quantize_op(luci::CircleNode *node, loco::DataType out_type)
{
auto quantize = node->graph()->nodes()->create<CircleQuantize>();
@@ -60,9 +78,9 @@ luci::CircleQuantize *create_quantize_op(luci::CircleNode *node, loco::DataType
assert(qparam); // FIX_CALLER_UNLESS
auto qtype = luci::activation_qtype(node);
- if (qtype == ActivationQType::PreDefinedValue)
+ if (use_predefined_values(qtype))
{
- quantize->quantparam(luci::make_predefined_qparam(node->opcode(), out_type));
+ quantize->quantparam(luci::make_predefined_qparam(qtype, out_type));
return quantize;
}
@@ -105,6 +123,23 @@ luci::CircleQuantize *create_quantize_op(luci::CircleNode *node, loco::DataType
return quantize;
}
+// Create Dequantize Op whose shape is the same with node
+luci::CircleDequantize *create_dequantize(luci::CircleNode *node)
+{
+ auto dequantize = node->graph()->nodes()->create<luci::CircleDequantize>();
+ dequantize->name(node->name() + "_Dequantize");
+ dequantize->dtype(loco::DataType::FLOAT32);
+ dequantize->rank(node->rank());
+ for (uint32_t i = 0; i < node->rank(); i++)
+ dequantize->dim(i).set(node->dim(i).value());
+
+ dequantize->shape_status(luci::ShapeStatus::VALID);
+
+ luci::add_origin(dequantize, luci::get_origin(node));
+
+ return dequantize;
+}
+
} // namespace
namespace luci
@@ -229,11 +264,13 @@ private:
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleFullyConnected, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleGather, params)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleInstanceNorm, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLeakyRelu, features)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLocalResponseNormalization, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLogistic, x)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMaxPool2D, value)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMean, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMirrorPad, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleNeg, x)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePad, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePadV2, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePRelu, input)
@@ -241,6 +278,7 @@ private:
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReduceMax, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReduceMin, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleRelu, features)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleRelu6, features)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReshape, tensor)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleResizeBilinear, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleResizeNearestNeighbor, input)
@@ -250,6 +288,7 @@ private:
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSoftmax, logits)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSpaceToBatchND, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSpaceToDepth, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSqueeze, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSqrt, x)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleStridedSlice, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSum, input)
@@ -353,7 +392,9 @@ void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const
luci::add_origin(quant_op, luci::get_origin(succ));
}
- // Requantize input
+ // Update qparam of input
+ // This step is skipped if input_type is float32
+ if (_ctx->input_type != loco::DataType::FLOAT32)
{
auto quantparam = input->quantparam();
assert(quantparam);
@@ -376,11 +417,13 @@ void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const
assert(_ctx->input_type == loco::DataType::S16);
compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
}
- input->dtype(_ctx->input_type);
input->quantparam()->scale[0] = scaling_factor;
input->quantparam()->zerop[0] = zp;
}
+ // Update dtype of input
+ input->dtype(_ctx->input_type);
+
auto graph_input = inputs->at(input->index());
graph_input->dtype(_ctx->input_type);
}
@@ -405,13 +448,26 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const
if (not from->quantparam())
continue;
- // Insert Quantize Op
- auto quant_op = create_quantize_op(from, _ctx->output_type);
- loco::replace(from).with(quant_op);
- quant_op->input(from);
+ // Insert Dequantize Op for float32 output_type
+ if (_ctx->output_type == loco::DataType::FLOAT32)
+ {
+ auto dequant_op = create_dequantize(from);
+ loco::replace(from).with(dequant_op);
+ dequant_op->input(from);
+ }
+ else
+ {
+ // Insert Quantize Op for non-float32 output_type
+ auto quant_op = create_quantize_op(from, _ctx->output_type);
+ loco::replace(from).with(quant_op);
+ quant_op->input(from);
- // TODO Set a proper origin (Quantize should have its own Origin)
- luci::add_origin(quant_op, luci::get_origin(from));
+ // TODO Set a proper origin (Quantize should have its own Origin)
+ luci::add_origin(quant_op, luci::get_origin(from));
+ }
+
+ // Update dtype of output
+ output->dtype(_ctx->output_type);
auto graph_output = outputs->at(output->index());
graph_output->dtype(_ctx->output_type);
@@ -594,12 +650,25 @@ bool QuantizeWithMinMaxPass::run(loco::Graph *g)
// Set output type
set_output_type(g);
+ // Remove redundant Quantize Op
+ {
+ logo::Phase phase;
+
+ phase.emplace_back(std::make_unique<luci::RemoveRedundantQuantizePass>());
+
+ ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
+ logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
+ phase_runner.attach(&prog);
+ phase_runner.run(phase);
+ }
+
// Remove min/max values
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
if (auto qparam = circle_node->quantparam())
{
+ warn_accuracy_with_range(circle_node);
qparam->min.clear();
qparam->max.clear();
}