diff options
Diffstat (limited to 'compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp')
-rw-r--r-- | compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp | 25 |
1 files changed, 17 insertions, 8 deletions
diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp index f8abee751..b335a53b4 100644 --- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp +++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp @@ -138,7 +138,8 @@ bool is_quantized(const CircleNode *node) node->dtype() == loco::DataType::S32; // bias } -void sym_wquant_per_channel(CircleConst *node, std::vector<float> &scaling_factor) +void sym_wquant_per_channel(CircleConst *node, std::vector<float> &scaling_factor, + int32_t &channel_dim_index) { assert(node->dtype() == loco::DataType::FLOAT32); @@ -153,7 +154,6 @@ void sym_wquant_per_channel(CircleConst *node, std::vector<float> &scaling_facto uint32_t indices[4] = { 0, }; - int channel_dim_index{0}; if (!get_channel_dim_index(node, dimension, channel_dim_index)) { @@ -189,7 +189,7 @@ void sym_wquant_per_channel(CircleConst *node, std::vector<float> &scaling_facto } void asym_wquant_per_channel(CircleConst *node, std::vector<float> &min, - std::vector<float> &scaling_factor) + std::vector<float> &scaling_factor, int32_t &channel_dim_index) { assert(node->dtype() == loco::DataType::FLOAT32); @@ -204,7 +204,6 @@ void asym_wquant_per_channel(CircleConst *node, std::vector<float> &min, uint32_t indices[4] = { 0, }; - int channel_dim_index{0}; if (!get_channel_dim_index(node, dimension, channel_dim_index)) { @@ -282,6 +281,10 @@ bool is_weights(CircleNode *node) if (dw_conv != nullptr && dw_conv->filter() == circle_const) return true; + auto t_conv = dynamic_cast<CircleTransposeConv *>(out); + if (t_conv != nullptr && t_conv->filter() == circle_const && circle_const->rank() == 4) + return true; + auto fc = dynamic_cast<CircleFullyConnected *>(out); if (fc != nullptr && fc->weights() == circle_const) return true; @@ -350,8 +353,8 @@ struct QuantizeActivation final : public luci::CircleNodeMutableVisitor<bool> circle_node->dtype(loco::DataType::S16); } - circle_node->quantparam()->max[0] = nudged_max; - circle_node->quantparam()->min[0] = nudged_min; + circle_node->quantparam()->min.clear(); + circle_node->quantparam()->max.clear(); circle_node->quantparam()->scale.push_back(scaling_factor); circle_node->quantparam()->zerop.push_back(zp); } @@ -472,15 +475,19 @@ struct QuantizeWeights final : public luci::CircleNodeMutableVisitor<bool> assert(quantparam != nullptr); auto min = quantparam->min; auto scaling_factor = quantparam->scale; + int32_t channel_dim_index = 0; if (output_type == loco::DataType::U8) { - asym_wquant_per_channel(circle_const, min, scaling_factor); + asym_wquant_per_channel(circle_const, min, scaling_factor, channel_dim_index); } else { - sym_wquant_per_channel(circle_const, scaling_factor); + sym_wquant_per_channel(circle_const, scaling_factor, channel_dim_index); } + quantparam->min.clear(); + quantparam->max.clear(); + quantparam->quantized_dimension = channel_dim_index; } // Find min/max per layer-wise else @@ -493,6 +500,8 @@ struct QuantizeWeights final : public luci::CircleNodeMutableVisitor<bool> auto min = quantparam->min[0]; auto scaling_factor = quantparam->scale[0]; asym_wquant_per_layer(circle_const, min, scaling_factor); + quantparam->min.clear(); + quantparam->max.clear(); } } } |