summaryrefslogtreecommitdiff
path: root/compiler/luci/pass
diff options
context:
space:
mode:
authorChunseok Lee <chunseok.lee@samsung.com>2020-10-28 12:16:55 +0900
committerChunseok Lee <chunseok.lee@samsung.com>2020-10-28 12:16:55 +0900
commitc55f8a6db48cda9d3a78048338b7f18c4cca62b8 (patch)
tree761ee8e171e5203f5c598ad93b2e7e0bc2e31aa2 /compiler/luci/pass
parent74476a2d0296bdad70a2f7f90bc7419a8b05bffd (diff)
downloadnnfw-c55f8a6db48cda9d3a78048338b7f18c4cca62b8.tar.gz
nnfw-c55f8a6db48cda9d3a78048338b7f18c4cca62b8.tar.bz2
nnfw-c55f8a6db48cda9d3a78048338b7f18c4cca62b8.zip
Diffstat (limited to 'compiler/luci/pass')
-rw-r--r--compiler/luci/pass/include/luci/CircleOptimizer.h15
-rw-r--r--compiler/luci/pass/include/luci/Pass/FoldDequantizePass.h38
-rw-r--r--compiler/luci/pass/include/luci/Pass/FuseAddWithTConvPass.h37
-rw-r--r--compiler/luci/pass/include/luci/Pass/SparsifyTensorPass.h69
-rw-r--r--compiler/luci/pass/src/CircleOptimizer.cpp67
-rw-r--r--compiler/luci/pass/src/FoldDequantizePass.cpp206
-rw-r--r--compiler/luci/pass/src/FuseAddWithTConvPass.cpp120
-rw-r--r--compiler/luci/pass/src/FuseBCQPass.cpp560
-rw-r--r--compiler/luci/pass/src/FuseBatchNormWithTConv.cpp20
-rw-r--r--compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp372
-rw-r--r--compiler/luci/pass/src/QuantizationUtils.cpp60
-rw-r--r--compiler/luci/pass/src/QuantizationUtils.h10
-rw-r--r--compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp30
-rw-r--r--compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp294
-rw-r--r--compiler/luci/pass/src/RequantizePass.cpp4
-rw-r--r--compiler/luci/pass/src/Sparsifier.cpp229
-rw-r--r--compiler/luci/pass/src/Sparsifier.h87
-rw-r--r--compiler/luci/pass/src/SparsifyTensorPass.cpp123
18 files changed, 2029 insertions, 312 deletions
diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h
index a832844f8..32ab85ef5 100644
--- a/compiler/luci/pass/include/luci/CircleOptimizer.h
+++ b/compiler/luci/pass/include/luci/CircleOptimizer.h
@@ -32,6 +32,7 @@ public:
{
enum Algorithm
{
+ FuseAddWithTConv,
FuseBatchNormWithTConv,
FuseBCQ,
FuseInstanceNorm,
@@ -41,13 +42,23 @@ public:
QuantizeDequantizeWeights,
QuantizeWithMinMax,
Requantize,
+ FoldDequantize,
+ SparsifyTensorPass,
};
enum AlgorithmParameters
{
+ // quantize
Quantize_input_dtype,
Quantize_output_dtype,
- Quantize_granularity // layer-wise or channel-wise
+ Quantize_granularity, // layer-wise or channel-wise
+
+ // sparsify
+ Sparsify_tensor_name,
+ Sparsify_traversal_order,
+ Sparsify_format,
+ Sparsify_block_size,
+ Sparsify_block_map,
};
virtual ~Options() = default;
@@ -67,6 +78,8 @@ public:
void quantize(loco::Graph *) const;
+ void sparsify(loco::Graph *) const;
+
private:
std::unique_ptr<Options> _options;
};
diff --git a/compiler/luci/pass/include/luci/Pass/FoldDequantizePass.h b/compiler/luci/pass/include/luci/Pass/FoldDequantizePass.h
new file mode 100644
index 000000000..07610d3e1
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/FoldDequantizePass.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_FOLD_DEQUANTIZE_PASS_H__
+#define __LUCI_FOLD_DEQUANTIZE_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to fold Dequantize, which can be folded by constant inputs
+ *
+ */
+struct FoldDequantizePass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::FOLD_DEQUANTIZE"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_FOLD_DEQUANTIZE_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/FuseAddWithTConvPass.h b/compiler/luci/pass/include/luci/Pass/FuseAddWithTConvPass.h
new file mode 100644
index 000000000..89b120397
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/FuseAddWithTConvPass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_FUSE_ADD_WITH_TCONV_PASS_H__
+#define __LUCI_FUSE_ADD_WITH_TCONV_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to fuse Add into CircleTransposeConv
+ */
+struct FuseAddWithTConvPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::FuseAddWithTConvPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_FUSE_ADD_WITH_TCONV_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/SparsifyTensorPass.h b/compiler/luci/pass/include/luci/Pass/SparsifyTensorPass.h
new file mode 100644
index 000000000..41f43bf88
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/SparsifyTensorPass.h
@@ -0,0 +1,69 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_SPARSIFY_TENSOR_PASS_H__
+#define __LUCI_SPARSIFY_TENSOR_PASS_H__
+
+#include <logo/Pass.h>
+
+#include <luci/IR/SparsityParam.h>
+
+namespace luci
+{
+
+class CircleConst;
+
+/**
+ * @brief Pass to sparsify tensor
+ */
+struct SparsifyTensorPass final : public logo::Pass
+{
+public:
+ SparsifyTensorPass(const std::string &tensor_name, const std::vector<int32_t> &traversal_order,
+ const std::vector<DimensionType> &format,
+ const std::vector<int32_t> &block_size, const std::vector<int32_t> &block_map)
+ : _tensor_name{tensor_name}, _traversal_order{traversal_order}, _format{format},
+ _block_size{block_size}, _block_map{block_map}
+ {
+ // DO NOTHING
+ }
+
+public:
+ const char *name(void) const final { return "luci::SparsifyTensorPass"; }
+
+ bool run(loco::Graph *g) final;
+
+ template <loco::DataType DT> void sparsify_tensor(luci::CircleConst *cop);
+
+private:
+ // Tensor name that the pass will sparsify
+ std::string _tensor_name;
+ std::vector<int32_t> _traversal_order;
+ std::vector<DimensionType> _format;
+ std::vector<int32_t> _block_size;
+ std::vector<int32_t> _block_map;
+};
+
+extern template void
+SparsifyTensorPass::sparsify_tensor<loco::DataType::S32>(luci::CircleConst *cop);
+extern template void
+SparsifyTensorPass::sparsify_tensor<loco::DataType::S8>(luci::CircleConst *cop);
+extern template void
+SparsifyTensorPass::sparsify_tensor<loco::DataType::FLOAT32>(luci::CircleConst *cop);
+
+} // namespace luci
+
+#endif // __LUCI_SPARSIFY_TENSOR_PASS_H__
diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp
index 2ee759b4e..0e6056ffc 100644
--- a/compiler/luci/pass/src/CircleOptimizer.cpp
+++ b/compiler/luci/pass/src/CircleOptimizer.cpp
@@ -16,6 +16,8 @@
#include "luci/CircleOptimizer.h"
+#include "luci/Pass/FoldDequantizePass.h"
+#include "luci/Pass/FuseAddWithTConvPass.h"
#include "luci/Pass/FuseBatchNormWithTConv.h"
#include "luci/Pass/FuseBCQPass.h"
#include "luci/Pass/FuseInstanceNormPass.h"
@@ -25,6 +27,7 @@
#include "luci/Pass/RequantizePass.h"
#include "luci/Pass/QuantizeWithMinMaxPass.h"
#include "luci/Pass/QuantizeDequantizeWeightsPass.h"
+#include "luci/Pass/SparsifyTensorPass.h"
// TODO add more passes
#include "luci/Pass/ShapeInferencePass.h"
@@ -40,10 +43,25 @@
#include <logo/Phase.h>
#include <memory>
+#include <sstream>
namespace
{
+std::vector<int> parseIntFromCommadelimitedStr(std::string str)
+{
+ std::vector<int> ret;
+ std::istringstream is(str);
+ for (uint32_t i; is >> i;)
+ {
+ assert(i != ',');
+ ret.push_back(i);
+ if (is.peek() == ',')
+ is.ignore();
+ }
+ return ret;
+}
+
using namespace luci;
class OptimizeOptionsImpl final : public luci::CircleOptimizer::Options
@@ -132,6 +150,14 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<FuseBatchNormWithTConvPass>());
}
+ if (_options->query(Options::Algorithm::FuseAddWithTConv))
+ {
+ phase.emplace_back(std::make_unique<FuseAddWithTConvPass>());
+ }
+ if (_options->query(Options::Algorithm::FoldDequantize))
+ {
+ phase.emplace_back(std::make_unique<luci::FoldDequantizePass>());
+ }
// Shape inference is needed for added nodes doing above transformations
phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
@@ -151,7 +177,7 @@ void CircleOptimizer::quantize(loco::Graph *g) const
if (_options->query(Options::Algorithm::QuantizeDequantizeWeights))
{
static const std::vector<std::string> fakeq_supported_input_dtype{"float32"};
- static const std::vector<std::string> fakeq_supported_output_dtype{"uint8"};
+ static const std::vector<std::string> fakeq_supported_output_dtype{"uint8", "int16"};
static const std::vector<std::string> fakeq_supported_granularity{"layer", "channel"};
auto input_dtype = _options->param(Options::AlgorithmParameters::Quantize_input_dtype);
@@ -187,7 +213,7 @@ void CircleOptimizer::quantize(loco::Graph *g) const
if (_options->query(Options::Algorithm::QuantizeWithMinMax))
{
static const std::vector<std::string> qwmm_supported_input_dtype{"float32"};
- static const std::vector<std::string> qwmm_supported_output_dtype{"uint8"};
+ static const std::vector<std::string> qwmm_supported_output_dtype{"uint8", "int16"};
static const std::vector<std::string> qwmm_supported_granularity{"layer", "channel"};
auto input_dtype = _options->param(Options::AlgorithmParameters::Quantize_input_dtype);
@@ -244,4 +270,41 @@ void CircleOptimizer::quantize(loco::Graph *g) const
phase_runner.run(phase);
}
+void CircleOptimizer::sparsify(loco::Graph *g) const
+{
+ if (_options->query(Options::Algorithm::SparsifyTensorPass))
+ {
+ std::string tensor_name = _options->param(Options::AlgorithmParameters::Sparsify_tensor_name);
+ std::string str_tarversal_order =
+ _options->param(Options::AlgorithmParameters::Sparsify_traversal_order);
+ std::string str_format = _options->param(Options::AlgorithmParameters::Sparsify_format);
+ std::string str_block_size = _options->param(Options::AlgorithmParameters::Sparsify_block_size);
+ std::string str_block_map = _options->param(Options::AlgorithmParameters::Sparsify_block_map);
+
+ // traversal order
+ std::vector<int32_t> traversal_order = parseIntFromCommadelimitedStr(str_tarversal_order);
+ // format
+ std::vector<DimensionType> format;
+ std::istringstream is(str_format);
+ for (char c; is >> c;)
+ {
+ assert(c != ',');
+ if (c == 'd')
+ format.push_back(DimensionType::DENSE);
+ else if (c == 's')
+ format.push_back(DimensionType::SPARSE_CSR);
+ if (is.peek() == ',')
+ is.ignore();
+ }
+ // block size
+ std::vector<int32_t> block_size = parseIntFromCommadelimitedStr(str_block_size);
+ // block map
+ std::vector<int32_t> block_map = parseIntFromCommadelimitedStr(str_block_map);
+
+ luci::SparsifyTensorPass sparsifier{tensor_name, traversal_order, format, block_size,
+ block_map};
+ sparsifier.run(g);
+ }
+}
+
} // namespace luci
diff --git a/compiler/luci/pass/src/FoldDequantizePass.cpp b/compiler/luci/pass/src/FoldDequantizePass.cpp
new file mode 100644
index 000000000..01c04f478
--- /dev/null
+++ b/compiler/luci/pass/src/FoldDequantizePass.cpp
@@ -0,0 +1,206 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FoldDequantizePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <loco/Service/TypeInference.h>
+
+namespace
+{
+
+bool is_hybrid_kernel_supported(loco::Node *node)
+{
+ if (dynamic_cast<luci::CircleFullyConnected *>(node) != nullptr)
+ return true;
+
+ return false;
+}
+
+bool is_foldable_const(luci::CircleConst *node)
+{
+ if (node->quantparam() == nullptr)
+ return false;
+
+ if (node->dtype() == loco::DataType::S8)
+ return true;
+ if (node->dtype() == loco::DataType::U8)
+ return true;
+
+ return false;
+}
+
+luci::CircleConst *dequantized_const_node(luci::CircleConst *const_node)
+{
+ if (const_node->quantparam() == nullptr)
+ {
+ throw std::runtime_error("Given constant node has no quantization parameter");
+ }
+
+ auto g = const_node->graph();
+ auto new_const_node = g->nodes()->create<luci::CircleConst>();
+
+ new_const_node->dtype(loco::DataType::FLOAT32);
+ new_const_node->rank(const_node->rank());
+ uint32_t dim_size = 1;
+ for (uint32_t i = 0; i < new_const_node->rank(); ++i)
+ {
+ new_const_node->dim(i) = const_node->dim(i);
+ dim_size *= const_node->dim(i).value();
+ }
+ new_const_node->size<loco::DataType::FLOAT32>(dim_size);
+ new_const_node->shape_status(luci::ShapeStatus::VALID);
+
+ const int32_t q_dim = const_node->quantparam()->quantized_dimension;
+ const int32_t q_dim_value = const_node->dim(q_dim).value();
+
+ int32_t right_count = q_dim_value;
+ for (uint32_t i = q_dim + 1; i < const_node->rank(); ++i)
+ right_count *= const_node->dim(i).value();
+
+ if (const_node->dtype() == loco::DataType::S8)
+ {
+ for (uint32_t i = 0; i < const_node->size<loco::DataType::S8>(); ++i)
+ {
+ uint32_t qd = (i % right_count) / (right_count / q_dim_value);
+ if (qd >= const_node->quantparam()->zerop.size())
+ qd = 0;
+
+ new_const_node->at<loco::DataType::FLOAT32>(i) =
+ (float)(const_node->at<loco::DataType::S8>(i) - const_node->quantparam()->zerop.at(qd)) *
+ const_node->quantparam()->scale.at(qd);
+ }
+ }
+ else
+ {
+ for (uint32_t i = 0; i < const_node->size<loco::DataType::U8>(); ++i)
+ {
+ uint32_t qd = (i % right_count) / (right_count / q_dim_value);
+ if (qd >= const_node->quantparam()->zerop.size())
+ qd = 0;
+
+ new_const_node->at<loco::DataType::FLOAT32>(i) =
+ (float)((int)const_node->at<loco::DataType::U8>(i) -
+ const_node->quantparam()->zerop.at(qd)) *
+ const_node->quantparam()->scale.at(qd);
+ }
+ }
+
+ return new_const_node;
+}
+
+bool replace_const_node(loco::Node *node, luci::CircleConst *const_node)
+{
+ if (auto gather = dynamic_cast<luci::CircleGather *>(node))
+ {
+ gather->params(dequantized_const_node(const_node));
+ gather->dtype(loco::DataType::FLOAT32);
+ return true;
+ }
+ else
+ {
+ // TODO Support more ops
+ return false;
+ }
+}
+
+} // namespace
+
+namespace luci
+{
+
+/**
+ *
+ * Folding pattern 1 - When input of Dequantize is foldable constant
+ *
+ * [Before]
+ * quantized_const_input ---------- Dequantize ---------- Op ---
+ * +-- Op1_with_quant_input ---
+ * +-- Op2_with_quant_input ---
+ *
+ * [After]
+ * dequantized_const_input -------------------------------- Op ---
+ *
+ * quantized_const_input ----- Op1_with_quant_input ---
+ * +-- Op2_with_quant_input ---
+ *
+ *
+ * Folding pattern 2 - When input of Dequantize uses quantized output value
+ *
+ * [Before]
+ * quantized_const_input ----- Gather ----- Dequantize --- Op ---
+ * +-- Op1_with_quant_input ---
+ * +-- Op2_with_quant_input ---
+ *
+ * [After]
+ * dequantized_const_input ------Gather -------------------- Op ---
+ *
+ * quantized_const_input ----- Op1_with_quant_input ---
+ * +-- Op2_with_quant_input ---
+ *
+ *
+ */
+bool FoldDequantizePass::run(loco::Graph *g)
+{
+ bool changed = false;
+
+ for (auto node : loco::all_nodes(g))
+ {
+ if (auto circle_dequant = dynamic_cast<luci::CircleDequantize *>(node))
+ {
+ if (auto const_input = dynamic_cast<luci::CircleConst *>(circle_dequant->input()))
+ {
+ // Pattern 1 - When input of Dequantize is foldable constant
+ if (is_foldable_const(const_input))
+ {
+ loco::replace(circle_dequant).with(dequantized_const_node(const_input));
+ changed = true;
+ }
+ }
+ }
+ else if (auto const_node = dynamic_cast<luci::CircleConst *>(node))
+ {
+ if (is_foldable_const(const_node))
+ {
+ for (auto const_node_user : loco::succs(const_node))
+ {
+ // If user is hybrid kernel supported operation, do not dequantize
+ if (is_hybrid_kernel_supported(const_node_user))
+ continue;
+
+ auto users = loco::succs(const_node_user);
+ if (users.size() > 1)
+ continue;
+
+ // Pattern 2 - When input of Dequantize uses quantized output value
+ if (auto dequant = dynamic_cast<luci::CircleDequantize *>(*users.begin()))
+ {
+ if (replace_const_node(const_node_user, const_node))
+ {
+ loco::replace(dequant).with(const_node_user);
+ changed = true;
+ }
+ }
+ }
+ }
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/FuseAddWithTConvPass.cpp b/compiler/luci/pass/src/FuseAddWithTConvPass.cpp
new file mode 100644
index 000000000..bd7805f6a
--- /dev/null
+++ b/compiler/luci/pass/src/FuseAddWithTConvPass.cpp
@@ -0,0 +1,120 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FuseAddWithTConvPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+namespace
+{
+/**
+ * Fuse add to TCONV if possible
+ *
+ * BEFORE
+ *
+ * [CircleTransposeConv]
+ * |
+ * [add]
+ * AFTER
+ *
+ * [CircleTransposeConv]
+ */
+bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv)
+{
+ // check whether it has bias or not. This optimization works only if it doesn't.
+ auto bias = dynamic_cast<luci::CircleOutputExclude *>(tconv->bias());
+ if (not bias)
+ return false;
+
+ // get weight of tconv
+ auto filter = dynamic_cast<luci::CircleConst *>(tconv->filter());
+ if (not filter)
+ return false;
+ if (filter->dtype() != loco::DataType::FLOAT32)
+ return false;
+
+ // get add node
+ auto tconv_output = loco::succs(tconv);
+ assert(tconv_output.size() == 1);
+ auto add = dynamic_cast<luci::CircleAdd *>(*tconv_output.begin());
+ if (not add)
+ return false;
+ if (add->dtype() != loco::DataType::FLOAT32)
+ return false;
+ if (add->fusedActivationFunction() != luci::FusedActFunc::NONE &&
+ add->fusedActivationFunction() != luci::FusedActFunc::RELU6)
+ return false;
+
+ // get addition
+ luci::CircleConst *addition = nullptr;
+ if (add->x() == tconv)
+ addition = dynamic_cast<luci::CircleConst *>(add->y());
+ else
+ addition = dynamic_cast<luci::CircleConst *>(add->x());
+
+ if (not addition)
+ return false;
+
+ // addition dim(0) == tconv filter channel dim
+ if (addition->rank() != 1)
+ return false;
+ auto addition_dim = addition->dim(0).value();
+ auto filter_channel_dim = filter->dim(0).value();
+ if (filter_channel_dim != addition_dim)
+ return false;
+
+ // fuse addition with transposed conv
+ tconv->bias(addition);
+
+ if (add->fusedActivationFunction() == luci::FusedActFunc::RELU6)
+ {
+ // separate relu op from add op
+ auto relu = add->graph()->nodes()->create<luci::CircleRelu6>();
+ relu->features(tconv);
+
+ // remove add node
+ replace(add).with(relu);
+ }
+ else
+ {
+ replace(add).with(tconv);
+ }
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool FuseAddWithTConvPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto tconv = dynamic_cast<luci::CircleTransposeConv *>(node);
+ if (not tconv)
+ continue;
+
+ if (fuse_add_with_tconv(tconv))
+ changed = true;
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/FuseBCQPass.cpp b/compiler/luci/pass/src/FuseBCQPass.cpp
index 7aa2e3e80..ebf28779b 100644
--- a/compiler/luci/pass/src/FuseBCQPass.cpp
+++ b/compiler/luci/pass/src/FuseBCQPass.cpp
@@ -17,163 +17,139 @@
#include "luci/Pass/FuseBCQPass.h"
#include <luci/IR/CircleNodes.h>
+#include <luci/Log.h>
#include <cassert>
-#include <string>
#include <set>
namespace
{
-/**
- * @brief Circle nodes including BCQ information and a circle node to which BCQ will be applied
- * are connected with their name. And their names include common prefix.
- * However, after pb file is converted to tflite file, some nodes' name are changed.
- * Thus this function will return original common prefix.
- *
- * @note All the re-naming rule of TFLite converter is not figured out.
- * Therefore, if new naming rule is detected, this function should be updated.
- */
-const std::string node_name_prefix(luci::NodeName node_name)
-{
- std::string prefix = node_name;
-
- if (prefix.find("/ReadVariableOp/resource") != std::string::npos)
- {
- const auto start_index = prefix.find("/ReadVariableOp/resource");
-
- const auto left_prefix = prefix.substr(0, start_index);
- const auto right_prefix = prefix.substr(start_index + 24);
-
- prefix = left_prefix + right_prefix;
- }
-
- if (prefix.find("Tensordot/") != std::string::npos)
- {
- const auto index = prefix.find("Tensordot/");
- prefix = prefix.substr(0, index - 1);
- }
- else if (prefix.find("/MatMul") != std::string::npos)
- {
- const auto index = prefix.find("/MatMul");
- prefix = prefix.substr(0, index);
- }
- else if (prefix.find("kernel/") != std::string::npos)
- {
- const auto index = prefix.find("kernel/");
- prefix = prefix.substr(0, index - 1);
- }
- else if (prefix.find("/bcqinfo_") != std::string::npos)
- {
- const auto index = prefix.find("/bcqinfo_");
- prefix = prefix.substr(0, index);
- }
-
- return prefix;
-}
-
-/**
- * @brief Create CircleOutputExclude operation, which has same shape and dtype with
- * original circle_node.
- */
-luci::CircleOutputExclude *createNoOp(luci::CircleNode *circle_node)
-{
- auto graph = circle_node->graph();
- auto noOp = graph->nodes()->create<luci::CircleOutputExclude>();
-
- if (circle_node->shape_status() == luci::ShapeStatus::VALID)
- {
- noOp->dtype(circle_node->dtype());
- noOp->rank(circle_node->rank());
- for (uint32_t i = 0; i < circle_node->rank(); ++i)
- noOp->dim(i) = circle_node->dim(i);
- }
- else
- {
- // For type inference
- noOp->dtype(loco::DataType::FLOAT32);
- }
-
- return noOp;
-};
-
-} // namespace
-
-namespace
-{
-
// V means the version of BCQ.
template <int32_t V> class BCQFuser;
template <> class BCQFuser<1>
{
public:
+ BCQFuser<1>(int32_t original_output_cnt, int32_t bundle_cnt)
+ : _original_output_cnt{original_output_cnt}, _bundle_cnt{bundle_cnt}
+ {
+ // Do nothing
+ }
+
+public:
bool fuseBCQ(loco::Graph *g)
{
- bool changed = false;
- for (auto node : loco::all_nodes(g))
+ const auto output_nodes = loco::output_nodes(g);
+ for (auto node : output_nodes)
{
- if (auto circle_const = dynamic_cast<luci::CircleConst *>(node))
+ auto output_node = loco::must_cast<luci::CircleOutput *>(node);
+
+ /**
+ * First output of model is metadata for BCQ. Please refer to following example.
+ *
+ * When original_output_cnt is 2,
+ * BCQ_METADATA, original_output_1, original_output_2, BCQ_INFO_1, ...
+ */
+ if ((int)output_node->index() > _original_output_cnt)
{
- add_BCQ_info_node(circle_const);
+ const auto prefix = (output_node->index() - (_original_output_cnt + 1)) / (_bundle_cnt);
+ const MetadataType metadata_type = static_cast<MetadataType>(
+ (output_node->index() - (_original_output_cnt + 1)) % (_bundle_cnt));
+ const auto circle_node = loco::must_cast<luci::CircleNode *>(output_node->from());
+ add_BCQ_info_node(prefix, metadata_type, circle_node);
}
}
if (!is_bcqinfo_valid())
return false;
- for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ for (auto f : _fusable_op)
{
+ auto prefix = f.first;
+ luci::CircleNode *node = f.second;
+
+ if (!is_valid_prefix(prefix))
+ continue;
+
+ // Fuse Gather to BCQGather
if (auto gather = dynamic_cast<luci::CircleGather *>(node))
{
- auto params = dynamic_cast<luci::CircleConst *>(gather->params());
- if (params != nullptr && has_BCQ_info(params))
+ if (auto params = dynamic_cast<luci::CircleConst *>(gather->params()))
{
auto bcq_gather = g->nodes()->create<luci::CircleBCQGather>();
bcq_gather->op_version(1);
- bcq_gather->input_scales(get_alpha(params));
- bcq_gather->input_binary(get_packed_binary_code(params));
+ bcq_gather->input_scales(_alpha[prefix]);
+ bcq_gather->input_binary(_packed_binary_code[prefix]);
bcq_gather->indices(gather->indices());
- bcq_gather->input_clusters(packed_clusters(params));
+ bcq_gather->input_clusters(packed_clusters(g, prefix));
- // input_binary shape : [output_size, hidden_size]
- const auto binary_hidden_size =
- loco::must_cast<luci::CircleConst *>(bcq_gather->input_binary())->dim(1).value() * 32;
- bcq_gather->input_hidden_size(binary_hidden_size);
-
- if (do_w_x(params))
+ if (_do_w_x[prefix]->at<loco::DataType::BOOL>(0))
{
+ bcq_gather->input_hidden_size(params->dim(1).value());
bcq_gather->axis(gather->axis());
+ loco::replace(gather).with(bcq_gather);
}
else
{
+ bcq_gather->input_hidden_size(params->dim(0).value());
const auto axis_transpose = (gather->axis() == 0) ? 1 : 0;
bcq_gather->axis(axis_transpose);
+
+ const auto indices_rank =
+ loco::must_cast<luci::CircleNode *>(gather->indices())->rank();
+
+ auto perm = g->nodes()->create<luci::CircleConst>();
+ perm->dtype(loco::DataType::S32);
+ perm->size<loco::DataType::S32>(1 + indices_rank);
+ perm->rank(1);
+ perm->dim(0) = 1 + indices_rank;
+ for (uint32_t idx = 0; idx < indices_rank; ++idx)
+ perm->at<loco::DataType::S32>(idx) = idx + 1;
+ perm->at<loco::DataType::S32>(indices_rank) = 0;
+ perm->shape_status(luci::ShapeStatus::VALID);
+
+ auto output_transpose = g->nodes()->create<luci::CircleTranspose>();
+ output_transpose->a(bcq_gather);
+ output_transpose->perm(perm);
+
+ loco::replace(gather).with(output_transpose);
}
- loco::replace(gather).with(bcq_gather);
+ return true;
+ }
+ }
- changed = true;
+ // Einsum is unpacked to FullyConnected, Pack and Reshape
+ if (auto reshape = dynamic_cast<luci::CircleReshape *>(node))
+ {
+ node = dynamic_cast<luci::CircleNode *>(reshape->tensor());
+ }
+ if (auto pack = dynamic_cast<luci::CirclePack *>(node))
+ {
+ if (pack->values_count() == 1 && pack->rank() == 3)
+ {
+ node = dynamic_cast<luci::CircleNode *>(pack->values(0));
}
}
- else if (auto fully_connected = dynamic_cast<luci::CircleFullyConnected *>(node))
+
+ // Fuse FullyConnected to BCQFullyConnected
+ if (auto fully_connected = dynamic_cast<luci::CircleFullyConnected *>(node))
{
- auto weights = dynamic_cast<luci::CircleConst *>(fully_connected->weights());
- if (weights != nullptr && has_BCQ_info(weights))
+ if (auto weights = dynamic_cast<luci::CircleConst *>(fully_connected->weights()))
{
auto bcq_fc = g->nodes()->create<luci::CircleBCQFullyConnected>();
bcq_fc->op_version(1);
- bcq_fc->weights_scales(get_alpha(weights));
- bcq_fc->weights_binary(get_packed_binary_code(weights));
+ bcq_fc->weights_scales(_alpha[prefix]);
+ bcq_fc->weights_binary(_packed_binary_code[prefix]);
bcq_fc->bias(fully_connected->bias());
- bcq_fc->weights_clusters(packed_clusters(weights));
+ bcq_fc->weights_clusters(packed_clusters(g, prefix));
bcq_fc->fusedActivationFunction(fully_connected->fusedActivationFunction());
loco::Node *bcq_input = fully_connected->input();
- int32_t batch_rank = 0;
// If input of BCQFullyConnected has more than rank 2, we should reshape it as rank 2
const auto original_input = loco::must_cast<luci::CircleNode *>(fully_connected->input());
@@ -200,27 +176,18 @@ public:
reshape->shape(new_shape);
bcq_input = reshape;
- batch_rank = original_input->rank() - 2;
}
// If x_w formation, we should insert Transpose in front and back of BCQFullyConnected
- if (do_w_x(weights))
+ if (_do_w_x[prefix]->at<loco::DataType::BOOL>(0))
{
- const auto binary_hidden_size =
- loco::must_cast<luci::CircleNode *>(fully_connected->input())
- ->dim(batch_rank)
- .value();
- bcq_fc->weights_hidden_size(binary_hidden_size);
+ bcq_fc->weights_hidden_size(weights->dim(0).value());
bcq_fc->input(bcq_input);
loco::replace(fully_connected).with(bcq_fc);
}
else
{
- const auto binary_hidden_size =
- loco::must_cast<luci::CircleNode *>(fully_connected->input())
- ->dim(1 + batch_rank)
- .value();
- bcq_fc->weights_hidden_size(binary_hidden_size);
+ bcq_fc->weights_hidden_size(weights->dim(1).value());
auto perm = g->nodes()->create<luci::CircleConst>();
perm->dtype(loco::DataType::S32);
@@ -244,159 +211,183 @@ public:
loco::replace(fully_connected).with(output_transpose);
}
- changed = true;
+ return true;
+ }
+ else
+ {
+ // TODO Is there any case that input() is constant, instead of weights()?
}
}
}
- if (changed)
- clear_BCQ_nodes();
-
- return changed;
+ return false;
}
private:
- void add_BCQ_info_node(luci::CircleConst *node)
+ enum MetadataType
{
- const auto node_name = node->name();
- const auto prefix = node_name_prefix(node_name);
-
- // If bcqinfo_* nodes are held by Reshape operation,
- // shape of bcqinfo_* nodes are copied to `shape` input of Reshape operation.
- // Then the name becomes bcqinfo_*_copy_shape.
- // We should prevent this node not to added to bcq information.
- if (node_name.find("_copy_shape") != std::string::npos)
+ DO_W_X,
+ ALPHA,
+ BINARY_CODE,
+ NUM_OF_CLUSTERS,
+ SIZE_OF_CLUSTERS,
+ QBITS_OF_CLUSTERS,
+ FUSABLE_OP,
+ DEQUANT_WEIGHT,
+ };
+
+ void add_BCQ_info_node(int32_t prefix, MetadataType metadata_type, luci::CircleNode *node)
+ {
+ if (metadata_type == MetadataType::FUSABLE_OP)
+ {
+ _fusable_op[prefix] = node;
return;
+ }
- if (node_name.find("bcqinfo_do_w_x") != std::string::npos)
- _do_w_x[prefix] = node;
- else if (node_name.find("bcqinfo_alpha") != std::string::npos)
- _alpha[prefix] = node;
- else if (node_name.find("bcqinfo_packed_binary_code") != std::string::npos)
- _packed_binary_code[prefix] = node;
- else if (node_name.find("bcqinfo_number_of_clusters") != std::string::npos)
- _number_of_clusters[prefix] = node;
- else if (node_name.find("bcqinfo_size_of_clusters") != std::string::npos)
- _size_of_clusters[prefix] = node;
- else if (node_name.find("bcqinfo_qbits_of_clusters") != std::string::npos)
- _qbits_of_clusters[prefix] = node;
- else if (node_name.find("bcqinfo_dequant_weight") != std::string::npos)
- _dequant_weight[prefix] = node;
- }
+ luci::CircleConst *const_node;
- bool has_BCQ_info(luci::CircleConst *node)
- {
- const auto prefix = node_name_prefix(node->name());
- bool has_info = true;
-
- has_info &= (_do_w_x.find(prefix) != _do_w_x.end());
- has_info &= (_alpha.find(prefix) != _alpha.end());
- has_info &= (_packed_binary_code.find(prefix) != _packed_binary_code.end());
- has_info &= (_number_of_clusters.find(prefix) != _number_of_clusters.end());
- has_info &= (_size_of_clusters.find(prefix) != _size_of_clusters.end());
- has_info &= (_qbits_of_clusters.find(prefix) != _qbits_of_clusters.end());
- // bcqinfo_dequant_weight is just for validation, so not always exists.
-
- return has_info;
+ // Converter in TensorFlow v1.x sometimes generate Reshape op
+ if (auto reshape = dynamic_cast<luci::CircleReshape *>(node))
+ const_node = loco::must_cast<luci::CircleConst *>(reshape->tensor());
+ else
+ const_node = loco::must_cast<luci::CircleConst *>(node);
+
+ if (metadata_type == MetadataType::DO_W_X)
+ _do_w_x[prefix] = const_node;
+ else if (metadata_type == MetadataType::ALPHA)
+ _alpha[prefix] = const_node;
+ else if (metadata_type == MetadataType::BINARY_CODE)
+ _packed_binary_code[prefix] = const_node;
+ else if (metadata_type == MetadataType::NUM_OF_CLUSTERS)
+ _number_of_clusters[prefix] = const_node;
+ else if (metadata_type == MetadataType::SIZE_OF_CLUSTERS)
+ _size_of_clusters[prefix] = const_node;
+ else if (metadata_type == MetadataType::QBITS_OF_CLUSTERS)
+ _qbits_of_clusters[prefix] = const_node;
+ else
+ _dequant_weight[prefix] = const_node;
}
- /**
- * @brief Exclude BCQ information nodes which are used for fusing BCQ operations
- * from graph output by using CircleOutputExclude
- */
- void clear_BCQ_nodes()
+ bool is_bcqinfo_valid()
{
- auto clear_nodes = [](std::map<std::string, luci::CircleConst *> &nodes) {
- for (auto &n : nodes)
+ LOGGER(l);
+
+ for (auto n : _do_w_x)
+ {
+ // do_w_x should be BOOL type
+ if (n.second->dtype() != loco::DataType::BOOL)
{
- auto node = n.second;
+ WARN(l) << "FuseBCQPass : do_w_x has wrong type" << std::endl;
+ return false;
+ }
+ }
- for (auto s : loco::succs(node))
- {
- if (auto outnode = dynamic_cast<luci::CircleOutput *>(s))
- {
- outnode->from(createNoOp(node));
- }
- else if (auto reshape_node = dynamic_cast<luci::CircleReshape *>(s))
- {
- for (auto o : loco::succs(reshape_node))
- {
- auto circle_output = loco::must_cast<luci::CircleOutput *>(o);
- circle_output->from(createNoOp(reshape_node));
- }
- }
- }
+ for (auto n : _alpha)
+ {
+ // alpha should be FLOAT32 type
+ if (n.second->dtype() != loco::DataType::FLOAT32)
+ {
+ WARN(l) << "FuseBCQPass : alpha has wrong type" << std::endl;
+ return false;
}
- };
-
- clear_nodes(_do_w_x);
- clear_nodes(_alpha);
- clear_nodes(_packed_binary_code);
- clear_nodes(_number_of_clusters);
- clear_nodes(_size_of_clusters);
- clear_nodes(_qbits_of_clusters);
- clear_nodes(_dequant_weight);
- }
+ }
- bool is_bcqinfo_valid()
- {
- // do_w_x should be int32 or bool type
- for (auto n : _do_w_x)
+ for (auto n : _packed_binary_code)
+ {
+ // packed_binary_code should be INT32 type
+ if (n.second->dtype() != loco::DataType::S32)
+ {
+ WARN(l) << "FuseBCQPass : packed_binary_code has wrong type" << std::endl;
+ return false;
+ }
+ }
+
+ for (auto n : _number_of_clusters)
{
- if (n.second->dtype() != loco::DataType::BOOL && n.second->dtype() != loco::DataType::S32)
+ // number_of_clusters should be INT32 type
+ if (n.second->dtype() != loco::DataType::S32)
+ {
+ WARN(l) << "FuseBCQPass : number_of_clusters has wrong type" << std::endl;
return false;
+ }
}
+ for (auto n : _size_of_clusters)
+ {
+ // size_of_clusters should be INT32 type
+ if (n.second->dtype() != loco::DataType::S32)
+ {
+ WARN(l) << "FuseBCQPass : size_of_clusters has wrong type" << std::endl;
+ return false;
+ }
+ }
+
+ for (auto n : _qbits_of_clusters)
+ {
+ // qbits_of_clusters should be INT32 type
+ if (n.second->dtype() != loco::DataType::S32)
+ {
+ WARN(l) << "FuseBCQPass : qbits_of_clusters has wrong type" << std::endl;
+ return false;
+ }
+ }
+
+ // As dequant_weight is not used for fusing, skip validation.
+
return true;
}
-private:
- bool do_w_x(luci::CircleConst *node)
+ bool is_valid_prefix(int32_t prefix)
{
- const auto prefix = node_name_prefix(node->name());
+ LOGGER(l);
- if (_do_w_x[prefix]->dtype() == loco::DataType::S32)
- return _do_w_x[prefix]->at<loco::DataType::S32>(0) == 1;
- else
- return _do_w_x[prefix]->at<loco::DataType::BOOL>(0);
- }
+ if (_do_w_x.find(prefix) == _do_w_x.end())
+ {
+ WARN(l) << "do_w_x is not found" << std::endl;
+ return false;
+ }
- luci::CircleConst *get_alpha(luci::CircleConst *node)
- {
- const auto prefix = node_name_prefix(node->name());
- return _alpha[prefix];
- }
+ if (_alpha.find(prefix) == _alpha.end())
+ {
+ WARN(l) << "alpha is not found" << std::endl;
+ return false;
+ }
- luci::CircleConst *get_packed_binary_code(luci::CircleConst *node)
- {
- const auto prefix = node_name_prefix(node->name());
- return _packed_binary_code[prefix];
- }
+ if (_packed_binary_code.find(prefix) == _packed_binary_code.end())
+ {
+ WARN(l) << "packed_binary_code is not found" << std::endl;
+ return false;
+ }
- luci::CircleConst *get_number_of_clusters(luci::CircleConst *node)
- {
- const auto prefix = node_name_prefix(node->name());
- return _number_of_clusters[prefix];
- }
+ if (_number_of_clusters.find(prefix) == _number_of_clusters.end())
+ {
+ WARN(l) << "number_of_clusters is not found" << std::endl;
+ return false;
+ }
- luci::CircleConst *get_size_of_clusters(luci::CircleConst *node)
- {
- const auto prefix = node_name_prefix(node->name());
- return _size_of_clusters[prefix];
- }
+ if (_size_of_clusters.find(prefix) == _size_of_clusters.end())
+ {
+ WARN(l) << "size_of_clusters is not found" << std::endl;
+ return false;
+ }
- luci::CircleConst *get_qbits_of_clusters(luci::CircleConst *node)
- {
- const auto prefix = node_name_prefix(node->name());
- return _qbits_of_clusters[prefix];
+ if (_qbits_of_clusters.find(prefix) == _qbits_of_clusters.end())
+ {
+ WARN(l) << "qbits_of_clusters is not found" << std::endl;
+ return false;
+ }
+
+ // As dequant_weight is not used for fusing, skip validation.
+
+ return true;
}
- luci::CircleConst *packed_clusters(luci::CircleConst *node)
+private:
+ luci::CircleConst *packed_clusters(loco::Graph *graph, int32_t prefix)
{
- auto graph = node->graph();
- auto qbits_of_clusters = get_qbits_of_clusters(node);
- auto size_of_clusters = get_size_of_clusters(node);
- const auto number_of_clusters = get_number_of_clusters(node)->at<loco::DataType::S32>(0);
+ auto qbits_of_clusters = _qbits_of_clusters[prefix];
+ auto size_of_clusters = _size_of_clusters[prefix];
+ const auto number_of_clusters = _number_of_clusters[prefix]->at<loco::DataType::S32>(0);
auto packed_clusters = graph->nodes()->create<luci::CircleConst>();
packed_clusters->dtype(loco::DataType::S32);
@@ -418,13 +409,18 @@ private:
}
private:
- std::map<std::string, luci::CircleConst *> _do_w_x;
- std::map<std::string, luci::CircleConst *> _alpha;
- std::map<std::string, luci::CircleConst *> _packed_binary_code;
- std::map<std::string, luci::CircleConst *> _number_of_clusters;
- std::map<std::string, luci::CircleConst *> _size_of_clusters;
- std::map<std::string, luci::CircleConst *> _qbits_of_clusters;
- std::map<std::string, luci::CircleConst *> _dequant_weight;
+ std::map<int32_t, luci::CircleConst *> _do_w_x;
+ std::map<int32_t, luci::CircleConst *> _alpha;
+ std::map<int32_t, luci::CircleConst *> _packed_binary_code;
+ std::map<int32_t, luci::CircleConst *> _number_of_clusters;
+ std::map<int32_t, luci::CircleConst *> _size_of_clusters;
+ std::map<int32_t, luci::CircleConst *> _qbits_of_clusters;
+ std::map<int32_t, luci::CircleConst *> _dequant_weight;
+ std::map<int32_t, luci::CircleNode *> _fusable_op;
+
+private:
+ int32_t _original_output_cnt = 0;
+ int32_t _bundle_cnt = 0;
};
} // namespace
@@ -436,38 +432,72 @@ bool FuseBCQPass::run(loco::Graph *g)
{
bool changed = false;
- // Find BCQ version information and check validity.
- luci::CircleConst *version_node = nullptr;
- for (auto node : loco::all_nodes(g))
+ const int32_t start_magicnum = -2e9 + 27;
+ const int32_t end_magicnum = 2e9 - 27;
+
+ luci::CircleConst *metadata_node = nullptr;
+ for (auto node : loco::output_nodes(g))
{
- if (auto circle_const = dynamic_cast<luci::CircleConst *>(node))
+ auto output_node = loco::must_cast<luci::CircleOutput *>(node);
+
+ // Metadata node should be first output
+ if (output_node->index() != 0)
+ continue;
+
+ // Metadata should be constant and dtype should be S32
+ auto const_node = dynamic_cast<luci::CircleConst *>(output_node->from());
+ if (const_node == nullptr || const_node->dtype() != loco::DataType::S32)
+ continue;
+
+ // Metadata has at least four elements
+ const auto element_cnt = const_node->size<loco::DataType::S32>();
+ if (element_cnt < 4)
+ continue;
+
+ // Metadata has magic numbers at first and at last
+ const auto start_value = const_node->at<loco::DataType::S32>(0);
+ const auto end_value = const_node->at<loco::DataType::S32>(element_cnt - 1);
+ if (start_value == start_magicnum && end_value == end_magicnum)
{
- if (circle_const->name().find("/bcqinfo_version") != std::string::npos)
- {
- // There should be only one bcqinfo_version in the model
- if (version_node != nullptr)
- {
- assert(false && "Multiple version information found");
- return false;
- }
-
- version_node = circle_const;
- }
+ metadata_node = const_node;
+ break;
}
}
- // If version node is not found, regard it as version 1.
- int32_t bcq_version = (version_node != nullptr) ? version_node->at<loco::DataType::S32>(0) : 1;
+ if (metadata_node != nullptr)
+ {
+ const auto bcq_version = metadata_node->at<loco::DataType::S32>(1);
+ const auto original_output_cnt = metadata_node->at<loco::DataType::S32>(2);
- if (bcq_version == 1)
- changed = BCQFuser<1>().fuseBCQ(g);
- else
- assert(false && "Not supported BCQ version");
+ if (bcq_version == 1)
+ {
+ const auto bundle_cnt = metadata_node->at<loco::DataType::S32>(3);
- if (changed && version_node != nullptr)
- {
- // If BCQ is applied and version node was found, remove the node.
- loco::replace(version_node).with(createNoOp(version_node));
+ BCQFuser<1> fuser{original_output_cnt, bundle_cnt};
+ if (fuser.fuseBCQ(g))
+ changed = true;
+ }
+ else
+ {
+ LOGGER(l);
+ WARN(l) << "Not supported BCQ version is found." << std::endl;
+ }
+
+ // Remove all of BCQ information nodes iff there is no change
+ if (changed == false)
+ {
+ for (auto node : loco::output_nodes(g))
+ {
+ auto output_node = loco::must_cast<luci::CircleOutput *>(node);
+ if (output_node->index() == 0 || (int)output_node->index() > original_output_cnt)
+ {
+ auto noOp = g->nodes()->create<luci::CircleOutputExclude>();
+ noOp->dtype(loco::DataType::FLOAT32); // TODO Remove this setting
+ output_node->from(noOp);
+ changed = true;
+ }
+ }
+ }
}
return changed;
diff --git a/compiler/luci/pass/src/FuseBatchNormWithTConv.cpp b/compiler/luci/pass/src/FuseBatchNormWithTConv.cpp
index e39455b1a..95ccd8176 100644
--- a/compiler/luci/pass/src/FuseBatchNormWithTConv.cpp
+++ b/compiler/luci/pass/src/FuseBatchNormWithTConv.cpp
@@ -77,11 +77,11 @@ bool fused_batch_norm_with_tconv(luci::CircleTransposeConv *tconv)
// scale dim(0) == tconv filter channel dim
if (filter->rank() != 4)
return false;
- auto filter_channel_dim = filter->dim(3).value();
+ auto filter_out_dim = filter->dim(0).value();
if (scale->rank() != 1)
return false;
auto scale_dim = scale->dim(0).value();
- if (filter_channel_dim != scale_dim)
+ if (filter_out_dim != scale_dim)
return false;
// get shift of batchnorm
@@ -93,23 +93,23 @@ bool fused_batch_norm_with_tconv(luci::CircleTransposeConv *tconv)
if (shift->rank() != 1)
return false;
auto shift_dim = shift->dim(0).value();
- if (filter_channel_dim != shift_dim)
+ if (filter_out_dim != shift_dim)
return false;
// filter weight = filter weight * mul(scale) + add(shift)
- uint32_t filter_batch_dim = filter->dim(0).value();
uint32_t filter_height_dim = filter->dim(1).value();
uint32_t filter_width_dim = filter->dim(2).value();
- for (uint32_t c = 0; c < filter_channel_dim; c++)
+ uint32_t filter_in_dim = filter->dim(3).value();
+ for (uint32_t c = 0; c < filter_out_dim; c++)
{
- for (uint32_t n = 0; n < filter_batch_dim; n++)
+ for (uint32_t h = 0; h < filter_height_dim; h++)
{
- for (uint32_t h = 0; h < filter_height_dim; h++)
+ for (uint32_t w = 0; w < filter_width_dim; w++)
{
- for (uint32_t w = 0; w < filter_width_dim; w++)
+ for (uint32_t b = 0; b < filter_in_dim; b++)
{
- uint32_t offset = n * filter_height_dim * filter_width_dim * filter_channel_dim +
- h * filter_width_dim * filter_channel_dim + w * filter_channel_dim + c;
+ uint32_t offset = c * filter_height_dim * filter_width_dim * filter_in_dim +
+ h * filter_width_dim * filter_in_dim + w * filter_in_dim + b;
filter->at<loco::DataType::FLOAT32>(offset) *= scale->at<loco::DataType::FLOAT32>(c);
}
}
diff --git a/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp b/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp
new file mode 100644
index 000000000..0f8d562e9
--- /dev/null
+++ b/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp
@@ -0,0 +1,372 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "QuantizationUtils.h"
+
+#include <luci/IR/CircleQuantParam.h>
+
+#include <math.h>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+void addQuantParam(luci::CircleNode &node, const std::vector<float> &scale,
+ const std::vector<int64_t> &zp)
+{
+ assert(node.quantparam() == nullptr);
+
+ auto quantparam = std::make_unique<luci::CircleQuantParam>();
+ quantparam->scale = scale;
+ quantparam->zerop = zp;
+ node.quantparam(std::move(quantparam));
+}
+
+int32_t quantize(float f, luci::CircleQuantParam *qparam)
+{
+ float scale = qparam->scale[0];
+ int64_t zp = qparam->zerop[0];
+
+ return std::round(f / scale) + zp;
+}
+
+class SimpleConcatGraph
+{
+public:
+ SimpleConcatGraph(loco::DataType quant_type)
+ {
+ concat_node.dtype(quant_type);
+ concat_node.fusedActivationFunction(luci::FusedActFunc::NONE);
+ input_1.dtype(quant_type);
+ input_2.dtype(quant_type);
+
+ concat_node.values(0, &input_1);
+ concat_node.values(1, &input_2);
+
+ if (quant_type == loco::DataType::U8)
+ {
+ addQuantParam(concat_node, {3.14}, {77});
+ addQuantParam(input_1, {1.0}, {1});
+ addQuantParam(input_2, {2.0}, {2});
+ }
+ else if (quant_type == loco::DataType::S16)
+ {
+ addQuantParam(concat_node, {3.14}, {0});
+ addQuantParam(input_1, {1.0}, {0});
+ addQuantParam(input_2, {2.0}, {0});
+ }
+ else
+ {
+ throw std::runtime_error("Unsupported quantization type");
+ }
+ }
+
+ ~SimpleConcatGraph()
+ {
+ concat_node.values(0, nullptr);
+ concat_node.values(1, nullptr);
+ }
+
+public:
+ luci::CircleConcatenation concat_node{2};
+ luci::CircleConv2D input_1;
+ luci::CircleConv2D input_2;
+};
+
+class SubsequentConcatGraph
+{
+public:
+ SubsequentConcatGraph(loco::DataType quant_type)
+ {
+ concat_node.dtype(quant_type);
+ concat_node.fusedActivationFunction(luci::FusedActFunc::NONE);
+ input_1.dtype(quant_type);
+ input_2.dtype(quant_type);
+
+ concat_node.values(0, &input_1);
+ concat_node.values(1, &input_2);
+
+ if (quant_type == loco::DataType::U8)
+ {
+ addQuantParam(concat_node, {3.14}, {77});
+ addQuantParam(input_1, {1.0}, {1});
+ addQuantParam(input_2, {2.0}, {2});
+ }
+ else if (quant_type == loco::DataType::S16)
+ {
+ addQuantParam(concat_node, {3.14}, {0});
+ addQuantParam(input_1, {1.0}, {0});
+ addQuantParam(input_2, {2.0}, {0});
+ }
+ else
+ {
+ throw std::runtime_error("Unsupported quantization type");
+ }
+ }
+
+ ~SubsequentConcatGraph()
+ {
+ concat_node.values(0, nullptr);
+ concat_node.values(1, nullptr);
+ }
+
+public:
+ luci::CircleConcatenation concat_node{2};
+ luci::CircleConcatenation input_1{2};
+ luci::CircleConv2D input_2;
+};
+
+class ConstInputConcatGraph
+{
+public:
+ ConstInputConcatGraph(loco::DataType quant_type)
+ {
+ concat_node.dtype(quant_type);
+ concat_node.fusedActivationFunction(luci::FusedActFunc::NONE);
+ input_1.dtype(loco::DataType::FLOAT32);
+ input_1.size<loco::DataType::FLOAT32>(5);
+ for (int i = 0; i < 5; i++)
+ {
+ // Set data {-2, -1, 0, 1, 2}
+ input_1.at<loco::DataType::FLOAT32>(i) = i - 2.0;
+ }
+
+ input_2.dtype(quant_type);
+
+ concat_node.values(0, &input_1);
+ concat_node.values(1, &input_2);
+
+ if (quant_type == loco::DataType::U8)
+ {
+ addQuantParam(concat_node, {0.1}, {10});
+ addQuantParam(input_2, {2.0}, {2});
+ }
+ else if (quant_type == loco::DataType::S16)
+ {
+ addQuantParam(concat_node, {0.1}, {0});
+ addQuantParam(input_2, {2.0}, {0});
+ }
+ else
+ {
+ throw std::runtime_error("Unsupported quantization type");
+ }
+ }
+
+ ~ConstInputConcatGraph()
+ {
+ concat_node.values(0, nullptr);
+ concat_node.values(1, nullptr);
+ }
+
+public:
+ luci::CircleConcatenation concat_node{2};
+ luci::CircleConst input_1;
+ luci::CircleConv2D input_2;
+};
+
+} // namespace
+
+TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8)
+{
+ // Check cases where qparam of concat_node is propagated
+ // (1) normal case: qparam is propagated to input_1 and input_2
+ // (2) input used by other Op: input_1 is an input of input_2. qparam is propagated only to
+ // input_2
+ // (3) subsequent concat: input_1 is concat. qparam is propagated only to input_2
+ // (4) const input: input_1 is const. constant values are quantized
+
+ // normal case: qparam of concat_node is propagated to input_1 and input_2
+ SimpleConcatGraph g(loco::DataType::U8);
+ luci::propagate_concat_quantparam(&g.concat_node, loco::DataType::U8);
+ EXPECT_FLOAT_EQ(3.14, g.concat_node.quantparam()->scale[0]);
+ EXPECT_EQ(77, g.concat_node.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(3.14, g.input_1.quantparam()->scale[0]);
+ EXPECT_EQ(77, g.input_1.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(3.14, g.input_2.quantparam()->scale[0]);
+ EXPECT_EQ(77, g.input_2.quantparam()->zerop[0]);
+
+ // input_1 is an input of input_2. qparam is propagated only to input_2
+ SimpleConcatGraph g2(loco::DataType::U8);
+ g2.input_2.input(&g2.input_1);
+ luci::propagate_concat_quantparam(&g2.concat_node, loco::DataType::U8);
+ EXPECT_FLOAT_EQ(3.14, g2.concat_node.quantparam()->scale[0]);
+ EXPECT_EQ(77, g2.concat_node.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(1.0, g2.input_1.quantparam()->scale[0]);
+ EXPECT_EQ(1, g2.input_1.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(3.14, g2.input_2.quantparam()->scale[0]);
+ EXPECT_EQ(77, g2.input_2.quantparam()->zerop[0]);
+
+ // input_1 is concat. qparam is propagated only to input_2
+ SubsequentConcatGraph sg(loco::DataType::U8);
+ luci::propagate_concat_quantparam(&sg.concat_node, loco::DataType::U8);
+ EXPECT_FLOAT_EQ(3.14, sg.concat_node.quantparam()->scale[0]);
+ EXPECT_EQ(77, sg.concat_node.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(1.0, sg.input_1.quantparam()->scale[0]);
+ EXPECT_EQ(1, sg.input_1.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(3.14, sg.input_2.quantparam()->scale[0]);
+ EXPECT_EQ(77, sg.input_2.quantparam()->zerop[0]);
+
+ // input_1 is const. const values are quantized with the qparam of concat
+ ConstInputConcatGraph cg(loco::DataType::U8);
+ luci::propagate_concat_quantparam(&cg.concat_node, loco::DataType::U8);
+ EXPECT_FLOAT_EQ(0.1, cg.concat_node.quantparam()->scale[0]);
+ EXPECT_EQ(10, cg.concat_node.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(0.1, cg.input_1.quantparam()->scale[0]);
+ EXPECT_EQ(10, cg.input_1.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(0.1, cg.input_2.quantparam()->scale[0]);
+ EXPECT_EQ(10, cg.input_2.quantparam()->zerop[0]);
+ EXPECT_EQ(loco::DataType::U8, cg.input_1.dtype());
+ EXPECT_EQ(0, cg.input_1.at<loco::DataType::U8>(0));
+ EXPECT_EQ(0, cg.input_1.at<loco::DataType::U8>(1));
+ EXPECT_EQ(10, cg.input_1.at<loco::DataType::U8>(2));
+ EXPECT_EQ(20, cg.input_1.at<loco::DataType::U8>(3));
+ EXPECT_EQ(30, cg.input_1.at<loco::DataType::U8>(4));
+}
+
+TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8_NEG)
+{
+ // Check negative cases where qparam is not propagated
+ // (1) concat has fused activation function
+ // (2) concat has fused activation function and input is const
+
+ SimpleConcatGraph g(loco::DataType::U8);
+
+ // concat has fused activation function
+ g.concat_node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ luci::propagate_concat_quantparam(&g.concat_node, loco::DataType::U8);
+ EXPECT_FLOAT_EQ(3.14, g.concat_node.quantparam()->scale[0]);
+ EXPECT_EQ(77, g.concat_node.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(1.0, g.input_1.quantparam()->scale[0]);
+ EXPECT_EQ(1, g.input_1.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(2.0, g.input_2.quantparam()->scale[0]);
+ EXPECT_EQ(2, g.input_2.quantparam()->zerop[0]);
+ g.concat_node.fusedActivationFunction(luci::FusedActFunc::NONE);
+
+ // concat has fused activation function and input_1 is const.
+ // const values are quantized using its min/max
+ ConstInputConcatGraph cg(loco::DataType::U8);
+ cg.concat_node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ luci::propagate_concat_quantparam(&cg.concat_node, loco::DataType::U8);
+ EXPECT_FLOAT_EQ(0.1, cg.concat_node.quantparam()->scale[0]);
+ EXPECT_EQ(10, cg.concat_node.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(0.015686275, cg.input_1.quantparam()->scale[0]);
+ EXPECT_EQ(128, cg.input_1.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(2.0, cg.input_2.quantparam()->scale[0]);
+ EXPECT_EQ(2, cg.input_2.quantparam()->zerop[0]);
+ EXPECT_EQ(loco::DataType::U8, cg.input_1.dtype());
+ EXPECT_EQ(quantize(-2, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::U8>(0));
+ EXPECT_EQ(quantize(-1, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::U8>(1));
+ EXPECT_EQ(quantize(0, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::U8>(2));
+ EXPECT_EQ(quantize(1, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::U8>(3));
+ EXPECT_EQ(quantize(2, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::U8>(4));
+}
+
+TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16)
+{
+ // Check cases where qparam of concat_node is propagated
+ // (1) normal case: qparam is propagated to input_1 and input_2
+ // (2) input used by other Op: input_1 is an input of input_2. qparam is propagated only to
+ // input_2
+ // (3) subsequent concat: input_1 is concat. qparam is propagated only to input_2
+ // (4) const input: input_1 is const. constant values are quantized
+
+ // normal case: qparam of concat_node is propagated to input_1 and input_2
+ SimpleConcatGraph g(loco::DataType::S16);
+ luci::propagate_concat_quantparam(&g.concat_node, loco::DataType::S16);
+ EXPECT_FLOAT_EQ(3.14, g.concat_node.quantparam()->scale[0]);
+ EXPECT_EQ(0, g.concat_node.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(3.14, g.input_1.quantparam()->scale[0]);
+ EXPECT_EQ(0, g.input_1.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(3.14, g.input_2.quantparam()->scale[0]);
+ EXPECT_EQ(0, g.input_2.quantparam()->zerop[0]);
+
+ // input_1 is an input of input_2. qparam is propagated only to input_2
+ SimpleConcatGraph g2(loco::DataType::S16);
+ g2.input_2.input(&g2.input_1);
+ luci::propagate_concat_quantparam(&g2.concat_node, loco::DataType::S16);
+ EXPECT_FLOAT_EQ(3.14, g2.concat_node.quantparam()->scale[0]);
+ EXPECT_EQ(0, g2.concat_node.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(1.0, g2.input_1.quantparam()->scale[0]);
+ EXPECT_EQ(0, g2.input_1.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(3.14, g2.input_2.quantparam()->scale[0]);
+ EXPECT_EQ(0, g2.input_2.quantparam()->zerop[0]);
+
+ // input_1 is concat. qparam is propagated only to input_2
+ SubsequentConcatGraph sg(loco::DataType::S16);
+ luci::propagate_concat_quantparam(&sg.concat_node, loco::DataType::S16);
+ EXPECT_FLOAT_EQ(3.14, sg.concat_node.quantparam()->scale[0]);
+ EXPECT_EQ(0, sg.concat_node.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(1.0, sg.input_1.quantparam()->scale[0]);
+ EXPECT_EQ(0, sg.input_1.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(3.14, sg.input_2.quantparam()->scale[0]);
+ EXPECT_EQ(0, sg.input_2.quantparam()->zerop[0]);
+
+ // input_1 is const. const values are quantized with the qparam of concat
+ ConstInputConcatGraph cg(loco::DataType::S16);
+ luci::propagate_concat_quantparam(&cg.concat_node, loco::DataType::S16);
+ EXPECT_FLOAT_EQ(0.1, cg.concat_node.quantparam()->scale[0]);
+ EXPECT_EQ(0, cg.concat_node.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(0.1, cg.input_1.quantparam()->scale[0]);
+ EXPECT_EQ(0, cg.input_1.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(0.1, cg.input_2.quantparam()->scale[0]);
+ EXPECT_EQ(0, cg.input_2.quantparam()->zerop[0]);
+ EXPECT_EQ(loco::DataType::S16, cg.input_1.dtype());
+ EXPECT_EQ(-20, cg.input_1.at<loco::DataType::S16>(0));
+ EXPECT_EQ(-10, cg.input_1.at<loco::DataType::S16>(1));
+ EXPECT_EQ(0, cg.input_1.at<loco::DataType::S16>(2));
+ EXPECT_EQ(10, cg.input_1.at<loco::DataType::S16>(3));
+ EXPECT_EQ(20, cg.input_1.at<loco::DataType::S16>(4));
+}
+
+TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16_NEG)
+{
+ // Check negative cases where qparam is not propagated
+ // (1) concat has fused activation function
+ // (2) concat has fused activation function and input is const
+
+ SimpleConcatGraph g(loco::DataType::S16);
+
+ // concat has fused activation function
+ g.concat_node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ luci::propagate_concat_quantparam(&g.concat_node, loco::DataType::S16);
+ EXPECT_FLOAT_EQ(3.14, g.concat_node.quantparam()->scale[0]);
+ EXPECT_EQ(0, g.concat_node.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(1.0, g.input_1.quantparam()->scale[0]);
+ EXPECT_EQ(0, g.input_1.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(2.0, g.input_2.quantparam()->scale[0]);
+ EXPECT_EQ(0, g.input_2.quantparam()->zerop[0]);
+ g.concat_node.fusedActivationFunction(luci::FusedActFunc::NONE);
+
+ // concat has fused activation function and input_1 is const.
+ // const values are quantized using its min/max
+ ConstInputConcatGraph cg(loco::DataType::S16);
+ cg.concat_node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ luci::propagate_concat_quantparam(&cg.concat_node, loco::DataType::S16);
+ EXPECT_FLOAT_EQ(0.1, cg.concat_node.quantparam()->scale[0]);
+ EXPECT_EQ(0, cg.concat_node.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(0.000061037, cg.input_1.quantparam()->scale[0]);
+ EXPECT_EQ(0, cg.input_1.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(2.0, cg.input_2.quantparam()->scale[0]);
+ EXPECT_EQ(0, cg.input_2.quantparam()->zerop[0]);
+ EXPECT_EQ(loco::DataType::S16, cg.input_1.dtype());
+ EXPECT_EQ(quantize(-2, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::S16>(0));
+ EXPECT_EQ(quantize(-1, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::S16>(1));
+ EXPECT_EQ(quantize(0, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::S16>(2));
+ EXPECT_EQ(quantize(1, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::S16>(3));
+ EXPECT_EQ(quantize(2, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::S16>(4));
+}
diff --git a/compiler/luci/pass/src/QuantizationUtils.cpp b/compiler/luci/pass/src/QuantizationUtils.cpp
index e18690605..9af52a4c4 100644
--- a/compiler/luci/pass/src/QuantizationUtils.cpp
+++ b/compiler/luci/pass/src/QuantizationUtils.cpp
@@ -31,6 +31,66 @@ uint8_t fp32_to_uint8_cast(float f)
return static_cast<uint8_t>(f);
}
+// Per-layer quantization of weights (const tensor) using given min/max values
+void asymmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float max,
+ float &scaling_factor, int64_t &zp, float &nudged_min,
+ float &nudged_max)
+{
+ const int32_t kMinScale = 0;
+ const int32_t kMaxScale = 255;
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+ compute_asym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
+ const float scaling_factor_inv = 1.0 / scaling_factor;
+ std::vector<int32_t> quantized_values(size);
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ // clipping
+ auto data = node->at<loco::DataType::FLOAT32>(i);
+ data = data < nudged_min ? nudged_min : data;
+ data = data > nudged_max ? nudged_max : data;
+ quantized_values[i] =
+ static_cast<int32_t>(std::round((data - nudged_min) * scaling_factor_inv));
+ }
+
+ node->dtype(loco::DataType::U8); // change the type of tensor
+ node->size<loco::DataType::U8>(size); // resize tensor
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ node->at<loco::DataType::U8>(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
+ }
+}
+
+// Per-layer quantization of weights (const tensor) using given min/max values
+void symmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float max,
+ float &scaling_factor, int64_t &zp, float &nudged_min,
+ float &nudged_max)
+{
+ const int32_t kMaxScale = std::numeric_limits<int16_t>::max();
+ const int32_t kMinScale = -kMaxScale;
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+ compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
+ const float scaling_factor_inv = 1.0 / scaling_factor;
+ std::vector<int32_t> quantized_values(size);
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ // clipping
+ auto data = node->at<loco::DataType::FLOAT32>(i);
+ data = data < nudged_min ? nudged_min : data;
+ data = data > nudged_max ? nudged_max : data;
+ quantized_values[i] = static_cast<int32_t>(std::round(data * scaling_factor_inv));
+ }
+
+ node->dtype(loco::DataType::S16); // change the type of tensor
+ node->size<loco::DataType::S16>(size); // resize tensor
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ node->at<loco::DataType::S16>(i) =
+ std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
+ }
+}
+
void compute_sym_scale_zp(float min, float max, float &scaling_factor, int64_t &zp,
float &nudged_min, float &nudged_max)
{
diff --git a/compiler/luci/pass/src/QuantizationUtils.h b/compiler/luci/pass/src/QuantizationUtils.h
index ec0e86df8..f766bd66d 100644
--- a/compiler/luci/pass/src/QuantizationUtils.h
+++ b/compiler/luci/pass/src/QuantizationUtils.h
@@ -29,10 +29,20 @@ void compute_sym_scale_zp(float min, float max, float &scaling_factor, int64_t &
void compute_asym_scale_zp(float min, float max, float &scaling_factor, int64_t &zp,
float &nudged_min, float &nudged_max);
+void asymmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float max,
+ float &scaling_factor, int64_t &zp, float &nudged_min,
+ float &nudged_max);
+
+void symmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float max,
+ float &scaling_factor, int64_t &zp, float &nudged_min,
+ float &nudged_max);
+
bool get_channel_dim_index(CircleConst *node, loco::TensorShape &dimension, int &channel_dim_index);
uint32_t cal_offset(loco::TensorShape &dimension, uint32_t *indices);
+void propagate_concat_quantparam(luci::CircleConcatenation *concat, loco::DataType quant_type);
+
} // namespace luci
#endif // __LUCI_QUANTIZATION_UTILS_H__
diff --git a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp
index c492234c7..e9925c7ff 100644
--- a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp
+++ b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp
@@ -284,36 +284,6 @@ void asymmetric_wdequant_per_channel(CircleConst *node, std::vector<float> &scal
}
}
-void asymmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float max,
- float &scaling_factor, int64_t &zp, float &nudged_min,
- float &nudged_max)
-{
-
- const int32_t kMinScale = 0;
- const int32_t kMaxScale = 255;
-
- uint32_t size = node->size<loco::DataType::FLOAT32>();
- compute_asym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
- const float scaling_factor_inv = 1.0 / scaling_factor;
- std::vector<int32_t> quantized_values(size);
- for (uint32_t i = 0; i < size; ++i)
- {
- // clipping
- auto data = node->at<loco::DataType::FLOAT32>(i);
- data = data < nudged_min ? nudged_min : data;
- data = data > nudged_max ? nudged_max : data;
- quantized_values[i] =
- static_cast<int32_t>(std::round((data - nudged_min) * scaling_factor_inv));
- }
-
- node->dtype(loco::DataType::U8); // change the type of tensor
- node->size<loco::DataType::U8>(size); // resize tensor
- for (uint32_t i = 0; i < size; ++i)
- {
- node->at<loco::DataType::U8>(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
- }
-}
-
void asymmetric_wdequant_with_minmax_per_layer(CircleConst *node, float scaling_factor,
float nudged_min)
{
diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
index 60c1cdd72..564e814f9 100644
--- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
+++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
@@ -32,7 +32,99 @@ namespace luci
namespace
{
-// Check if the node is the bias of Conv2D, DepthwiseConv2D, or FullyConnected layer
+void overwrite_quantparam(luci::CircleConcatenation *concat, luci::CircleNode *target)
+{
+ auto concat_qparam = concat->quantparam();
+ if (concat_qparam == nullptr)
+ throw std::runtime_error("quantparam of concat is not found during overwrite");
+
+ auto target_qparam = target->quantparam();
+ if (target_qparam == nullptr)
+ {
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ target->quantparam(std::move(quantparam));
+ target_qparam = target->quantparam();
+ }
+ target_qparam->min = concat_qparam->min;
+ target_qparam->max = concat_qparam->max;
+ target_qparam->scale = concat_qparam->scale;
+ target_qparam->zerop = concat_qparam->zerop;
+ target_qparam->quantized_dimension = concat_qparam->quantized_dimension;
+}
+
+void quant_const_values(luci::CircleConst *const_node, float scaling_factor, float zerop,
+ loco::DataType quant_type)
+{
+ uint32_t size = const_node->size<loco::DataType::FLOAT32>();
+
+ const float scaling_factor_inv = 1.0 / scaling_factor;
+ std::vector<int32_t> quantized_values(size);
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ auto data = const_node->at<loco::DataType::FLOAT32>(i);
+ quantized_values[i] = static_cast<int32_t>(std::round(data * scaling_factor_inv) + zerop);
+ }
+
+ switch (quant_type)
+ {
+ case loco::DataType::U8:
+ const_node->dtype(loco::DataType::U8); // change the type of tensor
+ const_node->size<loco::DataType::U8>(size); // resize tensor
+ for (uint32_t i = 0; i < size; ++i)
+ const_node->at<loco::DataType::U8>(i) = std::min(255, std::max(0, quantized_values[i]));
+ break;
+ case loco::DataType::S16:
+ assert(zerop == 0);
+ const_node->dtype(loco::DataType::S16); // change the type of tensor
+ const_node->size<loco::DataType::S16>(size); // resize tensor
+ for (uint32_t i = 0; i < size; ++i)
+ const_node->at<loco::DataType::S16>(i) =
+ std::min(32767, std::max(-32767, quantized_values[i]));
+ break;
+ default:
+ throw std::runtime_error("Unsupported data type");
+ }
+}
+
+void quant_const(CircleConst *node, loco::DataType quant_type)
+{
+ assert(node->dtype() == loco::DataType::FLOAT32);
+
+ float min = std::numeric_limits<float>::max();
+ float max = std::numeric_limits<float>::lowest();
+ for (uint32_t i = 0; i < node->size<loco::DataType::FLOAT32>(); i++)
+ {
+ auto data = node->at<loco::DataType::FLOAT32>(i);
+ min = data < min ? data : min;
+ max = data > max ? data : max;
+ }
+
+ float scaling_factor{0.0};
+ int64_t zp{0};
+ float nudged_min{0.0};
+ float nudged_max{0.0};
+
+ switch (quant_type)
+ {
+ case loco::DataType::U8:
+ asymmetric_wquant_with_minmax_per_layer(node, min, max, scaling_factor, zp, nudged_min,
+ nudged_max);
+ break;
+ case loco::DataType::S16:
+ symmetric_wquant_with_minmax_per_layer(node, min, max, scaling_factor, zp, nudged_min,
+ nudged_max);
+ break;
+ default:
+ throw std::runtime_error("Unsupported data type");
+ }
+
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->scale.push_back(scaling_factor);
+ quantparam->zerop.push_back(zp);
+ node->quantparam(std::move(quantparam));
+}
+
+// Check if the node is the bias of Conv2D, DepthwiseConv2D, FullyConnected, or TransposeConv layer
// If true, return <input, weight> pair of the successor node (used to quantize bias)
// If flase, return <nullptr, nullptr>
std::pair<loco::Node *, loco::Node *> get_input_weight_of_bias(CircleNode *node)
@@ -68,6 +160,13 @@ std::pair<loco::Node *, loco::Node *> get_input_weight_of_bias(CircleNode *node)
assert(fc->weights() != nullptr);
return std::make_pair(fc->input(), fc->weights());
}
+ auto tconv = dynamic_cast<CircleTransposeConv *>(out);
+ if (tconv != nullptr && tconv->bias() == circle_const)
+ {
+ assert(tconv->outBackprop() != nullptr);
+ assert(tconv->filter() != nullptr);
+ return std::make_pair(tconv->outBackprop(), tconv->filter());
+ }
}
return std::make_pair(nullptr, nullptr);
}
@@ -514,8 +613,171 @@ struct QuantizeWeights final : public luci::CircleNodeMutableVisitor<bool>
}
};
+/**
+ * @brief Quantize const input tensors using min/max of const values
+ */
+void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type)
+{
+ auto opcode = node->opcode();
+ auto arity = node->arity();
+
+ loco::Node *input_node{nullptr};
+ luci::CircleConst *const_node{nullptr};
+
+ switch (opcode)
+ {
+ case luci::CircleOpcode::CONV_2D:
+ case luci::CircleOpcode::DEPTHWISE_CONV_2D:
+ case luci::CircleOpcode::FULLY_CONNECTED:
+ case luci::CircleOpcode::TRANSPOSE_CONV:
+ // Handled in QuantizeWeights and QuantizeBias
+ break;
+
+ case luci::CircleOpcode::CONCATENATION:
+ // Handled in propagate_concat_quantparam
+ break;
+
+ case luci::CircleOpcode::ARG_MAX:
+ case luci::CircleOpcode::ARG_MIN:
+ case luci::CircleOpcode::MEAN:
+ case luci::CircleOpcode::PAD:
+ case luci::CircleOpcode::REDUCE_ANY:
+ case luci::CircleOpcode::REDUCE_PROD:
+ case luci::CircleOpcode::REDUCE_MAX:
+ case luci::CircleOpcode::REDUCE_MIN:
+ case luci::CircleOpcode::RESHAPE:
+ case luci::CircleOpcode::SUM:
+ // The second input of these Ops should not be quantized
+ // Ex: axis, paddings
+ input_node = node->arg(0);
+ const_node = dynamic_cast<luci::CircleConst *>(input_node);
+ if (const_node != nullptr)
+ quant_const(const_node, output_type);
+ break;
+
+ case luci::CircleOpcode::ADD:
+ case luci::CircleOpcode::ADD_N:
+ case luci::CircleOpcode::DIV:
+ case luci::CircleOpcode::EQUAL:
+ case luci::CircleOpcode::GREATER:
+ case luci::CircleOpcode::GREATER_EQUAL:
+ case luci::CircleOpcode::LESS:
+ case luci::CircleOpcode::LESS_EQUAL:
+ case luci::CircleOpcode::MAXIMUM:
+ case luci::CircleOpcode::MINIMUM:
+ case luci::CircleOpcode::MUL:
+ case luci::CircleOpcode::NOT_EQUAL:
+ case luci::CircleOpcode::PRELU:
+ case luci::CircleOpcode::SUB:
+ // Quantize all const inputs using their values
+ for (uint32_t i = 0; i < arity; i++)
+ {
+ input_node = node->arg(i);
+ const_node = dynamic_cast<luci::CircleConst *>(input_node);
+ if (const_node != nullptr)
+ quant_const(const_node, output_type);
+ }
+ break;
+
+ default:
+ for (uint32_t i = 0; i < arity; i++)
+ {
+ input_node = node->arg(i);
+ const_node = dynamic_cast<luci::CircleConst *>(input_node);
+ if (const_node != nullptr)
+ throw std::runtime_error("Unsupported Op for const inputs");
+ }
+ break;
+ }
+}
+
} // namespace
+/** BEFORE
+ *
+ * [CircleNode] [CircleConst]
+ * (U8 qparam1) (FP32)
+ * \ /
+ * \ /
+ * [CircleConcatenation]
+ * (U8 qparam2)
+ *
+ * AFTER
+ * [CircleNode] [CircleConst]
+ * (U8 qparam2) (U8 qparam2)
+ * \ /
+ * \ /
+ * [CircleConcatenation]
+ * (U8 qparam2)
+ */
+void propagate_concat_quantparam(luci::CircleConcatenation *concat, loco::DataType quant_type)
+{
+ assert(concat->quantparam() != nullptr);
+
+ const auto num_inputs = concat->numValues();
+
+ // Quantize const inputs using their values if concat has fused act function
+ if (concat->fusedActivationFunction() != luci::FusedActFunc::NONE)
+ {
+ for (uint32_t i = 0; i < num_inputs; i++)
+ {
+ auto node = concat->arg(i);
+ auto const_node = dynamic_cast<luci::CircleConst *>(node);
+ if (const_node != nullptr)
+ quant_const(const_node, quant_type);
+ }
+ return;
+ }
+
+ for (uint32_t i = 0; i < num_inputs; i++)
+ {
+ auto node = loco::must_cast<luci::CircleNode *>(concat->arg(i));
+
+ // Skip if this input is CONCAT Op
+ if (node->opcode() == luci::CircleOpcode::CONCATENATION)
+ continue;
+
+ // Skip if this input is used by other Ops
+ auto succs = loco::succs(node);
+ if (succs.size() != 1)
+ {
+ if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
+ {
+ luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
+ quant_const(const_node, quant_type);
+ }
+ continue;
+ }
+
+ assert(succs.find(concat) != succs.end());
+
+ // Quantize constant values
+ if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
+ {
+ luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
+ if (const_node->dtype() != loco::DataType::FLOAT32)
+ throw std::runtime_error("Unsupported data type for constant input of concatenation Op");
+
+ const auto concat_qparam = concat->quantparam();
+ if (concat_qparam == nullptr)
+ throw std::runtime_error("quantparam of concat is not found during propagation");
+
+ assert(concat_qparam->scale.size() == 1);
+ const auto scaling_factor = concat_qparam->scale[0];
+ const auto zerop = concat_qparam->zerop[0];
+
+ quant_const_values(const_node, scaling_factor, zerop, quant_type);
+ }
+ else
+ {
+ // Non-const input must have been quantized
+ assert(node->quantparam() != nullptr);
+ }
+
+ overwrite_quantparam(concat, node);
+ }
+}
+
bool QuantizeWithMinMaxPass::run(loco::Graph *g)
{
LOGGER(l);
@@ -538,11 +800,37 @@ bool QuantizeWithMinMaxPass::run(loco::Graph *g)
}
// Quantize bias
+ // (For int16 quantization, bias is not quantized)
+ if (_output_dtype == loco::DataType::U8)
+ {
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ QuantizeBias qb(_input_dtype, _output_dtype, _granularity);
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ circle_node->accept(&qb);
+ }
+ }
+
+ // Quantize const inputs other than weights and bias
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
- QuantizeBias qb(_input_dtype, _output_dtype, _granularity);
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- circle_node->accept(&qb);
+ quantize_const_inputs(circle_node, _output_dtype);
+ }
+
+ // Propagate quantization parameters of concat Op
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto concat = dynamic_cast<luci::CircleConcatenation *>(node);
+ if (not concat)
+ continue;
+
+ // Propagate qparam of concat to its inputs if
+ // (1) concat is uint8-quantized
+ // (2) concat has no fused activation function
+ // (3) the input is not concatenation Op
+ // (4) the input is not produced to Ops other than concat
+ propagate_concat_quantparam(concat, _output_dtype);
}
// Update output dtype
diff --git a/compiler/luci/pass/src/RequantizePass.cpp b/compiler/luci/pass/src/RequantizePass.cpp
index 49fbf76ec..fe84e3bc3 100644
--- a/compiler/luci/pass/src/RequantizePass.cpp
+++ b/compiler/luci/pass/src/RequantizePass.cpp
@@ -56,7 +56,9 @@ bool is_bias(CircleConst *node)
if (fc != nullptr && fc->bias() == node)
return true;
- // TODO: add TransposeConv when bias is supported in CircleTransposeConv
+ auto tconv = dynamic_cast<CircleTransposeConv *>(out);
+ if (tconv != nullptr && tconv->bias() == node)
+ return true;
}
return false;
}
diff --git a/compiler/luci/pass/src/Sparsifier.cpp b/compiler/luci/pass/src/Sparsifier.cpp
new file mode 100644
index 000000000..2aa542f15
--- /dev/null
+++ b/compiler/luci/pass/src/Sparsifier.cpp
@@ -0,0 +1,229 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Sparsifier.h"
+
+namespace luci
+{
+
+template <typename T>
+Sparsifier<T>::Sparsifier(const std::vector<int32_t> &shape,
+ const std::vector<int32_t> &traversal_order,
+ const std::vector<DimensionType> &format,
+ const std::vector<int32_t> &block_size,
+ const std::vector<int32_t> &block_map)
+ : _dense_shape(shape), _traversal_order(traversal_order), _block_size(block_size),
+ _block_map(block_map)
+{
+ _dense_size = 1;
+ int32_t block_dim = 0;
+ _blocked_shape.resize(shape.size());
+ _format.resize(shape.size() + block_map.size());
+ for (int32_t i = 0; i < static_cast<int32_t>(shape.size()); i++)
+ {
+ _format[i] = format[traversal_order[i]];
+ _dense_size *= shape[i];
+ if (block_dim < static_cast<int32_t>(block_map.size()) && block_map[block_dim] == i)
+ {
+ _blocked_shape[i] = shape[i] / block_size[block_dim];
+ block_dim++;
+ }
+ else
+ {
+ _blocked_shape[i] = shape[i];
+ }
+ }
+
+ // Only dense blocks are supported.
+ for (uint32_t i = 0; i < block_map.size(); i++)
+ {
+ _format[i + shape.size()] = DimensionType::DENSE;
+ }
+}
+
+template <typename T> void Sparsifier<T>::DenseToSparse(const T *src_data)
+{
+ int num_original_dims = _dense_shape.size();
+ int num_block_dims = _block_map.size();
+ int num_expanded_dims = num_original_dims + num_block_dims;
+ std::vector<int> expanded_shape(num_expanded_dims);
+ for (int i = 0; i < num_expanded_dims; i++)
+ {
+ if (i < num_original_dims)
+ {
+ expanded_shape[i] = _blocked_shape[i];
+ }
+ else
+ {
+ expanded_shape[i] = _block_size[i - num_original_dims];
+ }
+ }
+
+ std::vector<int> shape_offset(num_original_dims);
+ shape_offset[shape_offset.size() - 1] = 1;
+ for (int i = num_original_dims - 1; i > 0; --i)
+ {
+ shape_offset[i - 1] = shape_offset[i] * _dense_shape[i];
+ }
+
+ std::vector<int> expanded_shape_offset(num_expanded_dims);
+ for (int i = 0; i < num_original_dims; ++i)
+ {
+ expanded_shape_offset[i] = shape_offset[i];
+ }
+ for (int i = 0; i < num_block_dims; ++i)
+ {
+ int mapped_dim = _block_map[i];
+ expanded_shape_offset[num_original_dims + i] = shape_offset[mapped_dim];
+ expanded_shape_offset[mapped_dim] *= _block_size[i];
+ }
+
+ std::vector<int> dst_ordered_offset(num_expanded_dims);
+ for (int i = 0; i < num_expanded_dims; ++i)
+ {
+ dst_ordered_offset[i] = expanded_shape_offset[_traversal_order[i]];
+ }
+
+ std::vector<bool> dst_dim_has_nonzeroes(num_expanded_dims);
+ std::fill(dst_dim_has_nonzeroes.begin(), dst_dim_has_nonzeroes.end(), false);
+ std::vector<int> inner_compressed_dim(num_expanded_dims);
+ int most_recent_compressed_dim = -1;
+ std::vector<int> num_segments_of_next_compressed_dim(num_expanded_dims);
+ int segment_count = 1;
+ for (int i = num_expanded_dims - 1; i >= 0; --i)
+ {
+ inner_compressed_dim[i] = most_recent_compressed_dim;
+ if (_format[i] == DimensionType::SPARSE_CSR)
+ {
+ most_recent_compressed_dim = i;
+ num_segments_of_next_compressed_dim[i] = segment_count;
+ segment_count = 1;
+ }
+ else
+ {
+ num_segments_of_next_compressed_dim[i] = -1;
+ segment_count *= expanded_shape[_traversal_order[i]];
+ }
+ }
+
+ _dim_metadata.resize(num_expanded_dims * 2);
+ std::vector<int> dst_sparse_dims;
+ dst_sparse_dims.reserve(num_expanded_dims);
+ for (int i = 0; i < num_expanded_dims; ++i)
+ {
+ _dim_metadata[i * 2].clear();
+ _dim_metadata[i * 2 + 1].clear();
+ if (_format[i] == DimensionType::DENSE)
+ {
+ // If dimension is dense, just store the shape.
+ _dim_metadata[i * 2].push_back(expanded_shape[_traversal_order[i]]);
+ }
+ else
+ {
+ _dim_metadata[i * 2].push_back(0); // Segment array always begins with 0.
+ dst_sparse_dims.push_back(i); // Add dimension to the sparse list.
+ }
+ }
+
+ // This algorithm assumes that the block size is small enough for all the
+ // elements to fit in cache, so the strided accesses from different traversal
+ // order and the write-first-erase-later strategy shouldn't be too slow
+ int dst_dim_idx = num_expanded_dims;
+ std::vector<int> coordinate(num_expanded_dims, 0);
+ int dense_tensor_idx = 0;
+ while (dst_dim_idx >= 0)
+ {
+ if (dst_dim_idx == num_expanded_dims)
+ {
+ // We have a complete coordinate. Add the element to the value array if it
+ // is not zero, or if the last dimension is dense.
+ if (!IsZero(src_data[dense_tensor_idx]))
+ {
+ _data.push_back(src_data[dense_tensor_idx]);
+ // Mark all sparse dimensions that their current indices have nonzeroes.
+ for (auto dst_dim : dst_sparse_dims)
+ {
+ if (!dst_dim_has_nonzeroes[dst_dim])
+ {
+ // Only add the index to the indices array if the current nonzero
+ // is the first nonzero of the block.
+ _dim_metadata[2 * dst_dim + 1].push_back(coordinate[dst_dim]);
+ dst_dim_has_nonzeroes[dst_dim] = true;
+ }
+ }
+ }
+ else if (_format[num_expanded_dims - 1] == DimensionType::DENSE)
+ {
+ _data.push_back(src_data[dense_tensor_idx]);
+ }
+ --dst_dim_idx;
+ }
+ else
+ {
+ int original_dim_idx = _traversal_order[dst_dim_idx];
+ int dim_size = expanded_shape[original_dim_idx];
+ if (dst_dim_has_nonzeroes[dst_dim_idx])
+ {
+ // If the previous block has nonzeroes, reset the flag to false since
+ // we have just moved to a new block.
+ dst_dim_has_nonzeroes[dst_dim_idx] = false;
+ }
+ else if (_format[dst_dim_idx] == DimensionType::SPARSE_CSR)
+ {
+ // This block is empty. Delete unnecessary values if compressed.
+ int next_compressed_dim = inner_compressed_dim[dst_dim_idx];
+ int erase_offset = _dim_metadata[2 * dst_dim_idx + 1].size() *
+ num_segments_of_next_compressed_dim[dst_dim_idx];
+ if (next_compressed_dim >= 0)
+ {
+ auto &segments = _dim_metadata[2 * inner_compressed_dim[dst_dim_idx]];
+ segments.erase(segments.begin() + 1 + erase_offset, segments.end());
+ }
+ else
+ {
+ _data.erase(_data.begin() + erase_offset, _data.end());
+ }
+ }
+ if (++coordinate[dst_dim_idx] < dim_size)
+ {
+ // The current dst_dim_idx is valid (not out of bound).
+ dense_tensor_idx += dst_ordered_offset[dst_dim_idx];
+ ++dst_dim_idx;
+ }
+ else
+ {
+ // dst_dim_idx has reached its dim size. Update segment array and go
+ // back to incrementing the previous dimension (dst_dim_idx - 1).
+ if (_format[dst_dim_idx] == DimensionType::SPARSE_CSR)
+ {
+ _dim_metadata[2 * dst_dim_idx].push_back(_dim_metadata[2 * dst_dim_idx + 1].size());
+ }
+ coordinate[dst_dim_idx] = -1;
+ dense_tensor_idx -= dst_ordered_offset[dst_dim_idx] * dim_size;
+ --dst_dim_idx;
+ }
+ }
+ }
+}
+
+template <typename T> bool Sparsifier<T>::IsZero(const T val) { return (val == 0); }
+
+template class Sparsifier<int32_t>;
+template class Sparsifier<int8_t>;
+template class Sparsifier<float>;
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/Sparsifier.h b/compiler/luci/pass/src/Sparsifier.h
new file mode 100644
index 000000000..71ea28da9
--- /dev/null
+++ b/compiler/luci/pass/src/Sparsifier.h
@@ -0,0 +1,87 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_SPARSIFIER_H__
+#define __LUCI_SPARSIFIER_H__
+
+#include <vector>
+
+#include <luci/IR/SparsityParam.h>
+
+namespace luci
+{
+
+template <typename T> class Sparsifier
+{
+public:
+ /*
+ * Creates a dense to sparse converter.
+ * @param shape Shape of the dense tensor.
+ * @param traversal_order In what order to traverse all dimensions,
+ * including block dimensions.
+ * @param format Whether each dimension in converted tensor is
+ * dense or sparse (not in the traversal order).
+ * @param block_size Size of each block dimension.
+ * @param block_map Map from block dimension to original tensor
+ * dimension.
+ */
+ Sparsifier(const std::vector<int> &shape, const std::vector<int> &traversal_order,
+ const std::vector<DimensionType> &format, const std::vector<int> &block_size = {},
+ const std::vector<int> &block_map = {});
+
+ std::vector<T> GetData() { return _data; }
+ std::vector<std::vector<int>> GetDimMetadata() { return _dim_metadata; }
+
+ void DenseToSparse(const T *src_data);
+
+private:
+ // Check if val is equal to zero.
+ bool IsZero(const T val);
+
+ // Shape of the conceptual dense tensor.
+ std::vector<int> _dense_shape;
+ // Shape of the dense tensor with inner blocks reduced. For example, a (4, 4)
+ // tensor with (2, 2) block has blocked_shape (2, 2).
+ std::vector<int> _blocked_shape;
+ // Total number of elements in the dense tensor.
+ uint64_t _dense_size;
+ // Has n(original dimension)+k(block_dimension) elements.
+ std::vector<int> _traversal_order;
+ // Format of each dimension in the traversal order.
+ std::vector<DimensionType> _format;
+ // Size of each block dimension, in the same order as block map.
+ std::vector<int> _block_size;
+ // Map from block dimension to the original tensor dimension.
+ std::vector<int> _block_map;
+ // Metadata of each dimension in the traversal order.
+ // Each dimension needs two vectors. For dense dimensions, the first vector
+ // stores the size of that dimension, and the second vector is empty. For
+ // sparse dimensions, the first vector stores the segments and the second one
+ // stores the indices.
+ std::vector<std::vector<int>> _dim_metadata;
+ // Actual buffer holding data after conversion. Could be sparse buffer or
+ // dense buffer.
+ std::vector<T> _data;
+};
+
+extern template class Sparsifier<int32_t>;
+extern template class Sparsifier<int8_t>;
+extern template class Sparsifier<float>;
+
+} // namespace luci
+
+#endif // __LUCI_SPARSIFIER_H__
diff --git a/compiler/luci/pass/src/SparsifyTensorPass.cpp b/compiler/luci/pass/src/SparsifyTensorPass.cpp
new file mode 100644
index 000000000..2f1a36e77
--- /dev/null
+++ b/compiler/luci/pass/src/SparsifyTensorPass.cpp
@@ -0,0 +1,123 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/SparsifyTensorPass.h"
+
+#include "Sparsifier.h"
+
+#include <luci/IR/CircleNodes.h>
+
+namespace luci
+{
+
+template <loco::DataType DT> void SparsifyTensorPass::sparsify_tensor(luci::CircleConst *cop)
+{
+ using PRIMITIVE_DTYPE = typename loco::DataTypeImpl<DT>::Type;
+
+ std::vector<int32_t> dense_tensor_shape(cop->rank());
+ for (uint32_t d = 0; d < cop->rank(); d++)
+ {
+ dense_tensor_shape.at(d) = cop->dim(d).value();
+ }
+
+ Sparsifier<PRIMITIVE_DTYPE> sparsifier(dense_tensor_shape, _traversal_order, _format, _block_size,
+ _block_map);
+ // get dense tensor data
+ uint32_t dense_tensor_data_size = cop->size<DT>();
+ std::vector<PRIMITIVE_DTYPE> dense_tensor_data(dense_tensor_data_size);
+ for (uint32_t i = 0; i < dense_tensor_data_size; i++)
+ {
+ dense_tensor_data.at(i) = cop->at<DT>(i);
+ }
+ // sparsify
+ sparsifier.DenseToSparse(dense_tensor_data.data());
+ // get sparse tensor data
+ std::vector<PRIMITIVE_DTYPE> sparse_tensor_data = sparsifier.GetData();
+ uint32_t sparse_tensor_data_size = sparse_tensor_data.size();
+ cop->size<DT>(sparse_tensor_data_size);
+ for (uint32_t i = 0; i < sparse_tensor_data_size; i++)
+ {
+ cop->at<DT>(i) = sparse_tensor_data.at(i);
+ }
+ // make sparsity parameter
+ auto sparsityparam = std::make_unique<SparsityParam>();
+ sparsityparam->traversal_order = _traversal_order;
+ sparsityparam->block_map = _block_map;
+ // get dimension meta data
+ const auto dim_metadata = sparsifier.GetDimMetadata();
+ for (uint32_t idx = 0; idx < _format.size(); idx++)
+ {
+ if (_format.at(idx) == DimensionType::DENSE)
+ {
+ sparsityparam->dim_metadata.emplace_back(DimensionType::DENSE,
+ dim_metadata.at(idx * 2).at(0));
+ }
+ // TODO Set SparseIndexVectorType according to its data range
+ else if (_format.at(idx) == DimensionType::SPARSE_CSR)
+ {
+ sparsityparam->dim_metadata.emplace_back(
+ DimensionType::SPARSE_CSR, /* dense size */ 0,
+ /* array_segments */ SparseIndexVector{SparseIndexVectorType::U16,
+ dim_metadata.at(idx * 2)},
+ /* array_indices */ SparseIndexVector{SparseIndexVectorType::U16,
+ dim_metadata.at(idx * 2 + 1)});
+ }
+ }
+ for (uint32_t i = 0; i < _block_size.size(); i++)
+ {
+ assert(_block_size.at(i) == dim_metadata.at((_format.size() + i) * 2).at(0));
+ sparsityparam->dim_metadata.emplace_back(DimensionType::DENSE, _block_size.at(i));
+ }
+ cop->sparsityparam(std::move(sparsityparam));
+}
+
+bool SparsifyTensorPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto cop = dynamic_cast<luci::CircleConst *>(node);
+ if (not cop)
+ continue;
+
+ if (cop->name() != _tensor_name)
+ continue;
+
+ switch (cop->dtype())
+ {
+ case loco::DataType::S32:
+ sparsify_tensor<loco::DataType::S32>(cop);
+ break;
+ case loco::DataType::S8:
+ sparsify_tensor<loco::DataType::S8>(cop);
+ break;
+ case loco::DataType::FLOAT32:
+ sparsify_tensor<loco::DataType::FLOAT32>(cop);
+ break;
+ default:
+ throw std::runtime_error("SparsifyTensorPass: Unsupported dtype.");
+ }
+ changed = true;
+ }
+
+ return changed;
+}
+
+template void SparsifyTensorPass::sparsify_tensor<loco::DataType::S32>(luci::CircleConst *cop);
+template void SparsifyTensorPass::sparsify_tensor<loco::DataType::S8>(luci::CircleConst *cop);
+template void SparsifyTensorPass::sparsify_tensor<loco::DataType::FLOAT32>(luci::CircleConst *cop);
+
+} // namespace luci