diff options
Diffstat (limited to 'compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp')
-rw-r--r-- | compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp | 259 |
1 files changed, 181 insertions, 78 deletions
diff --git a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp index c8ad87e3d..c9b35e0be 100644 --- a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp +++ b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp @@ -16,9 +16,11 @@ #include "luci/Pass/QuantizeDequantizeWeightsPass.h" #include "QuantizationUtils.h" +#include "helpers/LayerInfoMap.h" #include <luci/IR/CircleNodes.h> #include <luci/IR/CircleNodeVisitor.h> +#include <luci/Service/Nodes/CircleConst.h> #include <luci/Log.h> #include <loco/IR/TensorShape.h> @@ -251,7 +253,7 @@ void asymmetric_wdequant_with_minmax_per_layer(CircleConst *node, float scaling_ * @brief QuantizeDequantizeWeights quantizes and dequantizes tensors for weights * @details Find min/max values on the fly, quantize the model, and dequantize the model */ -struct QuantizeDequantizeWeights final : public luci::CircleNodeMutableVisitor<bool> +struct QuantizeDequantizeWeights final : public luci::CircleNodeMutableVisitor<void> { QuantizeDequantizeWeights(loco::DataType input, loco::DataType output, QuantizationGranularity granularity) @@ -263,88 +265,164 @@ struct QuantizeDequantizeWeights final : public luci::CircleNodeMutableVisitor<b loco::DataType output_type; QuantizationGranularity granularity; - // Quantize and dequantize input tensors of each node - bool visit(luci::CircleNode *node) +private: + // Fake quantize weights (Only u8 quantization is supported for LWQ) + void fake_quantize_lwq(luci::CircleConst *weights) const { - assert(output_type == loco::DataType::U8 || output_type == loco::DataType::S16); - LOGGER(l); - INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl; - auto arity = node->arity(); - for (uint32_t i = 0; i < arity; i++) + assert(output_type == loco::DataType::U8); // FIX_CALLER_UNLESS + + // Find min/max per layer + float min = std::numeric_limits<float>::max(); + float max = std::numeric_limits<float>::lowest(); + for (uint32_t i = 0; i < weights->size<loco::DataType::FLOAT32>(); i++) { - auto input_node = node->arg(i); - auto circle_node = loco::must_cast<luci::CircleNode *>(input_node); + auto data = weights->at<loco::DataType::FLOAT32>(i); + min = data < min ? data : min; + max = data > max ? data : max; + } + float scaling_factor{0}; + int64_t zp{0}; + float nudged_min{0}; + float nudged_max{0}; + + asymmetric_wquant_with_minmax_per_layer(weights, min, max, scaling_factor, zp, nudged_min, + nudged_max); + asymmetric_wdequant_with_minmax_per_layer(weights, scaling_factor, nudged_min); + auto quantparam = std::make_unique<CircleQuantParam>(); + quantparam->min.push_back(nudged_min); + quantparam->max.push_back(nudged_max); + quantparam->scale.push_back(scaling_factor); + quantparam->zerop.push_back(zp); + weights->quantparam(std::move(quantparam)); + } - // Check if this is already quantized - if (is_quantized(circle_node)) - continue; +private: + // Fake quantize weights (u8/s16 quantization are supported for CWQ) + void fake_quantize_cwq(luci::CircleConst *weights) const + { + assert(output_type == loco::DataType::U8 || + output_type == loco::DataType::S16); // FIX_CALLER_UNLESS - if (is_weights(circle_node)) - { - auto circle_const = loco::must_cast<luci::CircleConst *>(circle_node); + // Find min/max per channel + std::vector<float> min; + std::vector<float> max; - // Find min/max per channel-wise - if (granularity == QuantizationGranularity::ChannelWise) - { - std::vector<float> min; - std::vector<float> max; - - cal_minmax_per_channel(circle_const, min, max); - - std::vector<float> nudged_min(min.size()); - std::vector<float> nudged_max(min.size()); - std::vector<float> scaling_factor(min.size()); - std::vector<int64_t> zp(min.size()); - - if (output_type == loco::DataType::U8) - { - asymmetric_wquant_per_channel(circle_const, min, max, scaling_factor, zp, nudged_min, - nudged_max); - asymmetric_wdequant_per_channel(circle_const, scaling_factor, nudged_min); - } - else - { - sym_wquant_per_channel(circle_const, min, max, scaling_factor, zp, nudged_min, - nudged_max); - sym_wdequant_per_channel(circle_const, scaling_factor); - } - - auto quantparam = std::make_unique<CircleQuantParam>(); - quantparam->min = nudged_min; - quantparam->max = nudged_max; - quantparam->scale = scaling_factor; - quantparam->zerop = zp; - circle_node->quantparam(std::move(quantparam)); - } - // Find min/max per layer-wise - else - { - float min = std::numeric_limits<float>::max(); - float max = std::numeric_limits<float>::lowest(); - for (uint32_t i = 0; i < circle_const->size<loco::DataType::FLOAT32>(); i++) - { - auto data = circle_const->at<loco::DataType::FLOAT32>(i); - min = data < min ? data : min; - max = data > max ? data : max; - } - float scaling_factor{0}; - int64_t zp{0}; - float nudged_min{0}; - float nudged_max{0}; - - asymmetric_wquant_with_minmax_per_layer(circle_const, min, max, scaling_factor, zp, - nudged_min, nudged_max); - asymmetric_wdequant_with_minmax_per_layer(circle_const, scaling_factor, nudged_min); - auto quantparam = std::make_unique<CircleQuantParam>(); - quantparam->min.push_back(nudged_min); - quantparam->max.push_back(nudged_max); - quantparam->scale.push_back(scaling_factor); - quantparam->zerop.push_back(zp); - circle_node->quantparam(std::move(quantparam)); - } - } + cal_minmax_per_channel(weights, min, max); + + std::vector<float> nudged_min(min.size()); + std::vector<float> nudged_max(min.size()); + std::vector<float> scaling_factor(min.size()); + std::vector<int64_t> zp(min.size()); + + if (output_type == loco::DataType::U8) + { + asymmetric_wquant_per_channel(weights, min, max, scaling_factor, zp, nudged_min, nudged_max); + asymmetric_wdequant_per_channel(weights, scaling_factor, nudged_min); + } + else + { + sym_wquant_per_channel(weights, min, max, scaling_factor, zp, nudged_min, nudged_max); + sym_wdequant_per_channel(weights, scaling_factor); } - return false; + + auto quantparam = std::make_unique<CircleQuantParam>(); + quantparam->min = nudged_min; + quantparam->max = nudged_max; + quantparam->scale = scaling_factor; + quantparam->zerop = zp; + weights->quantparam(std::move(quantparam)); + } + +private: + void fake_quantize(luci::CircleConst *weights) const + { + switch (granularity) + { + case luci::QuantizationGranularity::ChannelWise: + fake_quantize_cwq(weights); + break; + case luci::QuantizationGranularity::LayerWise: + fake_quantize_lwq(weights); + break; + default: + throw std::invalid_argument("Unsupported granularity"); + } + } + +private: + // Check if + // 1. node is const + // 2. node was not quantized + bool is_quantizable(loco::Node *node) + { + auto const_node = dynamic_cast<luci::CircleConst *>(node); + if (not const_node) + return false; + + // Skip if this is already quantized + if (is_quantized(const_node)) + return false; + + return true; + } + + // Default behavior (Do nothing) + void visit(luci::CircleNode *) {} + + void visit(luci::CircleConv2D *node) + { + LOGGER(l); + INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl; + + if (not is_quantizable(node->filter())) + return; + + auto weights = loco::must_cast<luci::CircleConst *>(node->filter()); + auto new_weights = luci::clone(weights); + node->filter(new_weights); + fake_quantize(new_weights); + } + + void visit(luci::CircleDepthwiseConv2D *node) + { + LOGGER(l); + INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl; + + if (not is_quantizable(node->filter())) + return; + + auto weights = loco::must_cast<luci::CircleConst *>(node->filter()); + auto new_weights = luci::clone(weights); + node->filter(new_weights); + fake_quantize(new_weights); + } + + void visit(luci::CircleTransposeConv *node) + { + LOGGER(l); + INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl; + + if (not is_quantizable(node->filter())) + return; + + auto weights = loco::must_cast<luci::CircleConst *>(node->filter()); + auto new_weights = luci::clone(weights); + node->filter(new_weights); + fake_quantize(new_weights); + } + + void visit(luci::CircleFullyConnected *node) + { + LOGGER(l); + INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl; + + if (not is_quantizable(node->weights())) + return; + + auto weights = loco::must_cast<luci::CircleConst *>(node->weights()); + auto new_weights = luci::clone(weights); + node->weights(new_weights); + fake_quantize(new_weights); } }; @@ -355,11 +433,36 @@ bool QuantizeDequantizeWeightsPass::run(loco::Graph *g) LOGGER(l); INFO(l) << "QuantizeDequantizeWeightsPass Start" << std::endl; + auto info_by_name = layer_info_map(g, _ctx->layers_info); + + auto quantize_dtype = [&](const luci::CircleNode *node) { + auto iter = info_by_name.find(node->name()); + + // Return designated quantization dtype + if (iter != info_by_name.end()) + return iter->second.dtype; + + // Return default quantization dtype + return _ctx->output_model_dtype; + }; + + auto quantize_granularity = [&](const luci::CircleNode *node) { + auto iter = info_by_name.find(node->name()); + + // Return designated quantization granularity + if (iter != info_by_name.end()) + return iter->second.granularity; + + // Return default quantization granularity + return _ctx->granularity; + }; + // Quantize weights for (auto node : loco::active_nodes(loco::output_nodes(g))) { - QuantizeDequantizeWeights qw(_input_model_dtype, _output_model_dtype, _granularity); auto circle_node = loco::must_cast<luci::CircleNode *>(node); + QuantizeDequantizeWeights qw(_ctx->input_model_dtype, quantize_dtype(circle_node), + quantize_granularity(circle_node)); circle_node->accept(&qw); } |