summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/QuantizationUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/pass/src/QuantizationUtils.cpp')
-rw-r--r--compiler/luci/pass/src/QuantizationUtils.cpp126
1 files changed, 118 insertions, 8 deletions
diff --git a/compiler/luci/pass/src/QuantizationUtils.cpp b/compiler/luci/pass/src/QuantizationUtils.cpp
index ad86cedf4..06a4ae9f6 100644
--- a/compiler/luci/pass/src/QuantizationUtils.cpp
+++ b/compiler/luci/pass/src/QuantizationUtils.cpp
@@ -20,6 +20,7 @@
#include <iostream>
#include <cmath>
+#include <limits>
namespace luci
{
@@ -276,31 +277,70 @@ uint32_t cal_offset(loco::TensorShape &dimension, uint32_t *indices)
indices[2] * dimension.dim(3).value() + indices[3];
}
+// Activation (ofm) qtype is determined in different ways.
+// 1. Pre-defined values: Some Ops have pre-defined qparams (ex: LOGISTIC, TANH)
+// 2. Integer scale: Output of some Ops should be integers (ex: FLOOR, CEIL)
+// 3. Activation qtype of input: Some Ops propagate qparam from input to output (ex: QUANTIZE,
+// TRANSPOSE, etc. See PropagateQParamForwardPass.cpp for more details).
ActivationQType activation_qtype(const CircleNode *node)
{
auto fused_act_node = dynamic_cast<const CircleNodeMixin<CircleNodeTrait::FusedActFunc> *>(node);
if (fused_act_node && fused_act_node->fusedActivationFunction() == FusedActFunc::TANH)
- return ActivationQType::PreDefinedValue;
+ return ActivationQType::PreDefinedTanh;
+
+#define RETURN_INPUT_ACTIVATION_QTYPE(CLASS, INPUT) \
+ { \
+ auto n = loco::must_cast<const CLASS *>(node); \
+ auto input = loco::must_cast<CircleNode *>(n->INPUT()); \
+ return activation_qtype(input); \
+ }
switch (node->opcode())
{
case CircleOpcode::LOGISTIC:
+ return ActivationQType::PreDefinedLogistic;
case CircleOpcode::TANH:
+ return ActivationQType::PreDefinedTanh;
case CircleOpcode::SOFTMAX:
- return ActivationQType::PreDefinedValue;
+ return ActivationQType::PreDefinedSoftmax;
case CircleOpcode::FLOOR:
case CircleOpcode::FLOOR_DIV:
case CircleOpcode::FLOOR_MOD:
case CircleOpcode::CEIL:
return ActivationQType::IntScale;
+ case CircleOpcode::GATHER:
+ RETURN_INPUT_ACTIVATION_QTYPE(CircleGather, params);
+ case CircleOpcode::RESHAPE:
+ RETURN_INPUT_ACTIVATION_QTYPE(CircleReshape, tensor);
+ case CircleOpcode::TRANSPOSE:
+ RETURN_INPUT_ACTIVATION_QTYPE(CircleTranspose, a);
+ case CircleOpcode::STRIDED_SLICE:
+ RETURN_INPUT_ACTIVATION_QTYPE(CircleStridedSlice, input);
+ case CircleOpcode::SPLIT:
+ RETURN_INPUT_ACTIVATION_QTYPE(CircleSplit, input);
+ case CircleOpcode::CIRCLESPLITOUT:
+ RETURN_INPUT_ACTIVATION_QTYPE(CircleSplitOut, input);
+ case CircleOpcode::SPLIT_V:
+ RETURN_INPUT_ACTIVATION_QTYPE(CircleSplitV, input);
+ case CircleOpcode::CIRCLESPLITVOUT:
+ RETURN_INPUT_ACTIVATION_QTYPE(CircleSplitVOut, input);
+ case CircleOpcode::UNPACK:
+ RETURN_INPUT_ACTIVATION_QTYPE(CircleUnpack, value);
+ case CircleOpcode::CIRCLEUNPACKOUT:
+ RETURN_INPUT_ACTIVATION_QTYPE(CircleUnpackOut, input);
+ case CircleOpcode::QUANTIZE:
+ RETURN_INPUT_ACTIVATION_QTYPE(CircleQuantize, input);
default:
break;
}
+#undef RETURN_INPUT_ACTIVATION_QTYPE
+
return ActivationQType::MinMax;
}
-std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, loco::DataType dtype)
+std::unique_ptr<CircleQuantParam> make_predefined_qparam(ActivationQType qtype,
+ loco::DataType dtype)
{
auto qparam = std::make_unique<CircleQuantParam>();
@@ -309,9 +349,9 @@ std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, lo
qparam->zerop.emplace_back(zp);
};
- switch (opcode)
+ switch (qtype)
{
- case CircleOpcode::LOGISTIC:
+ case ActivationQType::PreDefinedLogistic:
if (dtype == loco::DataType::U8)
set_qparam(1.0f / 256.0f, 0);
else
@@ -320,7 +360,7 @@ std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, lo
set_qparam(1.0f / 32768.0f, 0);
}
break;
- case CircleOpcode::TANH:
+ case ActivationQType::PreDefinedTanh:
if (dtype == loco::DataType::U8)
set_qparam(2.0f / 256.0f, 128);
else
@@ -329,7 +369,7 @@ std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, lo
set_qparam(1.0f / 32768.0f, 0);
}
break;
- case CircleOpcode::SOFTMAX:
+ case ActivationQType::PreDefinedSoftmax:
if (dtype == loco::DataType::U8)
set_qparam(1.0f / 255.0f, 0);
else
@@ -341,7 +381,7 @@ std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, lo
default:
throw std::runtime_error("Unsupported opcode with pre-defined qparam");
}
- return std::move(qparam);
+ return qparam;
}
// For nodes with integer output, we use integer scale
@@ -395,4 +435,74 @@ void quant_const(luci::CircleConst *node, loco::DataType quant_type)
node->quantparam(std::move(quantparam));
}
+namespace
+{
+
+// TODO move this to a more global helper file
+int nbits(loco::DataType dt) noexcept
+{
+ switch (dt)
+ {
+ case loco::DataType::S8:
+ case loco::DataType::U8:
+ return 8;
+ case loco::DataType::S16:
+ case loco::DataType::U16:
+ case loco::DataType::FLOAT16:
+ return 16;
+ case loco::DataType::S32:
+ case loco::DataType::U32:
+ case loco::DataType::FLOAT32:
+ return 32;
+ case loco::DataType::S64:
+ return 64;
+ default:
+ return 64; // a safe large default
+ }
+}
+
+// TODO Check if the metric is valid
+// Returns true if [min,max] is poorly representable
+bool range_check(float min, float max, loco::DataType dtype)
+{
+ float thresh = 1.5f;
+ return log2f(max) - log2f(min) > nbits(dtype) * thresh;
+}
+
+bool warn_scale_zp(float scale, int64_t zp, luci::CircleNode *n)
+{
+ float min, max;
+ // estimate min/max
+ switch (n->dtype())
+ {
+ case loco::DataType::U8:
+ min = scale * (0 - zp);
+ max = scale * (255 - zp);
+ break;
+ case loco::DataType::S16:
+ min = scale * (-32767);
+ max = scale * (32767);
+ break;
+ default:
+ return false;
+ }
+ return range_check(min, max, n->dtype());
+}
+
+} // namespace
+
+void warn_accuracy_with_range(luci::CircleNode *n)
+{
+ LOGGER(l);
+ auto qp = n->quantparam();
+ auto k = qp->zerop.size();
+ for (uint32_t i = 0; i < k; i++)
+ {
+ if (warn_scale_zp(qp->scale[i], qp->zerop[i], n))
+ WARN(l) << "Quantization of " << i << "-th channel of " << n->name()
+ << "'s quantization may cause accuracy issues" << std::endl;
+ ;
+ }
+}
+
} // namespace luci