diff options
Diffstat (limited to 'compiler/luci/pass/src/FuseBCQPass.cpp')
-rw-r--r-- | compiler/luci/pass/src/FuseBCQPass.cpp | 560 |
1 files changed, 295 insertions, 265 deletions
diff --git a/compiler/luci/pass/src/FuseBCQPass.cpp b/compiler/luci/pass/src/FuseBCQPass.cpp index 7aa2e3e80..ebf28779b 100644 --- a/compiler/luci/pass/src/FuseBCQPass.cpp +++ b/compiler/luci/pass/src/FuseBCQPass.cpp @@ -17,163 +17,139 @@ #include "luci/Pass/FuseBCQPass.h" #include <luci/IR/CircleNodes.h> +#include <luci/Log.h> #include <cassert> -#include <string> #include <set> namespace { -/** - * @brief Circle nodes including BCQ information and a circle node to which BCQ will be applied - * are connected with their name. And their names include common prefix. - * However, after pb file is converted to tflite file, some nodes' name are changed. - * Thus this function will return original common prefix. - * - * @note All the re-naming rule of TFLite converter is not figured out. - * Therefore, if new naming rule is detected, this function should be updated. - */ -const std::string node_name_prefix(luci::NodeName node_name) -{ - std::string prefix = node_name; - - if (prefix.find("/ReadVariableOp/resource") != std::string::npos) - { - const auto start_index = prefix.find("/ReadVariableOp/resource"); - - const auto left_prefix = prefix.substr(0, start_index); - const auto right_prefix = prefix.substr(start_index + 24); - - prefix = left_prefix + right_prefix; - } - - if (prefix.find("Tensordot/") != std::string::npos) - { - const auto index = prefix.find("Tensordot/"); - prefix = prefix.substr(0, index - 1); - } - else if (prefix.find("/MatMul") != std::string::npos) - { - const auto index = prefix.find("/MatMul"); - prefix = prefix.substr(0, index); - } - else if (prefix.find("kernel/") != std::string::npos) - { - const auto index = prefix.find("kernel/"); - prefix = prefix.substr(0, index - 1); - } - else if (prefix.find("/bcqinfo_") != std::string::npos) - { - const auto index = prefix.find("/bcqinfo_"); - prefix = prefix.substr(0, index); - } - - return prefix; -} - -/** - * @brief Create CircleOutputExclude operation, which has same shape and dtype with - * original circle_node. - */ -luci::CircleOutputExclude *createNoOp(luci::CircleNode *circle_node) -{ - auto graph = circle_node->graph(); - auto noOp = graph->nodes()->create<luci::CircleOutputExclude>(); - - if (circle_node->shape_status() == luci::ShapeStatus::VALID) - { - noOp->dtype(circle_node->dtype()); - noOp->rank(circle_node->rank()); - for (uint32_t i = 0; i < circle_node->rank(); ++i) - noOp->dim(i) = circle_node->dim(i); - } - else - { - // For type inference - noOp->dtype(loco::DataType::FLOAT32); - } - - return noOp; -}; - -} // namespace - -namespace -{ - // V means the version of BCQ. template <int32_t V> class BCQFuser; template <> class BCQFuser<1> { public: + BCQFuser<1>(int32_t original_output_cnt, int32_t bundle_cnt) + : _original_output_cnt{original_output_cnt}, _bundle_cnt{bundle_cnt} + { + // Do nothing + } + +public: bool fuseBCQ(loco::Graph *g) { - bool changed = false; - for (auto node : loco::all_nodes(g)) + const auto output_nodes = loco::output_nodes(g); + for (auto node : output_nodes) { - if (auto circle_const = dynamic_cast<luci::CircleConst *>(node)) + auto output_node = loco::must_cast<luci::CircleOutput *>(node); + + /** + * First output of model is metadata for BCQ. Please refer to following example. + * + * When original_output_cnt is 2, + * BCQ_METADATA, original_output_1, original_output_2, BCQ_INFO_1, ... + */ + if ((int)output_node->index() > _original_output_cnt) { - add_BCQ_info_node(circle_const); + const auto prefix = (output_node->index() - (_original_output_cnt + 1)) / (_bundle_cnt); + const MetadataType metadata_type = static_cast<MetadataType>( + (output_node->index() - (_original_output_cnt + 1)) % (_bundle_cnt)); + const auto circle_node = loco::must_cast<luci::CircleNode *>(output_node->from()); + add_BCQ_info_node(prefix, metadata_type, circle_node); } } if (!is_bcqinfo_valid()) return false; - for (auto node : loco::active_nodes(loco::output_nodes(g))) + for (auto f : _fusable_op) { + auto prefix = f.first; + luci::CircleNode *node = f.second; + + if (!is_valid_prefix(prefix)) + continue; + + // Fuse Gather to BCQGather if (auto gather = dynamic_cast<luci::CircleGather *>(node)) { - auto params = dynamic_cast<luci::CircleConst *>(gather->params()); - if (params != nullptr && has_BCQ_info(params)) + if (auto params = dynamic_cast<luci::CircleConst *>(gather->params())) { auto bcq_gather = g->nodes()->create<luci::CircleBCQGather>(); bcq_gather->op_version(1); - bcq_gather->input_scales(get_alpha(params)); - bcq_gather->input_binary(get_packed_binary_code(params)); + bcq_gather->input_scales(_alpha[prefix]); + bcq_gather->input_binary(_packed_binary_code[prefix]); bcq_gather->indices(gather->indices()); - bcq_gather->input_clusters(packed_clusters(params)); + bcq_gather->input_clusters(packed_clusters(g, prefix)); - // input_binary shape : [output_size, hidden_size] - const auto binary_hidden_size = - loco::must_cast<luci::CircleConst *>(bcq_gather->input_binary())->dim(1).value() * 32; - bcq_gather->input_hidden_size(binary_hidden_size); - - if (do_w_x(params)) + if (_do_w_x[prefix]->at<loco::DataType::BOOL>(0)) { + bcq_gather->input_hidden_size(params->dim(1).value()); bcq_gather->axis(gather->axis()); + loco::replace(gather).with(bcq_gather); } else { + bcq_gather->input_hidden_size(params->dim(0).value()); const auto axis_transpose = (gather->axis() == 0) ? 1 : 0; bcq_gather->axis(axis_transpose); + + const auto indices_rank = + loco::must_cast<luci::CircleNode *>(gather->indices())->rank(); + + auto perm = g->nodes()->create<luci::CircleConst>(); + perm->dtype(loco::DataType::S32); + perm->size<loco::DataType::S32>(1 + indices_rank); + perm->rank(1); + perm->dim(0) = 1 + indices_rank; + for (uint32_t idx = 0; idx < indices_rank; ++idx) + perm->at<loco::DataType::S32>(idx) = idx + 1; + perm->at<loco::DataType::S32>(indices_rank) = 0; + perm->shape_status(luci::ShapeStatus::VALID); + + auto output_transpose = g->nodes()->create<luci::CircleTranspose>(); + output_transpose->a(bcq_gather); + output_transpose->perm(perm); + + loco::replace(gather).with(output_transpose); } - loco::replace(gather).with(bcq_gather); + return true; + } + } - changed = true; + // Einsum is unpacked to FullyConnected, Pack and Reshape + if (auto reshape = dynamic_cast<luci::CircleReshape *>(node)) + { + node = dynamic_cast<luci::CircleNode *>(reshape->tensor()); + } + if (auto pack = dynamic_cast<luci::CirclePack *>(node)) + { + if (pack->values_count() == 1 && pack->rank() == 3) + { + node = dynamic_cast<luci::CircleNode *>(pack->values(0)); } } - else if (auto fully_connected = dynamic_cast<luci::CircleFullyConnected *>(node)) + + // Fuse FullyConnected to BCQFullyConnected + if (auto fully_connected = dynamic_cast<luci::CircleFullyConnected *>(node)) { - auto weights = dynamic_cast<luci::CircleConst *>(fully_connected->weights()); - if (weights != nullptr && has_BCQ_info(weights)) + if (auto weights = dynamic_cast<luci::CircleConst *>(fully_connected->weights())) { auto bcq_fc = g->nodes()->create<luci::CircleBCQFullyConnected>(); bcq_fc->op_version(1); - bcq_fc->weights_scales(get_alpha(weights)); - bcq_fc->weights_binary(get_packed_binary_code(weights)); + bcq_fc->weights_scales(_alpha[prefix]); + bcq_fc->weights_binary(_packed_binary_code[prefix]); bcq_fc->bias(fully_connected->bias()); - bcq_fc->weights_clusters(packed_clusters(weights)); + bcq_fc->weights_clusters(packed_clusters(g, prefix)); bcq_fc->fusedActivationFunction(fully_connected->fusedActivationFunction()); loco::Node *bcq_input = fully_connected->input(); - int32_t batch_rank = 0; // If input of BCQFullyConnected has more than rank 2, we should reshape it as rank 2 const auto original_input = loco::must_cast<luci::CircleNode *>(fully_connected->input()); @@ -200,27 +176,18 @@ public: reshape->shape(new_shape); bcq_input = reshape; - batch_rank = original_input->rank() - 2; } // If x_w formation, we should insert Transpose in front and back of BCQFullyConnected - if (do_w_x(weights)) + if (_do_w_x[prefix]->at<loco::DataType::BOOL>(0)) { - const auto binary_hidden_size = - loco::must_cast<luci::CircleNode *>(fully_connected->input()) - ->dim(batch_rank) - .value(); - bcq_fc->weights_hidden_size(binary_hidden_size); + bcq_fc->weights_hidden_size(weights->dim(0).value()); bcq_fc->input(bcq_input); loco::replace(fully_connected).with(bcq_fc); } else { - const auto binary_hidden_size = - loco::must_cast<luci::CircleNode *>(fully_connected->input()) - ->dim(1 + batch_rank) - .value(); - bcq_fc->weights_hidden_size(binary_hidden_size); + bcq_fc->weights_hidden_size(weights->dim(1).value()); auto perm = g->nodes()->create<luci::CircleConst>(); perm->dtype(loco::DataType::S32); @@ -244,159 +211,183 @@ public: loco::replace(fully_connected).with(output_transpose); } - changed = true; + return true; + } + else + { + // TODO Is there any case that input() is constant, instead of weights()? } } } - if (changed) - clear_BCQ_nodes(); - - return changed; + return false; } private: - void add_BCQ_info_node(luci::CircleConst *node) + enum MetadataType { - const auto node_name = node->name(); - const auto prefix = node_name_prefix(node_name); - - // If bcqinfo_* nodes are held by Reshape operation, - // shape of bcqinfo_* nodes are copied to `shape` input of Reshape operation. - // Then the name becomes bcqinfo_*_copy_shape. - // We should prevent this node not to added to bcq information. - if (node_name.find("_copy_shape") != std::string::npos) + DO_W_X, + ALPHA, + BINARY_CODE, + NUM_OF_CLUSTERS, + SIZE_OF_CLUSTERS, + QBITS_OF_CLUSTERS, + FUSABLE_OP, + DEQUANT_WEIGHT, + }; + + void add_BCQ_info_node(int32_t prefix, MetadataType metadata_type, luci::CircleNode *node) + { + if (metadata_type == MetadataType::FUSABLE_OP) + { + _fusable_op[prefix] = node; return; + } - if (node_name.find("bcqinfo_do_w_x") != std::string::npos) - _do_w_x[prefix] = node; - else if (node_name.find("bcqinfo_alpha") != std::string::npos) - _alpha[prefix] = node; - else if (node_name.find("bcqinfo_packed_binary_code") != std::string::npos) - _packed_binary_code[prefix] = node; - else if (node_name.find("bcqinfo_number_of_clusters") != std::string::npos) - _number_of_clusters[prefix] = node; - else if (node_name.find("bcqinfo_size_of_clusters") != std::string::npos) - _size_of_clusters[prefix] = node; - else if (node_name.find("bcqinfo_qbits_of_clusters") != std::string::npos) - _qbits_of_clusters[prefix] = node; - else if (node_name.find("bcqinfo_dequant_weight") != std::string::npos) - _dequant_weight[prefix] = node; - } + luci::CircleConst *const_node; - bool has_BCQ_info(luci::CircleConst *node) - { - const auto prefix = node_name_prefix(node->name()); - bool has_info = true; - - has_info &= (_do_w_x.find(prefix) != _do_w_x.end()); - has_info &= (_alpha.find(prefix) != _alpha.end()); - has_info &= (_packed_binary_code.find(prefix) != _packed_binary_code.end()); - has_info &= (_number_of_clusters.find(prefix) != _number_of_clusters.end()); - has_info &= (_size_of_clusters.find(prefix) != _size_of_clusters.end()); - has_info &= (_qbits_of_clusters.find(prefix) != _qbits_of_clusters.end()); - // bcqinfo_dequant_weight is just for validation, so not always exists. - - return has_info; + // Converter in TensorFlow v1.x sometimes generate Reshape op + if (auto reshape = dynamic_cast<luci::CircleReshape *>(node)) + const_node = loco::must_cast<luci::CircleConst *>(reshape->tensor()); + else + const_node = loco::must_cast<luci::CircleConst *>(node); + + if (metadata_type == MetadataType::DO_W_X) + _do_w_x[prefix] = const_node; + else if (metadata_type == MetadataType::ALPHA) + _alpha[prefix] = const_node; + else if (metadata_type == MetadataType::BINARY_CODE) + _packed_binary_code[prefix] = const_node; + else if (metadata_type == MetadataType::NUM_OF_CLUSTERS) + _number_of_clusters[prefix] = const_node; + else if (metadata_type == MetadataType::SIZE_OF_CLUSTERS) + _size_of_clusters[prefix] = const_node; + else if (metadata_type == MetadataType::QBITS_OF_CLUSTERS) + _qbits_of_clusters[prefix] = const_node; + else + _dequant_weight[prefix] = const_node; } - /** - * @brief Exclude BCQ information nodes which are used for fusing BCQ operations - * from graph output by using CircleOutputExclude - */ - void clear_BCQ_nodes() + bool is_bcqinfo_valid() { - auto clear_nodes = [](std::map<std::string, luci::CircleConst *> &nodes) { - for (auto &n : nodes) + LOGGER(l); + + for (auto n : _do_w_x) + { + // do_w_x should be BOOL type + if (n.second->dtype() != loco::DataType::BOOL) { - auto node = n.second; + WARN(l) << "FuseBCQPass : do_w_x has wrong type" << std::endl; + return false; + } + } - for (auto s : loco::succs(node)) - { - if (auto outnode = dynamic_cast<luci::CircleOutput *>(s)) - { - outnode->from(createNoOp(node)); - } - else if (auto reshape_node = dynamic_cast<luci::CircleReshape *>(s)) - { - for (auto o : loco::succs(reshape_node)) - { - auto circle_output = loco::must_cast<luci::CircleOutput *>(o); - circle_output->from(createNoOp(reshape_node)); - } - } - } + for (auto n : _alpha) + { + // alpha should be FLOAT32 type + if (n.second->dtype() != loco::DataType::FLOAT32) + { + WARN(l) << "FuseBCQPass : alpha has wrong type" << std::endl; + return false; } - }; - - clear_nodes(_do_w_x); - clear_nodes(_alpha); - clear_nodes(_packed_binary_code); - clear_nodes(_number_of_clusters); - clear_nodes(_size_of_clusters); - clear_nodes(_qbits_of_clusters); - clear_nodes(_dequant_weight); - } + } - bool is_bcqinfo_valid() - { - // do_w_x should be int32 or bool type - for (auto n : _do_w_x) + for (auto n : _packed_binary_code) + { + // packed_binary_code should be INT32 type + if (n.second->dtype() != loco::DataType::S32) + { + WARN(l) << "FuseBCQPass : packed_binary_code has wrong type" << std::endl; + return false; + } + } + + for (auto n : _number_of_clusters) { - if (n.second->dtype() != loco::DataType::BOOL && n.second->dtype() != loco::DataType::S32) + // number_of_clusters should be INT32 type + if (n.second->dtype() != loco::DataType::S32) + { + WARN(l) << "FuseBCQPass : number_of_clusters has wrong type" << std::endl; return false; + } } + for (auto n : _size_of_clusters) + { + // size_of_clusters should be INT32 type + if (n.second->dtype() != loco::DataType::S32) + { + WARN(l) << "FuseBCQPass : size_of_clusters has wrong type" << std::endl; + return false; + } + } + + for (auto n : _qbits_of_clusters) + { + // qbits_of_clusters should be INT32 type + if (n.second->dtype() != loco::DataType::S32) + { + WARN(l) << "FuseBCQPass : qbits_of_clusters has wrong type" << std::endl; + return false; + } + } + + // As dequant_weight is not used for fusing, skip validation. + return true; } -private: - bool do_w_x(luci::CircleConst *node) + bool is_valid_prefix(int32_t prefix) { - const auto prefix = node_name_prefix(node->name()); + LOGGER(l); - if (_do_w_x[prefix]->dtype() == loco::DataType::S32) - return _do_w_x[prefix]->at<loco::DataType::S32>(0) == 1; - else - return _do_w_x[prefix]->at<loco::DataType::BOOL>(0); - } + if (_do_w_x.find(prefix) == _do_w_x.end()) + { + WARN(l) << "do_w_x is not found" << std::endl; + return false; + } - luci::CircleConst *get_alpha(luci::CircleConst *node) - { - const auto prefix = node_name_prefix(node->name()); - return _alpha[prefix]; - } + if (_alpha.find(prefix) == _alpha.end()) + { + WARN(l) << "alpha is not found" << std::endl; + return false; + } - luci::CircleConst *get_packed_binary_code(luci::CircleConst *node) - { - const auto prefix = node_name_prefix(node->name()); - return _packed_binary_code[prefix]; - } + if (_packed_binary_code.find(prefix) == _packed_binary_code.end()) + { + WARN(l) << "packed_binary_code is not found" << std::endl; + return false; + } - luci::CircleConst *get_number_of_clusters(luci::CircleConst *node) - { - const auto prefix = node_name_prefix(node->name()); - return _number_of_clusters[prefix]; - } + if (_number_of_clusters.find(prefix) == _number_of_clusters.end()) + { + WARN(l) << "number_of_clusters is not found" << std::endl; + return false; + } - luci::CircleConst *get_size_of_clusters(luci::CircleConst *node) - { - const auto prefix = node_name_prefix(node->name()); - return _size_of_clusters[prefix]; - } + if (_size_of_clusters.find(prefix) == _size_of_clusters.end()) + { + WARN(l) << "size_of_clusters is not found" << std::endl; + return false; + } - luci::CircleConst *get_qbits_of_clusters(luci::CircleConst *node) - { - const auto prefix = node_name_prefix(node->name()); - return _qbits_of_clusters[prefix]; + if (_qbits_of_clusters.find(prefix) == _qbits_of_clusters.end()) + { + WARN(l) << "qbits_of_clusters is not found" << std::endl; + return false; + } + + // As dequant_weight is not used for fusing, skip validation. + + return true; } - luci::CircleConst *packed_clusters(luci::CircleConst *node) +private: + luci::CircleConst *packed_clusters(loco::Graph *graph, int32_t prefix) { - auto graph = node->graph(); - auto qbits_of_clusters = get_qbits_of_clusters(node); - auto size_of_clusters = get_size_of_clusters(node); - const auto number_of_clusters = get_number_of_clusters(node)->at<loco::DataType::S32>(0); + auto qbits_of_clusters = _qbits_of_clusters[prefix]; + auto size_of_clusters = _size_of_clusters[prefix]; + const auto number_of_clusters = _number_of_clusters[prefix]->at<loco::DataType::S32>(0); auto packed_clusters = graph->nodes()->create<luci::CircleConst>(); packed_clusters->dtype(loco::DataType::S32); @@ -418,13 +409,18 @@ private: } private: - std::map<std::string, luci::CircleConst *> _do_w_x; - std::map<std::string, luci::CircleConst *> _alpha; - std::map<std::string, luci::CircleConst *> _packed_binary_code; - std::map<std::string, luci::CircleConst *> _number_of_clusters; - std::map<std::string, luci::CircleConst *> _size_of_clusters; - std::map<std::string, luci::CircleConst *> _qbits_of_clusters; - std::map<std::string, luci::CircleConst *> _dequant_weight; + std::map<int32_t, luci::CircleConst *> _do_w_x; + std::map<int32_t, luci::CircleConst *> _alpha; + std::map<int32_t, luci::CircleConst *> _packed_binary_code; + std::map<int32_t, luci::CircleConst *> _number_of_clusters; + std::map<int32_t, luci::CircleConst *> _size_of_clusters; + std::map<int32_t, luci::CircleConst *> _qbits_of_clusters; + std::map<int32_t, luci::CircleConst *> _dequant_weight; + std::map<int32_t, luci::CircleNode *> _fusable_op; + +private: + int32_t _original_output_cnt = 0; + int32_t _bundle_cnt = 0; }; } // namespace @@ -436,38 +432,72 @@ bool FuseBCQPass::run(loco::Graph *g) { bool changed = false; - // Find BCQ version information and check validity. - luci::CircleConst *version_node = nullptr; - for (auto node : loco::all_nodes(g)) + const int32_t start_magicnum = -2e9 + 27; + const int32_t end_magicnum = 2e9 - 27; + + luci::CircleConst *metadata_node = nullptr; + for (auto node : loco::output_nodes(g)) { - if (auto circle_const = dynamic_cast<luci::CircleConst *>(node)) + auto output_node = loco::must_cast<luci::CircleOutput *>(node); + + // Metadata node should be first output + if (output_node->index() != 0) + continue; + + // Metadata should be constant and dtype should be S32 + auto const_node = dynamic_cast<luci::CircleConst *>(output_node->from()); + if (const_node == nullptr || const_node->dtype() != loco::DataType::S32) + continue; + + // Metadata has at least four elements + const auto element_cnt = const_node->size<loco::DataType::S32>(); + if (element_cnt < 4) + continue; + + // Metadata has magic numbers at first and at last + const auto start_value = const_node->at<loco::DataType::S32>(0); + const auto end_value = const_node->at<loco::DataType::S32>(element_cnt - 1); + if (start_value == start_magicnum && end_value == end_magicnum) { - if (circle_const->name().find("/bcqinfo_version") != std::string::npos) - { - // There should be only one bcqinfo_version in the model - if (version_node != nullptr) - { - assert(false && "Multiple version information found"); - return false; - } - - version_node = circle_const; - } + metadata_node = const_node; + break; } } - // If version node is not found, regard it as version 1. - int32_t bcq_version = (version_node != nullptr) ? version_node->at<loco::DataType::S32>(0) : 1; + if (metadata_node != nullptr) + { + const auto bcq_version = metadata_node->at<loco::DataType::S32>(1); + const auto original_output_cnt = metadata_node->at<loco::DataType::S32>(2); - if (bcq_version == 1) - changed = BCQFuser<1>().fuseBCQ(g); - else - assert(false && "Not supported BCQ version"); + if (bcq_version == 1) + { + const auto bundle_cnt = metadata_node->at<loco::DataType::S32>(3); - if (changed && version_node != nullptr) - { - // If BCQ is applied and version node was found, remove the node. - loco::replace(version_node).with(createNoOp(version_node)); + BCQFuser<1> fuser{original_output_cnt, bundle_cnt}; + if (fuser.fuseBCQ(g)) + changed = true; + } + else + { + LOGGER(l); + WARN(l) << "Not supported BCQ version is found." << std::endl; + } + + // Remove all of BCQ information nodes iff there is no change + if (changed == false) + { + for (auto node : loco::output_nodes(g)) + { + auto output_node = loco::must_cast<luci::CircleOutput *>(node); + if (output_node->index() == 0 || (int)output_node->index() > original_output_cnt) + { + auto noOp = g->nodes()->create<luci::CircleOutputExclude>(); + noOp->dtype(loco::DataType::FLOAT32); // TODO Remove this setting + output_node->from(noOp); + changed = true; + } + } + } } return changed; |