summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp')
-rw-r--r--compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp16
1 files changed, 12 insertions, 4 deletions
diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
index c68e06712..4f4edaf36 100644
--- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
+++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
@@ -101,7 +101,7 @@ luci::CircleQuantize *create_quantize_op(luci::CircleNode *node, loco::DataType
else
{
assert(out_type == loco::DataType::S16);
- compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
+ compute_sym_scale(min, max, scaling_factor, nudged_min, nudged_max);
}
auto quantparam = std::make_unique<CircleQuantParam>();
@@ -271,6 +271,7 @@ private:
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleFloor, x)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleFullyConnected, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleGather, params)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleGelu, features)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleInstanceNorm, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLeakyRelu, features)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLocalResponseNormalization, input)
@@ -433,7 +434,7 @@ void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const
else
{
assert(user_given_dtype == loco::DataType::S16);
- compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
+ compute_sym_scale(min, max, scaling_factor, nudged_min, nudged_max);
}
input->quantparam()->scale[0] = scaling_factor;
input->quantparam()->zerop[0] = zp;
@@ -479,15 +480,15 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const
if (user_given_dtype == loco::DataType::FLOAT32)
{
auto dequant_op = create_dequantize(from);
- loco::replace(from).with(dequant_op);
dequant_op->input(from);
+ output->from(dequant_op);
}
else
{
// Insert Quantize Op for non-float32 output_type
auto quant_op = create_quantize_op(from, user_given_dtype);
- loco::replace(from).with(quant_op);
quant_op->input(from);
+ output->from(quant_op);
// TODO Set a proper origin (Quantize should have its own Origin)
luci::add_origin(quant_op, luci::get_origin(from));
@@ -629,6 +630,13 @@ bool QuantizeWithMinMaxPass::run(loco::Graph *g)
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+
+ // At this point, all activations have to be quantized.
+ // Un-quantized nodes are not the quantization target (ex: int32 tensor),
+ // so we skip them
+ if (circle_node->quantparam() == nullptr)
+ continue;
+
QuantizeSpecialActivation qsa(_ctx->input_model_dtype, quantize_dtype(circle_node));
circle_node->accept(&qsa);
}