diff options
Diffstat (limited to 'compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp')
-rw-r--r-- | compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp | 91 |
1 files changed, 80 insertions, 11 deletions
diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp index d9a9d4db7..005144516 100644 --- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp +++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp @@ -41,10 +41,28 @@ namespace { using namespace luci; + +bool use_predefined_values(ActivationQType qtype) +{ + switch (qtype) + { + case ActivationQType::PreDefinedLogistic: + case ActivationQType::PreDefinedTanh: + case ActivationQType::PreDefinedSoftmax: + return true; + default: + // This ensures this switch-statement handles all ActivationQTypes + assert(qtype == ActivationQType::IntScale or qtype == ActivationQType::MinMax); + break; + } + + return false; +} + // Create a Quantize Op whose // dtype is out_type // shape is the same with node -// qparam is computed using node's min/max +// qparam is computed according to node's qtype luci::CircleQuantize *create_quantize_op(luci::CircleNode *node, loco::DataType out_type) { auto quantize = node->graph()->nodes()->create<CircleQuantize>(); @@ -60,9 +78,9 @@ luci::CircleQuantize *create_quantize_op(luci::CircleNode *node, loco::DataType assert(qparam); // FIX_CALLER_UNLESS auto qtype = luci::activation_qtype(node); - if (qtype == ActivationQType::PreDefinedValue) + if (use_predefined_values(qtype)) { - quantize->quantparam(luci::make_predefined_qparam(node->opcode(), out_type)); + quantize->quantparam(luci::make_predefined_qparam(qtype, out_type)); return quantize; } @@ -105,6 +123,23 @@ luci::CircleQuantize *create_quantize_op(luci::CircleNode *node, loco::DataType return quantize; } +// Create Dequantize Op whose shape is the same with node +luci::CircleDequantize *create_dequantize(luci::CircleNode *node) +{ + auto dequantize = node->graph()->nodes()->create<luci::CircleDequantize>(); + dequantize->name(node->name() + "_Dequantize"); + dequantize->dtype(loco::DataType::FLOAT32); + dequantize->rank(node->rank()); + for (uint32_t i = 0; i < node->rank(); i++) + dequantize->dim(i).set(node->dim(i).value()); + + dequantize->shape_status(luci::ShapeStatus::VALID); + + luci::add_origin(dequantize, luci::get_origin(node)); + + return dequantize; +} + } // namespace namespace luci @@ -229,11 +264,13 @@ private: INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleFullyConnected, input) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleGather, params) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleInstanceNorm, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLeakyRelu, features) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLocalResponseNormalization, input) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLogistic, x) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMaxPool2D, value) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMean, input) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMirrorPad, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleNeg, x) INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePad, input) INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePadV2, input) INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePRelu, input) @@ -241,6 +278,7 @@ private: INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReduceMax, input) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReduceMin, input) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleRelu, features) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleRelu6, features) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReshape, tensor) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleResizeBilinear, input) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleResizeNearestNeighbor, input) @@ -250,6 +288,7 @@ private: INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSoftmax, logits) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSpaceToBatchND, input) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSpaceToDepth, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSqueeze, input) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSqrt, x) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleStridedSlice, input) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSum, input) @@ -353,7 +392,9 @@ void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const luci::add_origin(quant_op, luci::get_origin(succ)); } - // Requantize input + // Update qparam of input + // This step is skipped if input_type is float32 + if (_ctx->input_type != loco::DataType::FLOAT32) { auto quantparam = input->quantparam(); assert(quantparam); @@ -376,11 +417,13 @@ void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const assert(_ctx->input_type == loco::DataType::S16); compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max); } - input->dtype(_ctx->input_type); input->quantparam()->scale[0] = scaling_factor; input->quantparam()->zerop[0] = zp; } + // Update dtype of input + input->dtype(_ctx->input_type); + auto graph_input = inputs->at(input->index()); graph_input->dtype(_ctx->input_type); } @@ -405,13 +448,26 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const if (not from->quantparam()) continue; - // Insert Quantize Op - auto quant_op = create_quantize_op(from, _ctx->output_type); - loco::replace(from).with(quant_op); - quant_op->input(from); + // Insert Dequantize Op for float32 output_type + if (_ctx->output_type == loco::DataType::FLOAT32) + { + auto dequant_op = create_dequantize(from); + loco::replace(from).with(dequant_op); + dequant_op->input(from); + } + else + { + // Insert Quantize Op for non-float32 output_type + auto quant_op = create_quantize_op(from, _ctx->output_type); + loco::replace(from).with(quant_op); + quant_op->input(from); - // TODO Set a proper origin (Quantize should have its own Origin) - luci::add_origin(quant_op, luci::get_origin(from)); + // TODO Set a proper origin (Quantize should have its own Origin) + luci::add_origin(quant_op, luci::get_origin(from)); + } + + // Update dtype of output + output->dtype(_ctx->output_type); auto graph_output = outputs->at(output->index()); graph_output->dtype(_ctx->output_type); @@ -594,12 +650,25 @@ bool QuantizeWithMinMaxPass::run(loco::Graph *g) // Set output type set_output_type(g); + // Remove redundant Quantize Op + { + logo::Phase phase; + + phase.emplace_back(std::make_unique<luci::RemoveRedundantQuantizePass>()); + + ProgressReporter prog(g, logo::PhaseStrategy::Saturate); + logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g}; + phase_runner.attach(&prog); + phase_runner.run(phase); + } + // Remove min/max values for (auto node : loco::active_nodes(loco::output_nodes(g))) { auto circle_node = loco::must_cast<luci::CircleNode *>(node); if (auto qparam = circle_node->quantparam()) { + warn_accuracy_with_range(circle_node); qparam->min.clear(); qparam->max.clear(); } |