summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/FuseBCQPass.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/pass/src/FuseBCQPass.cpp')
-rw-r--r--compiler/luci/pass/src/FuseBCQPass.cpp560
1 files changed, 295 insertions, 265 deletions
diff --git a/compiler/luci/pass/src/FuseBCQPass.cpp b/compiler/luci/pass/src/FuseBCQPass.cpp
index 7aa2e3e80..ebf28779b 100644
--- a/compiler/luci/pass/src/FuseBCQPass.cpp
+++ b/compiler/luci/pass/src/FuseBCQPass.cpp
@@ -17,163 +17,139 @@
#include "luci/Pass/FuseBCQPass.h"
#include <luci/IR/CircleNodes.h>
+#include <luci/Log.h>
#include <cassert>
-#include <string>
#include <set>
namespace
{
-/**
- * @brief Circle nodes including BCQ information and a circle node to which BCQ will be applied
- * are connected with their name. And their names include common prefix.
- * However, after pb file is converted to tflite file, some nodes' name are changed.
- * Thus this function will return original common prefix.
- *
- * @note All the re-naming rule of TFLite converter is not figured out.
- * Therefore, if new naming rule is detected, this function should be updated.
- */
-const std::string node_name_prefix(luci::NodeName node_name)
-{
- std::string prefix = node_name;
-
- if (prefix.find("/ReadVariableOp/resource") != std::string::npos)
- {
- const auto start_index = prefix.find("/ReadVariableOp/resource");
-
- const auto left_prefix = prefix.substr(0, start_index);
- const auto right_prefix = prefix.substr(start_index + 24);
-
- prefix = left_prefix + right_prefix;
- }
-
- if (prefix.find("Tensordot/") != std::string::npos)
- {
- const auto index = prefix.find("Tensordot/");
- prefix = prefix.substr(0, index - 1);
- }
- else if (prefix.find("/MatMul") != std::string::npos)
- {
- const auto index = prefix.find("/MatMul");
- prefix = prefix.substr(0, index);
- }
- else if (prefix.find("kernel/") != std::string::npos)
- {
- const auto index = prefix.find("kernel/");
- prefix = prefix.substr(0, index - 1);
- }
- else if (prefix.find("/bcqinfo_") != std::string::npos)
- {
- const auto index = prefix.find("/bcqinfo_");
- prefix = prefix.substr(0, index);
- }
-
- return prefix;
-}
-
-/**
- * @brief Create CircleOutputExclude operation, which has same shape and dtype with
- * original circle_node.
- */
-luci::CircleOutputExclude *createNoOp(luci::CircleNode *circle_node)
-{
- auto graph = circle_node->graph();
- auto noOp = graph->nodes()->create<luci::CircleOutputExclude>();
-
- if (circle_node->shape_status() == luci::ShapeStatus::VALID)
- {
- noOp->dtype(circle_node->dtype());
- noOp->rank(circle_node->rank());
- for (uint32_t i = 0; i < circle_node->rank(); ++i)
- noOp->dim(i) = circle_node->dim(i);
- }
- else
- {
- // For type inference
- noOp->dtype(loco::DataType::FLOAT32);
- }
-
- return noOp;
-};
-
-} // namespace
-
-namespace
-{
-
// V means the version of BCQ.
template <int32_t V> class BCQFuser;
template <> class BCQFuser<1>
{
public:
+ BCQFuser<1>(int32_t original_output_cnt, int32_t bundle_cnt)
+ : _original_output_cnt{original_output_cnt}, _bundle_cnt{bundle_cnt}
+ {
+ // Do nothing
+ }
+
+public:
bool fuseBCQ(loco::Graph *g)
{
- bool changed = false;
- for (auto node : loco::all_nodes(g))
+ const auto output_nodes = loco::output_nodes(g);
+ for (auto node : output_nodes)
{
- if (auto circle_const = dynamic_cast<luci::CircleConst *>(node))
+ auto output_node = loco::must_cast<luci::CircleOutput *>(node);
+
+ /**
+ * First output of model is metadata for BCQ. Please refer to following example.
+ *
+ * When original_output_cnt is 2,
+ * BCQ_METADATA, original_output_1, original_output_2, BCQ_INFO_1, ...
+ */
+ if ((int)output_node->index() > _original_output_cnt)
{
- add_BCQ_info_node(circle_const);
+ const auto prefix = (output_node->index() - (_original_output_cnt + 1)) / (_bundle_cnt);
+ const MetadataType metadata_type = static_cast<MetadataType>(
+ (output_node->index() - (_original_output_cnt + 1)) % (_bundle_cnt));
+ const auto circle_node = loco::must_cast<luci::CircleNode *>(output_node->from());
+ add_BCQ_info_node(prefix, metadata_type, circle_node);
}
}
if (!is_bcqinfo_valid())
return false;
- for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ for (auto f : _fusable_op)
{
+ auto prefix = f.first;
+ luci::CircleNode *node = f.second;
+
+ if (!is_valid_prefix(prefix))
+ continue;
+
+ // Fuse Gather to BCQGather
if (auto gather = dynamic_cast<luci::CircleGather *>(node))
{
- auto params = dynamic_cast<luci::CircleConst *>(gather->params());
- if (params != nullptr && has_BCQ_info(params))
+ if (auto params = dynamic_cast<luci::CircleConst *>(gather->params()))
{
auto bcq_gather = g->nodes()->create<luci::CircleBCQGather>();
bcq_gather->op_version(1);
- bcq_gather->input_scales(get_alpha(params));
- bcq_gather->input_binary(get_packed_binary_code(params));
+ bcq_gather->input_scales(_alpha[prefix]);
+ bcq_gather->input_binary(_packed_binary_code[prefix]);
bcq_gather->indices(gather->indices());
- bcq_gather->input_clusters(packed_clusters(params));
+ bcq_gather->input_clusters(packed_clusters(g, prefix));
- // input_binary shape : [output_size, hidden_size]
- const auto binary_hidden_size =
- loco::must_cast<luci::CircleConst *>(bcq_gather->input_binary())->dim(1).value() * 32;
- bcq_gather->input_hidden_size(binary_hidden_size);
-
- if (do_w_x(params))
+ if (_do_w_x[prefix]->at<loco::DataType::BOOL>(0))
{
+ bcq_gather->input_hidden_size(params->dim(1).value());
bcq_gather->axis(gather->axis());
+ loco::replace(gather).with(bcq_gather);
}
else
{
+ bcq_gather->input_hidden_size(params->dim(0).value());
const auto axis_transpose = (gather->axis() == 0) ? 1 : 0;
bcq_gather->axis(axis_transpose);
+
+ const auto indices_rank =
+ loco::must_cast<luci::CircleNode *>(gather->indices())->rank();
+
+ auto perm = g->nodes()->create<luci::CircleConst>();
+ perm->dtype(loco::DataType::S32);
+ perm->size<loco::DataType::S32>(1 + indices_rank);
+ perm->rank(1);
+ perm->dim(0) = 1 + indices_rank;
+ for (uint32_t idx = 0; idx < indices_rank; ++idx)
+ perm->at<loco::DataType::S32>(idx) = idx + 1;
+ perm->at<loco::DataType::S32>(indices_rank) = 0;
+ perm->shape_status(luci::ShapeStatus::VALID);
+
+ auto output_transpose = g->nodes()->create<luci::CircleTranspose>();
+ output_transpose->a(bcq_gather);
+ output_transpose->perm(perm);
+
+ loco::replace(gather).with(output_transpose);
}
- loco::replace(gather).with(bcq_gather);
+ return true;
+ }
+ }
- changed = true;
+ // Einsum is unpacked to FullyConnected, Pack and Reshape
+ if (auto reshape = dynamic_cast<luci::CircleReshape *>(node))
+ {
+ node = dynamic_cast<luci::CircleNode *>(reshape->tensor());
+ }
+ if (auto pack = dynamic_cast<luci::CirclePack *>(node))
+ {
+ if (pack->values_count() == 1 && pack->rank() == 3)
+ {
+ node = dynamic_cast<luci::CircleNode *>(pack->values(0));
}
}
- else if (auto fully_connected = dynamic_cast<luci::CircleFullyConnected *>(node))
+
+ // Fuse FullyConnected to BCQFullyConnected
+ if (auto fully_connected = dynamic_cast<luci::CircleFullyConnected *>(node))
{
- auto weights = dynamic_cast<luci::CircleConst *>(fully_connected->weights());
- if (weights != nullptr && has_BCQ_info(weights))
+ if (auto weights = dynamic_cast<luci::CircleConst *>(fully_connected->weights()))
{
auto bcq_fc = g->nodes()->create<luci::CircleBCQFullyConnected>();
bcq_fc->op_version(1);
- bcq_fc->weights_scales(get_alpha(weights));
- bcq_fc->weights_binary(get_packed_binary_code(weights));
+ bcq_fc->weights_scales(_alpha[prefix]);
+ bcq_fc->weights_binary(_packed_binary_code[prefix]);
bcq_fc->bias(fully_connected->bias());
- bcq_fc->weights_clusters(packed_clusters(weights));
+ bcq_fc->weights_clusters(packed_clusters(g, prefix));
bcq_fc->fusedActivationFunction(fully_connected->fusedActivationFunction());
loco::Node *bcq_input = fully_connected->input();
- int32_t batch_rank = 0;
// If input of BCQFullyConnected has more than rank 2, we should reshape it as rank 2
const auto original_input = loco::must_cast<luci::CircleNode *>(fully_connected->input());
@@ -200,27 +176,18 @@ public:
reshape->shape(new_shape);
bcq_input = reshape;
- batch_rank = original_input->rank() - 2;
}
// If x_w formation, we should insert Transpose in front and back of BCQFullyConnected
- if (do_w_x(weights))
+ if (_do_w_x[prefix]->at<loco::DataType::BOOL>(0))
{
- const auto binary_hidden_size =
- loco::must_cast<luci::CircleNode *>(fully_connected->input())
- ->dim(batch_rank)
- .value();
- bcq_fc->weights_hidden_size(binary_hidden_size);
+ bcq_fc->weights_hidden_size(weights->dim(0).value());
bcq_fc->input(bcq_input);
loco::replace(fully_connected).with(bcq_fc);
}
else
{
- const auto binary_hidden_size =
- loco::must_cast<luci::CircleNode *>(fully_connected->input())
- ->dim(1 + batch_rank)
- .value();
- bcq_fc->weights_hidden_size(binary_hidden_size);
+ bcq_fc->weights_hidden_size(weights->dim(1).value());
auto perm = g->nodes()->create<luci::CircleConst>();
perm->dtype(loco::DataType::S32);
@@ -244,159 +211,183 @@ public:
loco::replace(fully_connected).with(output_transpose);
}
- changed = true;
+ return true;
+ }
+ else
+ {
+ // TODO Is there any case that input() is constant, instead of weights()?
}
}
}
- if (changed)
- clear_BCQ_nodes();
-
- return changed;
+ return false;
}
private:
- void add_BCQ_info_node(luci::CircleConst *node)
+ enum MetadataType
{
- const auto node_name = node->name();
- const auto prefix = node_name_prefix(node_name);
-
- // If bcqinfo_* nodes are held by Reshape operation,
- // shape of bcqinfo_* nodes are copied to `shape` input of Reshape operation.
- // Then the name becomes bcqinfo_*_copy_shape.
- // We should prevent this node not to added to bcq information.
- if (node_name.find("_copy_shape") != std::string::npos)
+ DO_W_X,
+ ALPHA,
+ BINARY_CODE,
+ NUM_OF_CLUSTERS,
+ SIZE_OF_CLUSTERS,
+ QBITS_OF_CLUSTERS,
+ FUSABLE_OP,
+ DEQUANT_WEIGHT,
+ };
+
+ void add_BCQ_info_node(int32_t prefix, MetadataType metadata_type, luci::CircleNode *node)
+ {
+ if (metadata_type == MetadataType::FUSABLE_OP)
+ {
+ _fusable_op[prefix] = node;
return;
+ }
- if (node_name.find("bcqinfo_do_w_x") != std::string::npos)
- _do_w_x[prefix] = node;
- else if (node_name.find("bcqinfo_alpha") != std::string::npos)
- _alpha[prefix] = node;
- else if (node_name.find("bcqinfo_packed_binary_code") != std::string::npos)
- _packed_binary_code[prefix] = node;
- else if (node_name.find("bcqinfo_number_of_clusters") != std::string::npos)
- _number_of_clusters[prefix] = node;
- else if (node_name.find("bcqinfo_size_of_clusters") != std::string::npos)
- _size_of_clusters[prefix] = node;
- else if (node_name.find("bcqinfo_qbits_of_clusters") != std::string::npos)
- _qbits_of_clusters[prefix] = node;
- else if (node_name.find("bcqinfo_dequant_weight") != std::string::npos)
- _dequant_weight[prefix] = node;
- }
+ luci::CircleConst *const_node;
- bool has_BCQ_info(luci::CircleConst *node)
- {
- const auto prefix = node_name_prefix(node->name());
- bool has_info = true;
-
- has_info &= (_do_w_x.find(prefix) != _do_w_x.end());
- has_info &= (_alpha.find(prefix) != _alpha.end());
- has_info &= (_packed_binary_code.find(prefix) != _packed_binary_code.end());
- has_info &= (_number_of_clusters.find(prefix) != _number_of_clusters.end());
- has_info &= (_size_of_clusters.find(prefix) != _size_of_clusters.end());
- has_info &= (_qbits_of_clusters.find(prefix) != _qbits_of_clusters.end());
- // bcqinfo_dequant_weight is just for validation, so not always exists.
-
- return has_info;
+ // Converter in TensorFlow v1.x sometimes generate Reshape op
+ if (auto reshape = dynamic_cast<luci::CircleReshape *>(node))
+ const_node = loco::must_cast<luci::CircleConst *>(reshape->tensor());
+ else
+ const_node = loco::must_cast<luci::CircleConst *>(node);
+
+ if (metadata_type == MetadataType::DO_W_X)
+ _do_w_x[prefix] = const_node;
+ else if (metadata_type == MetadataType::ALPHA)
+ _alpha[prefix] = const_node;
+ else if (metadata_type == MetadataType::BINARY_CODE)
+ _packed_binary_code[prefix] = const_node;
+ else if (metadata_type == MetadataType::NUM_OF_CLUSTERS)
+ _number_of_clusters[prefix] = const_node;
+ else if (metadata_type == MetadataType::SIZE_OF_CLUSTERS)
+ _size_of_clusters[prefix] = const_node;
+ else if (metadata_type == MetadataType::QBITS_OF_CLUSTERS)
+ _qbits_of_clusters[prefix] = const_node;
+ else
+ _dequant_weight[prefix] = const_node;
}
- /**
- * @brief Exclude BCQ information nodes which are used for fusing BCQ operations
- * from graph output by using CircleOutputExclude
- */
- void clear_BCQ_nodes()
+ bool is_bcqinfo_valid()
{
- auto clear_nodes = [](std::map<std::string, luci::CircleConst *> &nodes) {
- for (auto &n : nodes)
+ LOGGER(l);
+
+ for (auto n : _do_w_x)
+ {
+ // do_w_x should be BOOL type
+ if (n.second->dtype() != loco::DataType::BOOL)
{
- auto node = n.second;
+ WARN(l) << "FuseBCQPass : do_w_x has wrong type" << std::endl;
+ return false;
+ }
+ }
- for (auto s : loco::succs(node))
- {
- if (auto outnode = dynamic_cast<luci::CircleOutput *>(s))
- {
- outnode->from(createNoOp(node));
- }
- else if (auto reshape_node = dynamic_cast<luci::CircleReshape *>(s))
- {
- for (auto o : loco::succs(reshape_node))
- {
- auto circle_output = loco::must_cast<luci::CircleOutput *>(o);
- circle_output->from(createNoOp(reshape_node));
- }
- }
- }
+ for (auto n : _alpha)
+ {
+ // alpha should be FLOAT32 type
+ if (n.second->dtype() != loco::DataType::FLOAT32)
+ {
+ WARN(l) << "FuseBCQPass : alpha has wrong type" << std::endl;
+ return false;
}
- };
-
- clear_nodes(_do_w_x);
- clear_nodes(_alpha);
- clear_nodes(_packed_binary_code);
- clear_nodes(_number_of_clusters);
- clear_nodes(_size_of_clusters);
- clear_nodes(_qbits_of_clusters);
- clear_nodes(_dequant_weight);
- }
+ }
- bool is_bcqinfo_valid()
- {
- // do_w_x should be int32 or bool type
- for (auto n : _do_w_x)
+ for (auto n : _packed_binary_code)
+ {
+ // packed_binary_code should be INT32 type
+ if (n.second->dtype() != loco::DataType::S32)
+ {
+ WARN(l) << "FuseBCQPass : packed_binary_code has wrong type" << std::endl;
+ return false;
+ }
+ }
+
+ for (auto n : _number_of_clusters)
{
- if (n.second->dtype() != loco::DataType::BOOL && n.second->dtype() != loco::DataType::S32)
+ // number_of_clusters should be INT32 type
+ if (n.second->dtype() != loco::DataType::S32)
+ {
+ WARN(l) << "FuseBCQPass : number_of_clusters has wrong type" << std::endl;
return false;
+ }
}
+ for (auto n : _size_of_clusters)
+ {
+ // size_of_clusters should be INT32 type
+ if (n.second->dtype() != loco::DataType::S32)
+ {
+ WARN(l) << "FuseBCQPass : size_of_clusters has wrong type" << std::endl;
+ return false;
+ }
+ }
+
+ for (auto n : _qbits_of_clusters)
+ {
+ // qbits_of_clusters should be INT32 type
+ if (n.second->dtype() != loco::DataType::S32)
+ {
+ WARN(l) << "FuseBCQPass : qbits_of_clusters has wrong type" << std::endl;
+ return false;
+ }
+ }
+
+ // As dequant_weight is not used for fusing, skip validation.
+
return true;
}
-private:
- bool do_w_x(luci::CircleConst *node)
+ bool is_valid_prefix(int32_t prefix)
{
- const auto prefix = node_name_prefix(node->name());
+ LOGGER(l);
- if (_do_w_x[prefix]->dtype() == loco::DataType::S32)
- return _do_w_x[prefix]->at<loco::DataType::S32>(0) == 1;
- else
- return _do_w_x[prefix]->at<loco::DataType::BOOL>(0);
- }
+ if (_do_w_x.find(prefix) == _do_w_x.end())
+ {
+ WARN(l) << "do_w_x is not found" << std::endl;
+ return false;
+ }
- luci::CircleConst *get_alpha(luci::CircleConst *node)
- {
- const auto prefix = node_name_prefix(node->name());
- return _alpha[prefix];
- }
+ if (_alpha.find(prefix) == _alpha.end())
+ {
+ WARN(l) << "alpha is not found" << std::endl;
+ return false;
+ }
- luci::CircleConst *get_packed_binary_code(luci::CircleConst *node)
- {
- const auto prefix = node_name_prefix(node->name());
- return _packed_binary_code[prefix];
- }
+ if (_packed_binary_code.find(prefix) == _packed_binary_code.end())
+ {
+ WARN(l) << "packed_binary_code is not found" << std::endl;
+ return false;
+ }
- luci::CircleConst *get_number_of_clusters(luci::CircleConst *node)
- {
- const auto prefix = node_name_prefix(node->name());
- return _number_of_clusters[prefix];
- }
+ if (_number_of_clusters.find(prefix) == _number_of_clusters.end())
+ {
+ WARN(l) << "number_of_clusters is not found" << std::endl;
+ return false;
+ }
- luci::CircleConst *get_size_of_clusters(luci::CircleConst *node)
- {
- const auto prefix = node_name_prefix(node->name());
- return _size_of_clusters[prefix];
- }
+ if (_size_of_clusters.find(prefix) == _size_of_clusters.end())
+ {
+ WARN(l) << "size_of_clusters is not found" << std::endl;
+ return false;
+ }
- luci::CircleConst *get_qbits_of_clusters(luci::CircleConst *node)
- {
- const auto prefix = node_name_prefix(node->name());
- return _qbits_of_clusters[prefix];
+ if (_qbits_of_clusters.find(prefix) == _qbits_of_clusters.end())
+ {
+ WARN(l) << "qbits_of_clusters is not found" << std::endl;
+ return false;
+ }
+
+ // As dequant_weight is not used for fusing, skip validation.
+
+ return true;
}
- luci::CircleConst *packed_clusters(luci::CircleConst *node)
+private:
+ luci::CircleConst *packed_clusters(loco::Graph *graph, int32_t prefix)
{
- auto graph = node->graph();
- auto qbits_of_clusters = get_qbits_of_clusters(node);
- auto size_of_clusters = get_size_of_clusters(node);
- const auto number_of_clusters = get_number_of_clusters(node)->at<loco::DataType::S32>(0);
+ auto qbits_of_clusters = _qbits_of_clusters[prefix];
+ auto size_of_clusters = _size_of_clusters[prefix];
+ const auto number_of_clusters = _number_of_clusters[prefix]->at<loco::DataType::S32>(0);
auto packed_clusters = graph->nodes()->create<luci::CircleConst>();
packed_clusters->dtype(loco::DataType::S32);
@@ -418,13 +409,18 @@ private:
}
private:
- std::map<std::string, luci::CircleConst *> _do_w_x;
- std::map<std::string, luci::CircleConst *> _alpha;
- std::map<std::string, luci::CircleConst *> _packed_binary_code;
- std::map<std::string, luci::CircleConst *> _number_of_clusters;
- std::map<std::string, luci::CircleConst *> _size_of_clusters;
- std::map<std::string, luci::CircleConst *> _qbits_of_clusters;
- std::map<std::string, luci::CircleConst *> _dequant_weight;
+ std::map<int32_t, luci::CircleConst *> _do_w_x;
+ std::map<int32_t, luci::CircleConst *> _alpha;
+ std::map<int32_t, luci::CircleConst *> _packed_binary_code;
+ std::map<int32_t, luci::CircleConst *> _number_of_clusters;
+ std::map<int32_t, luci::CircleConst *> _size_of_clusters;
+ std::map<int32_t, luci::CircleConst *> _qbits_of_clusters;
+ std::map<int32_t, luci::CircleConst *> _dequant_weight;
+ std::map<int32_t, luci::CircleNode *> _fusable_op;
+
+private:
+ int32_t _original_output_cnt = 0;
+ int32_t _bundle_cnt = 0;
};
} // namespace
@@ -436,38 +432,72 @@ bool FuseBCQPass::run(loco::Graph *g)
{
bool changed = false;
- // Find BCQ version information and check validity.
- luci::CircleConst *version_node = nullptr;
- for (auto node : loco::all_nodes(g))
+ const int32_t start_magicnum = -2e9 + 27;
+ const int32_t end_magicnum = 2e9 - 27;
+
+ luci::CircleConst *metadata_node = nullptr;
+ for (auto node : loco::output_nodes(g))
{
- if (auto circle_const = dynamic_cast<luci::CircleConst *>(node))
+ auto output_node = loco::must_cast<luci::CircleOutput *>(node);
+
+ // Metadata node should be first output
+ if (output_node->index() != 0)
+ continue;
+
+ // Metadata should be constant and dtype should be S32
+ auto const_node = dynamic_cast<luci::CircleConst *>(output_node->from());
+ if (const_node == nullptr || const_node->dtype() != loco::DataType::S32)
+ continue;
+
+ // Metadata has at least four elements
+ const auto element_cnt = const_node->size<loco::DataType::S32>();
+ if (element_cnt < 4)
+ continue;
+
+ // Metadata has magic numbers at first and at last
+ const auto start_value = const_node->at<loco::DataType::S32>(0);
+ const auto end_value = const_node->at<loco::DataType::S32>(element_cnt - 1);
+ if (start_value == start_magicnum && end_value == end_magicnum)
{
- if (circle_const->name().find("/bcqinfo_version") != std::string::npos)
- {
- // There should be only one bcqinfo_version in the model
- if (version_node != nullptr)
- {
- assert(false && "Multiple version information found");
- return false;
- }
-
- version_node = circle_const;
- }
+ metadata_node = const_node;
+ break;
}
}
- // If version node is not found, regard it as version 1.
- int32_t bcq_version = (version_node != nullptr) ? version_node->at<loco::DataType::S32>(0) : 1;
+ if (metadata_node != nullptr)
+ {
+ const auto bcq_version = metadata_node->at<loco::DataType::S32>(1);
+ const auto original_output_cnt = metadata_node->at<loco::DataType::S32>(2);
- if (bcq_version == 1)
- changed = BCQFuser<1>().fuseBCQ(g);
- else
- assert(false && "Not supported BCQ version");
+ if (bcq_version == 1)
+ {
+ const auto bundle_cnt = metadata_node->at<loco::DataType::S32>(3);
- if (changed && version_node != nullptr)
- {
- // If BCQ is applied and version node was found, remove the node.
- loco::replace(version_node).with(createNoOp(version_node));
+ BCQFuser<1> fuser{original_output_cnt, bundle_cnt};
+ if (fuser.fuseBCQ(g))
+ changed = true;
+ }
+ else
+ {
+ LOGGER(l);
+ WARN(l) << "Not supported BCQ version is found." << std::endl;
+ }
+
+ // Remove all of BCQ information nodes iff there is no change
+ if (changed == false)
+ {
+ for (auto node : loco::output_nodes(g))
+ {
+ auto output_node = loco::must_cast<luci::CircleOutput *>(node);
+ if (output_node->index() == 0 || (int)output_node->index() > original_output_cnt)
+ {
+ auto noOp = g->nodes()->create<luci::CircleOutputExclude>();
+ noOp->dtype(loco::DataType::FLOAT32); // TODO Remove this setting
+ output_node->from(noOp);
+ changed = true;
+ }
+ }
+ }
}
return changed;