diff options
Diffstat (limited to 'compiler/luci/pass/src/QuantizationUtils.cpp')
-rw-r--r-- | compiler/luci/pass/src/QuantizationUtils.cpp | 126 |
1 files changed, 118 insertions, 8 deletions
diff --git a/compiler/luci/pass/src/QuantizationUtils.cpp b/compiler/luci/pass/src/QuantizationUtils.cpp index ad86cedf4..06a4ae9f6 100644 --- a/compiler/luci/pass/src/QuantizationUtils.cpp +++ b/compiler/luci/pass/src/QuantizationUtils.cpp @@ -20,6 +20,7 @@ #include <iostream> #include <cmath> +#include <limits> namespace luci { @@ -276,31 +277,70 @@ uint32_t cal_offset(loco::TensorShape &dimension, uint32_t *indices) indices[2] * dimension.dim(3).value() + indices[3]; } +// Activation (ofm) qtype is determined in different ways. +// 1. Pre-defined values: Some Ops have pre-defined qparams (ex: LOGISTIC, TANH) +// 2. Integer scale: Output of some Ops should be integers (ex: FLOOR, CEIL) +// 3. Activation qtype of input: Some Ops propagate qparam from input to output (ex: QUANTIZE, +// TRANSPOSE, etc. See PropagateQParamForwardPass.cpp for more details). ActivationQType activation_qtype(const CircleNode *node) { auto fused_act_node = dynamic_cast<const CircleNodeMixin<CircleNodeTrait::FusedActFunc> *>(node); if (fused_act_node && fused_act_node->fusedActivationFunction() == FusedActFunc::TANH) - return ActivationQType::PreDefinedValue; + return ActivationQType::PreDefinedTanh; + +#define RETURN_INPUT_ACTIVATION_QTYPE(CLASS, INPUT) \ + { \ + auto n = loco::must_cast<const CLASS *>(node); \ + auto input = loco::must_cast<CircleNode *>(n->INPUT()); \ + return activation_qtype(input); \ + } switch (node->opcode()) { case CircleOpcode::LOGISTIC: + return ActivationQType::PreDefinedLogistic; case CircleOpcode::TANH: + return ActivationQType::PreDefinedTanh; case CircleOpcode::SOFTMAX: - return ActivationQType::PreDefinedValue; + return ActivationQType::PreDefinedSoftmax; case CircleOpcode::FLOOR: case CircleOpcode::FLOOR_DIV: case CircleOpcode::FLOOR_MOD: case CircleOpcode::CEIL: return ActivationQType::IntScale; + case CircleOpcode::GATHER: + RETURN_INPUT_ACTIVATION_QTYPE(CircleGather, params); + case CircleOpcode::RESHAPE: + RETURN_INPUT_ACTIVATION_QTYPE(CircleReshape, tensor); + case CircleOpcode::TRANSPOSE: + RETURN_INPUT_ACTIVATION_QTYPE(CircleTranspose, a); + case CircleOpcode::STRIDED_SLICE: + RETURN_INPUT_ACTIVATION_QTYPE(CircleStridedSlice, input); + case CircleOpcode::SPLIT: + RETURN_INPUT_ACTIVATION_QTYPE(CircleSplit, input); + case CircleOpcode::CIRCLESPLITOUT: + RETURN_INPUT_ACTIVATION_QTYPE(CircleSplitOut, input); + case CircleOpcode::SPLIT_V: + RETURN_INPUT_ACTIVATION_QTYPE(CircleSplitV, input); + case CircleOpcode::CIRCLESPLITVOUT: + RETURN_INPUT_ACTIVATION_QTYPE(CircleSplitVOut, input); + case CircleOpcode::UNPACK: + RETURN_INPUT_ACTIVATION_QTYPE(CircleUnpack, value); + case CircleOpcode::CIRCLEUNPACKOUT: + RETURN_INPUT_ACTIVATION_QTYPE(CircleUnpackOut, input); + case CircleOpcode::QUANTIZE: + RETURN_INPUT_ACTIVATION_QTYPE(CircleQuantize, input); default: break; } +#undef RETURN_INPUT_ACTIVATION_QTYPE + return ActivationQType::MinMax; } -std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, loco::DataType dtype) +std::unique_ptr<CircleQuantParam> make_predefined_qparam(ActivationQType qtype, + loco::DataType dtype) { auto qparam = std::make_unique<CircleQuantParam>(); @@ -309,9 +349,9 @@ std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, lo qparam->zerop.emplace_back(zp); }; - switch (opcode) + switch (qtype) { - case CircleOpcode::LOGISTIC: + case ActivationQType::PreDefinedLogistic: if (dtype == loco::DataType::U8) set_qparam(1.0f / 256.0f, 0); else @@ -320,7 +360,7 @@ std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, lo set_qparam(1.0f / 32768.0f, 0); } break; - case CircleOpcode::TANH: + case ActivationQType::PreDefinedTanh: if (dtype == loco::DataType::U8) set_qparam(2.0f / 256.0f, 128); else @@ -329,7 +369,7 @@ std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, lo set_qparam(1.0f / 32768.0f, 0); } break; - case CircleOpcode::SOFTMAX: + case ActivationQType::PreDefinedSoftmax: if (dtype == loco::DataType::U8) set_qparam(1.0f / 255.0f, 0); else @@ -341,7 +381,7 @@ std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, lo default: throw std::runtime_error("Unsupported opcode with pre-defined qparam"); } - return std::move(qparam); + return qparam; } // For nodes with integer output, we use integer scale @@ -395,4 +435,74 @@ void quant_const(luci::CircleConst *node, loco::DataType quant_type) node->quantparam(std::move(quantparam)); } +namespace +{ + +// TODO move this to a more global helper file +int nbits(loco::DataType dt) noexcept +{ + switch (dt) + { + case loco::DataType::S8: + case loco::DataType::U8: + return 8; + case loco::DataType::S16: + case loco::DataType::U16: + case loco::DataType::FLOAT16: + return 16; + case loco::DataType::S32: + case loco::DataType::U32: + case loco::DataType::FLOAT32: + return 32; + case loco::DataType::S64: + return 64; + default: + return 64; // a safe large default + } +} + +// TODO Check if the metric is valid +// Returns true if [min,max] is poorly representable +bool range_check(float min, float max, loco::DataType dtype) +{ + float thresh = 1.5f; + return log2f(max) - log2f(min) > nbits(dtype) * thresh; +} + +bool warn_scale_zp(float scale, int64_t zp, luci::CircleNode *n) +{ + float min, max; + // estimate min/max + switch (n->dtype()) + { + case loco::DataType::U8: + min = scale * (0 - zp); + max = scale * (255 - zp); + break; + case loco::DataType::S16: + min = scale * (-32767); + max = scale * (32767); + break; + default: + return false; + } + return range_check(min, max, n->dtype()); +} + +} // namespace + +void warn_accuracy_with_range(luci::CircleNode *n) +{ + LOGGER(l); + auto qp = n->quantparam(); + auto k = qp->zerop.size(); + for (uint32_t i = 0; i < k; i++) + { + if (warn_scale_zp(qp->scale[i], qp->zerop[i], n)) + WARN(l) << "Quantization of " << i << "-th channel of " << n->name() + << "'s quantization may cause accuracy issues" << std::endl; + ; + } +} + } // namespace luci |