diff options
Diffstat (limited to 'compiler/luci/pass/src')
-rw-r--r-- | compiler/luci/pass/src/CircleOptimizer.cpp | 4 | ||||
-rw-r--r-- | compiler/luci/pass/src/FuseBCQPass.cpp | 435 | ||||
-rw-r--r-- | compiler/luci/pass/src/QuantizationUtils.cpp | 20 | ||||
-rw-r--r-- | compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp | 25 |
4 files changed, 289 insertions, 195 deletions
diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp index 90fbe9009..2edf7a9c6 100644 --- a/compiler/luci/pass/src/CircleOptimizer.cpp +++ b/compiler/luci/pass/src/CircleOptimizer.cpp @@ -145,7 +145,7 @@ void CircleOptimizer::quantize(loco::Graph *g) const { static const std::vector<std::string> fakeq_supported_input_dtype{"float32"}; static const std::vector<std::string> fakeq_supported_output_dtype{"uint8"}; - static const std::vector<std::string> fakeq_supported_granularity{"layer"}; + static const std::vector<std::string> fakeq_supported_granularity{"layer", "channel"}; auto input_dtype = _options->param(Options::AlgorithmParameters::Quantize_input_dtype); auto output_dtype = _options->param(Options::AlgorithmParameters::Quantize_output_dtype); @@ -173,7 +173,7 @@ void CircleOptimizer::quantize(loco::Graph *g) const { static const std::vector<std::string> qwmm_supported_input_dtype{"float32"}; static const std::vector<std::string> qwmm_supported_output_dtype{"uint8"}; - static const std::vector<std::string> qwmm_supported_granularity{"layer"}; + static const std::vector<std::string> qwmm_supported_granularity{"layer", "channel"}; auto input_dtype = _options->param(Options::AlgorithmParameters::Quantize_input_dtype); auto output_dtype = _options->param(Options::AlgorithmParameters::Quantize_output_dtype); diff --git a/compiler/luci/pass/src/FuseBCQPass.cpp b/compiler/luci/pass/src/FuseBCQPass.cpp index b81db8827..260de5b30 100644 --- a/compiler/luci/pass/src/FuseBCQPass.cpp +++ b/compiler/luci/pass/src/FuseBCQPass.cpp @@ -53,6 +53,11 @@ const std::string node_name_prefix(luci::NodeName node_name) const auto index = prefix.find("Tensordot/"); prefix = prefix.substr(0, index - 1); } + else if (prefix.find("/MatMul") != std::string::npos) + { + const auto index = prefix.find("/MatMul"); + prefix = prefix.substr(0, index); + } else if (prefix.find("kernel/") != std::string::npos) { const auto index = prefix.find("kernel/"); @@ -67,14 +72,190 @@ const std::string node_name_prefix(luci::NodeName node_name) return prefix; } +/** + * @brief Create CircleOutputExclude operation, which has same shape and dtype with + * original circle_node. + */ +luci::CircleOutputExclude *createNoOp(luci::CircleNode *circle_node) +{ + auto graph = circle_node->graph(); + auto noOp = graph->nodes()->create<luci::CircleOutputExclude>(); + + if (circle_node->shape_status() == luci::ShapeStatus::VALID) + { + noOp->dtype(circle_node->dtype()); + noOp->rank(circle_node->rank()); + for (uint32_t i = 0; i < circle_node->rank(); ++i) + noOp->dim(i) = circle_node->dim(i); + } + else + { + // For type inference + noOp->dtype(loco::DataType::FLOAT32); + } + + return noOp; +}; + } // namespace namespace { -class BCQConverter final +// V means the version of BCQ. +template <int32_t V> class BCQFuser; + +template <> class BCQFuser<1> { public: + bool fuseBCQ(loco::Graph *g) + { + bool changed = false; + + for (auto node : loco::all_nodes(g)) + { + if (auto circle_const = dynamic_cast<luci::CircleConst *>(node)) + { + add_BCQ_info_node(circle_const); + } + } + + if (!is_bcqinfo_valid()) + return false; + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto gather = dynamic_cast<luci::CircleGather *>(node)) + { + auto params = dynamic_cast<luci::CircleConst *>(gather->params()); + if (params != nullptr && has_BCQ_info(params)) + { + auto bcq_gather = g->nodes()->create<luci::CircleBCQGather>(); + + bcq_gather->op_version(1); + bcq_gather->input_scales(get_alpha(params)); + bcq_gather->input_binary(get_packed_binary_code(params)); + bcq_gather->indices(gather->indices()); + bcq_gather->input_clusters(packed_clusters(params)); + + // input_binary shape : [output_size, hidden_size] + const auto binary_hidden_size = + loco::must_cast<luci::CircleConst *>(bcq_gather->input_binary())->dim(1).value() * 32; + bcq_gather->input_hidden_size(binary_hidden_size); + + if (do_w_x(params)) + { + bcq_gather->axis(gather->axis()); + } + else + { + const auto axis_transpose = (gather->axis() == 0) ? 1 : 0; + bcq_gather->axis(axis_transpose); + } + + loco::replace(gather).with(bcq_gather); + + changed = true; + } + } + else if (auto fully_connected = dynamic_cast<luci::CircleFullyConnected *>(node)) + { + auto weights = dynamic_cast<luci::CircleConst *>(fully_connected->weights()); + if (weights != nullptr && has_BCQ_info(weights)) + { + auto bcq_fc = g->nodes()->create<luci::CircleBCQFullyConnected>(); + + bcq_fc->op_version(1); + bcq_fc->weights_scales(get_alpha(weights)); + bcq_fc->weights_binary(get_packed_binary_code(weights)); + bcq_fc->bias(fully_connected->bias()); + bcq_fc->weights_clusters(packed_clusters(weights)); + bcq_fc->fusedActivationFunction(fully_connected->fusedActivationFunction()); + + loco::Node *bcq_input = fully_connected->input(); + int32_t batch_rank = 0; + + // If input of BCQFullyConnected has more than rank 2, we should reshape it as rank 2 + const auto original_input = loco::must_cast<luci::CircleNode *>(fully_connected->input()); + if (original_input->shape_status() == luci::ShapeStatus::VALID && + original_input->rank() > 2) + { + auto new_shape = g->nodes()->create<luci::CircleConst>(); + new_shape->dtype(loco::DataType::S32); + new_shape->size<loco::DataType::S32>(2); + new_shape->rank(1); + new_shape->dim(0) = 2; + + auto batch_size = 1; + for (uint32_t i = 0; i < original_input->rank() - 1; ++i) + batch_size *= original_input->dim(i).value(); + + new_shape->at<loco::DataType::S32>(0) = batch_size; + new_shape->at<loco::DataType::S32>(1) = + original_input->dim(original_input->rank() - 1).value(); + new_shape->shape_status(luci::ShapeStatus::VALID); + + auto reshape = g->nodes()->create<luci::CircleReshape>(); + reshape->tensor(original_input); + reshape->shape(new_shape); + + bcq_input = reshape; + batch_rank = original_input->rank() - 2; + } + + // If x_w formation, we should insert Transpose in front and back of BCQFullyConnected + if (do_w_x(weights)) + { + const auto binary_hidden_size = + loco::must_cast<luci::CircleNode *>(fully_connected->input()) + ->dim(batch_rank) + .value(); + bcq_fc->weights_hidden_size(binary_hidden_size); + bcq_fc->input(bcq_input); + loco::replace(fully_connected).with(bcq_fc); + } + else + { + const auto binary_hidden_size = + loco::must_cast<luci::CircleNode *>(fully_connected->input()) + ->dim(1 + batch_rank) + .value(); + bcq_fc->weights_hidden_size(binary_hidden_size); + + auto perm = g->nodes()->create<luci::CircleConst>(); + perm->dtype(loco::DataType::S32); + perm->size<loco::DataType::S32>(2); + perm->rank(1); + perm->dim(0) = 2; + perm->at<loco::DataType::S32>(0) = 1; + perm->at<loco::DataType::S32>(1) = 0; + perm->shape_status(luci::ShapeStatus::VALID); + + auto input_transpose = g->nodes()->create<luci::CircleTranspose>(); + input_transpose->a(bcq_input); + input_transpose->perm(perm); + + bcq_fc->input(input_transpose); + + auto output_transpose = g->nodes()->create<luci::CircleTranspose>(); + output_transpose->a(bcq_fc); + output_transpose->perm(perm); + + loco::replace(fully_connected).with(output_transpose); + } + + changed = true; + } + } + } + + if (changed) + clear_BCQ_nodes(); + + return changed; + } + +private: void add_BCQ_info_node(luci::CircleConst *node) { const auto node_name = node->name(); @@ -119,16 +300,65 @@ public: return has_info; } + /** + * @brief Exclude BCQ information nodes which are used for fusing BCQ operations + * from graph output by using CircleOutputExclude + */ + void clear_BCQ_nodes() + { + auto clear_nodes = [](std::map<std::string, luci::CircleConst *> &nodes) { + for (auto &n : nodes) + { + auto node = n.second; + + for (auto s : loco::succs(node)) + { + if (auto outnode = dynamic_cast<luci::CircleOutput *>(s)) + { + outnode->from(createNoOp(node)); + } + else if (auto reshape_node = dynamic_cast<luci::CircleReshape *>(s)) + { + for (auto o : loco::succs(reshape_node)) + { + auto circle_output = loco::must_cast<luci::CircleOutput *>(o); + circle_output->from(createNoOp(reshape_node)); + } + } + } + } + }; + + clear_nodes(_do_w_x); + clear_nodes(_alpha); + clear_nodes(_packed_binary_code); + clear_nodes(_number_of_clusters); + clear_nodes(_size_of_clusters); + clear_nodes(_qbits_of_clusters); + clear_nodes(_dequant_weight); + } + + bool is_bcqinfo_valid() + { + // do_w_x should be int32 or bool type + for (auto n : _do_w_x) + { + if (n.second->dtype() != loco::DataType::BOOL && n.second->dtype() != loco::DataType::S32) + return false; + } + + return true; + } + +private: bool do_w_x(luci::CircleConst *node) { const auto prefix = node_name_prefix(node->name()); if (_do_w_x[prefix]->dtype() == loco::DataType::S32) return _do_w_x[prefix]->at<loco::DataType::S32>(0) == 1; - else if (_do_w_x[prefix]->dtype() == loco::DataType::BOOL) - return _do_w_x[prefix]->at<loco::DataType::BOOL>(0); else - throw std::runtime_error("do_w_x should be int or bool"); + return _do_w_x[prefix]->at<loco::DataType::BOOL>(0); } luci::CircleConst *get_alpha(luci::CircleConst *node) @@ -187,64 +417,6 @@ public: return packed_clusters; } - /** - * @brief Exclude BCQ information nodes which are used for fusing BCQ operations - * from graph output by using CircleOutputExclude - */ - void clear_BCQ_nodes() - { - auto createNoOp = [](luci::CircleNode *circle_node) { - auto graph = circle_node->graph(); - auto noOp = graph->nodes()->create<luci::CircleOutputExclude>(); - - if (circle_node->shape_status() == luci::ShapeStatus::VALID) - { - noOp->dtype(circle_node->dtype()); - noOp->rank(circle_node->rank()); - for (uint32_t i = 0; i < circle_node->rank(); ++i) - noOp->dim(i) = circle_node->dim(i); - } - else - { - // For type inference - noOp->dtype(loco::DataType::FLOAT32); - } - - return noOp; - }; - - auto clear_nodes = [createNoOp](std::map<std::string, luci::CircleConst *> &nodes) { - for (auto &n : nodes) - { - auto node = n.second; - - for (auto s : loco::succs(node)) - { - if (auto outnode = dynamic_cast<luci::CircleOutput *>(s)) - { - outnode->from(createNoOp(node)); - } - else if (auto reshape_node = dynamic_cast<luci::CircleReshape *>(s)) - { - for (auto o : loco::succs(reshape_node)) - { - auto circle_output = loco::must_cast<luci::CircleOutput *>(o); - circle_output->from(createNoOp(reshape_node)); - } - } - } - } - }; - - clear_nodes(_do_w_x); - clear_nodes(_alpha); - clear_nodes(_packed_binary_code); - clear_nodes(_number_of_clusters); - clear_nodes(_size_of_clusters); - clear_nodes(_qbits_of_clusters); - clear_nodes(_dequant_weight); - } - private: std::map<std::string, luci::CircleConst *> _do_w_x; std::map<std::string, luci::CircleConst *> _alpha; @@ -262,143 +434,42 @@ namespace luci bool FuseBCQPass::run(loco::Graph *g) { - BCQConverter converter; - bool changed = false; + // Find BCQ version information and check validity. + luci::CircleConst *version_node = nullptr; for (auto node : loco::all_nodes(g)) { if (auto circle_const = dynamic_cast<luci::CircleConst *>(node)) { - converter.add_BCQ_info_node(circle_const); - } - } - - for (auto node : loco::active_nodes(loco::output_nodes(g))) - { - if (auto gather = dynamic_cast<luci::CircleGather *>(node)) - { - auto params = dynamic_cast<luci::CircleConst *>(gather->params()); - if (params != nullptr && converter.has_BCQ_info(params)) + if (circle_const->name().find("/bcqinfo_version") != std::string::npos) { - auto bcq_gather = g->nodes()->create<luci::CircleBCQGather>(); - - bcq_gather->input_scales(converter.get_alpha(params)); - bcq_gather->input_binary(converter.get_packed_binary_code(params)); - bcq_gather->indices(gather->indices()); - bcq_gather->input_clusters(converter.packed_clusters(params)); - - const auto binary_hidden_size = - loco::must_cast<luci::CircleConst *>(bcq_gather->input_binary())->dim(1).value() * 32; - bcq_gather->input_hidden_size(binary_hidden_size); - - if (converter.do_w_x(params)) - { - bcq_gather->axis(gather->axis()); - } - else + // There should be only one bcqinfo_version in the model + if (version_node != nullptr) { - const auto axis_transpose = (gather->axis() == 0) ? 1 : 0; - bcq_gather->axis(axis_transpose); + assert(false && "Multiple version information found"); + return false; } - loco::replace(gather).with(bcq_gather); - - changed = true; + version_node = circle_const; } } - else if (auto fully_connected = dynamic_cast<luci::CircleFullyConnected *>(node)) - { - auto weights = dynamic_cast<luci::CircleConst *>(fully_connected->weights()); - if (weights != nullptr && converter.has_BCQ_info(weights)) - { - auto bcq_fc = g->nodes()->create<luci::CircleBCQFullyConnected>(); - - bcq_fc->weights_scales(converter.get_alpha(weights)); - bcq_fc->weights_binary(converter.get_packed_binary_code(weights)); - bcq_fc->bias(fully_connected->bias()); - bcq_fc->weights_clusters(converter.packed_clusters(weights)); - bcq_fc->fusedActivationFunction(fully_connected->fusedActivationFunction()); - - loco::Node *bcq_input = fully_connected->input(); - int32_t batch_rank = 0; + } - // If input of BCQFullyConnected has more than rank 2, we should reshape it as rank 2 - const auto original_input = loco::must_cast<luci::CircleNode *>(fully_connected->input()); - if (original_input->shape_status() == ShapeStatus::VALID && original_input->rank() > 2) - { - auto new_shape = g->nodes()->create<luci::CircleConst>(); - new_shape->dtype(loco::DataType::S32); - new_shape->size<loco::DataType::S32>(2); - new_shape->rank(1); - new_shape->dim(0) = 2; - - auto batch_size = 1; - for (uint32_t i = 0; i < original_input->rank() - 1; ++i) - batch_size *= original_input->dim(i).value(); - - new_shape->at<loco::DataType::S32>(0) = batch_size; - new_shape->at<loco::DataType::S32>(1) = - original_input->dim(original_input->rank() - 1).value(); - new_shape->shape_status(ShapeStatus::VALID); - - auto reshape = g->nodes()->create<luci::CircleReshape>(); - reshape->tensor(original_input); - reshape->shape(new_shape); - - bcq_input = reshape; - batch_rank = original_input->rank() - 2; - } + // If version node is not found, regard it as version 1. + int32_t bcq_version = (version_node != nullptr) ? version_node->at<loco::DataType::S32>(0) : 1; - // If x_w formation, we should insert Transpose in front and back of BCQFullyConnected - if (converter.do_w_x(weights)) - { - const auto binary_hidden_size = - loco::must_cast<luci::CircleNode *>(fully_connected->input()) - ->dim(batch_rank) - .value(); - bcq_fc->weights_hidden_size(binary_hidden_size); - bcq_fc->input(bcq_input); - loco::replace(fully_connected).with(bcq_fc); - } - else - { - const auto binary_hidden_size = - loco::must_cast<luci::CircleNode *>(fully_connected->input()) - ->dim(1 + batch_rank) - .value(); - bcq_fc->weights_hidden_size(binary_hidden_size); - - auto perm = g->nodes()->create<luci::CircleConst>(); - perm->dtype(loco::DataType::S32); - perm->size<loco::DataType::S32>(2); - perm->rank(1); - perm->dim(0) = 2; - perm->at<loco::DataType::S32>(0) = 1; - perm->at<loco::DataType::S32>(1) = 0; - perm->shape_status(ShapeStatus::VALID); - - auto input_transpose = g->nodes()->create<luci::CircleTranspose>(); - input_transpose->a(bcq_input); - input_transpose->perm(perm); - - bcq_fc->input(input_transpose); - - auto output_transpose = g->nodes()->create<luci::CircleTranspose>(); - output_transpose->a(bcq_fc); - output_transpose->perm(perm); - - loco::replace(fully_connected).with(output_transpose); - } + if (bcq_version == 1) + changed = BCQFuser<1>().fuseBCQ(g); + else + assert(false && "Not supported BCQ version"); - changed = true; - } - } + if (changed && version_node != nullptr) + { + // If BCQ is applied and version node was found, remove the node. + loco::replace(version_node).with(createNoOp(version_node)); } - if (changed) - converter.clear_BCQ_nodes(); - return changed; } diff --git a/compiler/luci/pass/src/QuantizationUtils.cpp b/compiler/luci/pass/src/QuantizationUtils.cpp index 6726ce746..e18690605 100644 --- a/compiler/luci/pass/src/QuantizationUtils.cpp +++ b/compiler/luci/pass/src/QuantizationUtils.cpp @@ -24,6 +24,13 @@ namespace luci { +uint8_t fp32_to_uint8_cast(float f) +{ + assert(std::numeric_limits<uint8_t>::min() <= f); + assert(f <= std::numeric_limits<uint8_t>::max()); + return static_cast<uint8_t>(f); +} + void compute_sym_scale_zp(float min, float max, float &scaling_factor, int64_t &zp, float &nudged_min, float &nudged_max) { @@ -78,7 +85,7 @@ void compute_asym_scale_zp(float min, float max, float &scaling_factor, int64_t } else zero_point_double = qmin_double - rmin / scale; - if (zero_point_double <= qmin_double) + if (min >= 0) { assert(min >= 0 && max >= 0); nudged_zero_point = kMinScale; @@ -86,7 +93,7 @@ void compute_asym_scale_zp(float min, float max, float &scaling_factor, int64_t if (min > 0 && max > 0) WARN(l) << "The minimum and maximum values are all positive." << std::endl; } - else if (zero_point_double >= qmax_double) + else if (max < 0) { assert(min < 0 && max < 0); nudged_zero_point = kMaxScale; @@ -96,7 +103,14 @@ void compute_asym_scale_zp(float min, float max, float &scaling_factor, int64_t else { assert(min < 0 && max >= 0); - nudged_zero_point = static_cast<uint8_t>(std::round(zero_point_double)); + nudged_zero_point = fp32_to_uint8_cast(std::round(zero_point_double)); + } + + // protect scale from being very low due to overflow + if (scale < 1e-5) + { + scale = 1e-5; + nudged_zero_point = fp32_to_uint8_cast(std::round(qmin_double - rmin / scale)); } nudged_min = static_cast<float>((qmin_double - nudged_zero_point) * scale); diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp index f8abee751..b335a53b4 100644 --- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp +++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp @@ -138,7 +138,8 @@ bool is_quantized(const CircleNode *node) node->dtype() == loco::DataType::S32; // bias } -void sym_wquant_per_channel(CircleConst *node, std::vector<float> &scaling_factor) +void sym_wquant_per_channel(CircleConst *node, std::vector<float> &scaling_factor, + int32_t &channel_dim_index) { assert(node->dtype() == loco::DataType::FLOAT32); @@ -153,7 +154,6 @@ void sym_wquant_per_channel(CircleConst *node, std::vector<float> &scaling_facto uint32_t indices[4] = { 0, }; - int channel_dim_index{0}; if (!get_channel_dim_index(node, dimension, channel_dim_index)) { @@ -189,7 +189,7 @@ void sym_wquant_per_channel(CircleConst *node, std::vector<float> &scaling_facto } void asym_wquant_per_channel(CircleConst *node, std::vector<float> &min, - std::vector<float> &scaling_factor) + std::vector<float> &scaling_factor, int32_t &channel_dim_index) { assert(node->dtype() == loco::DataType::FLOAT32); @@ -204,7 +204,6 @@ void asym_wquant_per_channel(CircleConst *node, std::vector<float> &min, uint32_t indices[4] = { 0, }; - int channel_dim_index{0}; if (!get_channel_dim_index(node, dimension, channel_dim_index)) { @@ -282,6 +281,10 @@ bool is_weights(CircleNode *node) if (dw_conv != nullptr && dw_conv->filter() == circle_const) return true; + auto t_conv = dynamic_cast<CircleTransposeConv *>(out); + if (t_conv != nullptr && t_conv->filter() == circle_const && circle_const->rank() == 4) + return true; + auto fc = dynamic_cast<CircleFullyConnected *>(out); if (fc != nullptr && fc->weights() == circle_const) return true; @@ -350,8 +353,8 @@ struct QuantizeActivation final : public luci::CircleNodeMutableVisitor<bool> circle_node->dtype(loco::DataType::S16); } - circle_node->quantparam()->max[0] = nudged_max; - circle_node->quantparam()->min[0] = nudged_min; + circle_node->quantparam()->min.clear(); + circle_node->quantparam()->max.clear(); circle_node->quantparam()->scale.push_back(scaling_factor); circle_node->quantparam()->zerop.push_back(zp); } @@ -472,15 +475,19 @@ struct QuantizeWeights final : public luci::CircleNodeMutableVisitor<bool> assert(quantparam != nullptr); auto min = quantparam->min; auto scaling_factor = quantparam->scale; + int32_t channel_dim_index = 0; if (output_type == loco::DataType::U8) { - asym_wquant_per_channel(circle_const, min, scaling_factor); + asym_wquant_per_channel(circle_const, min, scaling_factor, channel_dim_index); } else { - sym_wquant_per_channel(circle_const, scaling_factor); + sym_wquant_per_channel(circle_const, scaling_factor, channel_dim_index); } + quantparam->min.clear(); + quantparam->max.clear(); + quantparam->quantized_dimension = channel_dim_index; } // Find min/max per layer-wise else @@ -493,6 +500,8 @@ struct QuantizeWeights final : public luci::CircleNodeMutableVisitor<bool> auto min = quantparam->min[0]; auto scaling_factor = quantparam->scale[0]; asym_wquant_per_layer(circle_const, min, scaling_factor); + quantparam->min.clear(); + quantparam->max.clear(); } } } |