summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src
diff options
context:
space:
mode:
authorChunseok Lee <chunseok.lee@samsung.com>2022-04-15 19:15:11 +0900
committerChunseok Lee <chunseok.lee@samsung.com>2022-04-15 19:15:11 +0900
commit3ad689f0803519e343c36d5700646e86059df961 (patch)
tree862346c401a5577518fa7f042532aa931b53aa0e /compiler/luci/pass/src
parentac6e4dd7b480e83b586ef533d7b29a8a97eb48fe (diff)
downloadnnfw-3ad689f0803519e343c36d5700646e86059df961.tar.gz
nnfw-3ad689f0803519e343c36d5700646e86059df961.tar.bz2
nnfw-3ad689f0803519e343c36d5700646e86059df961.zip
Imported Upstream version 1.20.0upstream/1.20.0submit/tizen/20220415.103159
Diffstat (limited to 'compiler/luci/pass/src')
-rw-r--r--compiler/luci/pass/src/BatchNormPatternFinder.cpp40
-rw-r--r--compiler/luci/pass/src/BatchNormPatternFinder.test.cpp107
-rw-r--r--compiler/luci/pass/src/CircleOptimizer.cpp224
-rw-r--r--compiler/luci/pass/src/CircleOptimizer.test.cpp168
-rw-r--r--compiler/luci/pass/src/CircleQuantizer.cpp458
-rw-r--r--compiler/luci/pass/src/CircleQuantizer.test.cpp191
-rw-r--r--compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp6
-rw-r--r--compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp36
-rw-r--r--compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp214
-rw-r--r--compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.test.cpp277
-rw-r--r--compiler/luci/pass/src/CopyQuantParamPass.cpp82
-rw-r--r--compiler/luci/pass/src/FoldGatherPass.cpp185
-rw-r--r--compiler/luci/pass/src/FoldGatherPass.test.cpp214
-rw-r--r--compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp36
-rw-r--r--compiler/luci/pass/src/PropagateQParamBackwardPass.cpp482
-rw-r--r--compiler/luci/pass/src/PropagateQParamBackwardPass.test.cpp167
-rw-r--r--compiler/luci/pass/src/PropagateQParamForwardPass.cpp194
-rw-r--r--compiler/luci/pass/src/PropagateQParamForwardPass.test.cpp260
-rw-r--r--compiler/luci/pass/src/PropagateQuantParamPass.cpp107
-rw-r--r--compiler/luci/pass/src/PropagateQuantParamPass.test.cpp125
-rw-r--r--compiler/luci/pass/src/QuantizationUtils.cpp158
-rw-r--r--compiler/luci/pass/src/QuantizationUtils.h36
-rw-r--r--compiler/luci/pass/src/QuantizeActivation.cpp296
-rw-r--r--compiler/luci/pass/src/QuantizeActivation.h165
-rw-r--r--compiler/luci/pass/src/QuantizeBias.cpp300
-rw-r--r--compiler/luci/pass/src/QuantizeBias.h56
-rw-r--r--compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp259
-rw-r--r--compiler/luci/pass/src/QuantizeDequantizeWeightsPass.test.cpp14
-rw-r--r--compiler/luci/pass/src/QuantizePreCheckerPass.cpp119
-rw-r--r--compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp401
-rw-r--r--compiler/luci/pass/src/QuantizeWeights.cpp394
-rw-r--r--compiler/luci/pass/src/QuantizeWeights.h55
-rw-r--r--compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp1773
-rw-r--r--compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp49
-rw-r--r--compiler/luci/pass/src/QuantizedModelVerifier.cpp70
-rw-r--r--compiler/luci/pass/src/QuantizedModelVerifier.h30
-rw-r--r--compiler/luci/pass/src/QuantizedModelVerifier.test.cpp497
-rw-r--r--compiler/luci/pass/src/RemoveRedundantQuantizePass.cpp104
-rw-r--r--compiler/luci/pass/src/RemoveRedundantQuantizePass.test.cpp166
-rw-r--r--compiler/luci/pass/src/RemoveRedundantTransposePass.cpp2
-rw-r--r--compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp25
-rw-r--r--compiler/luci/pass/src/RemoveUnnecessaryReshapePass.cpp19
-rw-r--r--compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp26
-rw-r--r--compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp46
-rw-r--r--compiler/luci/pass/src/SubstituteSplitVToSplitPass.cpp13
-rw-r--r--compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp14
-rw-r--r--compiler/luci/pass/src/SubstituteStridedSliceToReshapePass.cpp2
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedBiasScale.cpp105
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedBiasScale.h59
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeGranularity.cpp38
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h (renamed from compiler/luci/pass/src/VerifyQuantizedNodeChannelWiseGranularity.h)301
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeLayerWiseGranularity.h473
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeS16Type.h516
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeType.cpp554
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeType.h157
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h518
-rw-r--r--compiler/luci/pass/src/helpers/LayerInfoMap.cpp189
-rw-r--r--compiler/luci/pass/src/helpers/LayerInfoMap.h33
-rw-r--r--compiler/luci/pass/src/helpers/LayerInfoMap.test.cpp201
59 files changed, 7907 insertions, 3899 deletions
diff --git a/compiler/luci/pass/src/BatchNormPatternFinder.cpp b/compiler/luci/pass/src/BatchNormPatternFinder.cpp
index c1a06bfda..e3f126b15 100644
--- a/compiler/luci/pass/src/BatchNormPatternFinder.cpp
+++ b/compiler/luci/pass/src/BatchNormPatternFinder.cpp
@@ -44,10 +44,26 @@ bool is_batchnorm_add(const luci::CircleAdd *add, luci::CircleMul *&mul, luci::C
return false;
}
- if (constant->rank() != 1)
+ uint32_t channel_dim = 0;
+
+ if (constant->rank() == 1)
+ {
+ channel_dim = constant->dim(0).value();
+ }
+ else if (constant->rank() == 4)
+ {
+ for (uint32_t i = 0; i < 3; i++)
+ {
+ if (constant->dim(i).value() != 1)
+ return false;
+ }
+ channel_dim = constant->dim(3).value();
+ }
+ else
+ {
return false;
+ }
- auto channel_dim = constant->dim(0);
// Assumption: Layout is channel-last
if (!(channel_dim == add->dim(add->rank() - 1)))
return false;
@@ -90,10 +106,26 @@ bool is_batchnorm_mul(const luci::CircleMul *mul, luci::CircleNode *&pred_node,
return false;
}
- if (constant->rank() != 1)
+ uint32_t channel_dim = 0;
+
+ if (constant->rank() == 1)
+ {
+ channel_dim = constant->dim(0).value();
+ }
+ else if (constant->rank() == 4)
+ {
+ for (uint32_t i = 0; i < 3; i++)
+ {
+ if (constant->dim(i).value() != 1)
+ return false;
+ }
+ channel_dim = constant->dim(3).value();
+ }
+ else
+ {
return false;
+ }
- auto channel_dim = constant->dim(0);
// Assumption: Layout is channel-last
if (!(channel_dim == mul->dim(mul->rank() - 1)))
return false;
diff --git a/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp b/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp
index 08e7fac1c..cc8c5615f 100644
--- a/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp
+++ b/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp
@@ -50,7 +50,7 @@ public:
auto channel_size = *last_it;
_add->shape(shape);
- _add_beta->shape({channel_size});
+ set_beta_shape(channel_size);
_add_beta->size<loco::DataType::FLOAT32>(channel_size);
for (uint32_t i = 0; i < channel_size; i++)
_add_beta->at<loco::DataType::FLOAT32>(i) = i;
@@ -63,10 +63,23 @@ public:
luci::CircleAdd *add() { return _add; }
protected:
+ virtual void set_beta_shape(uint32_t channel) = 0;
+
+protected:
luci::CircleAdd *_add = nullptr;
luci::CircleConst *_add_beta = nullptr;
};
+class AddRank1BetaGraphlet : public AddBetaGraphlet
+{
+ void set_beta_shape(uint32_t channel) final { _add_beta->shape({channel}); }
+};
+
+class AddRank4BetaGraphlet : public AddBetaGraphlet
+{
+ void set_beta_shape(uint32_t channel) final { _add_beta->shape({1, 1, 1, channel}); }
+};
+
/**
* @brief Graphlet with Mul and Const as gamma from BatchNorm
*/
@@ -90,7 +103,7 @@ public:
auto channel_size = *last_it;
_mul->shape(shape);
- _mul_gamma->shape({channel_size});
+ set_gamma_shape(channel_size);
_mul_gamma->size<loco::DataType::FLOAT32>(channel_size);
for (uint32_t i = 0; i < channel_size; i++)
_mul_gamma->at<loco::DataType::FLOAT32>(i) = i;
@@ -103,14 +116,27 @@ public:
luci::CircleMul *mul(void) { return _mul; }
protected:
+ virtual void set_gamma_shape(uint32_t channel) = 0;
+
+protected:
luci::CircleMul *_mul = nullptr;
luci::CircleConst *_mul_gamma = nullptr;
};
+class MulRank1GammaGraphlet : public MulGammaGraphlet
+{
+ void set_gamma_shape(uint32_t channel) final { _mul_gamma->shape({channel}); }
+};
+
+class MulRank4GammaGraphlet : public MulGammaGraphlet
+{
+ void set_gamma_shape(uint32_t channel) final { _mul_gamma->shape({1, 1, 1, channel}); }
+};
+
/**
* @brief Graph of Mul-Add pattern from BatchNorm
*/
-class MulAddGraph : public TestIOGraph, public AddBetaGraphlet, public MulGammaGraphlet
+class MulAddGraph : public TestIOGraph, public AddRank1BetaGraphlet, public MulRank1GammaGraphlet
{
public:
MulAddGraph() = default;
@@ -118,8 +144,30 @@ public:
void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
{
TestIOGraph::init(shape_in, shape_out);
- MulGammaGraphlet::init(g(), shape_in, luci::FusedActFunc::NONE);
- AddBetaGraphlet::init(g(), shape_out, luci::FusedActFunc::RELU);
+ MulRank1GammaGraphlet::init(g(), shape_in, luci::FusedActFunc::NONE);
+ AddRank1BetaGraphlet::init(g(), shape_out, luci::FusedActFunc::RELU);
+
+ // connect network
+ _mul->x(input());
+ _mul->y(_mul_gamma);
+ _add->x(_mul);
+ _add->y(_add_beta);
+ output()->from(_add);
+ }
+};
+
+class MulAddRank4Graph : public TestIOGraph,
+ public AddRank4BetaGraphlet,
+ public MulRank4GammaGraphlet
+{
+public:
+ MulAddRank4Graph() = default;
+
+ void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
+ {
+ TestIOGraph::init(shape_in, shape_out);
+ MulRank4GammaGraphlet::init(g(), shape_in, luci::FusedActFunc::NONE);
+ AddRank4BetaGraphlet::init(g(), shape_out, luci::FusedActFunc::RELU);
// connect network
_mul->x(input());
@@ -133,7 +181,7 @@ public:
/**
* @brief Graph of Add with Const
*/
-class AddGraph : public TestIOGraph, public AddBetaGraphlet
+class AddGraph : public TestIOGraph, public AddRank1BetaGraphlet
{
public:
AddGraph() = default;
@@ -141,7 +189,24 @@ public:
void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
{
TestIOGraph::init(shape_in, shape_out);
- AddBetaGraphlet::init(g(), shape_in, luci::FusedActFunc::RELU);
+ AddRank1BetaGraphlet::init(g(), shape_in, luci::FusedActFunc::RELU);
+
+ // connect network
+ _add->x(input());
+ _add->y(_add_beta);
+ output()->from(_add);
+ }
+};
+
+class AddRank4Graph : public TestIOGraph, public AddRank4BetaGraphlet
+{
+public:
+ AddRank4Graph() = default;
+
+ void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
+ {
+ TestIOGraph::init(shape_in, shape_out);
+ AddRank4BetaGraphlet::init(g(), shape_in, luci::FusedActFunc::RELU);
// connect network
_add->x(input());
@@ -160,6 +225,7 @@ public:
protected:
luci::test::MulAddGraph _mag;
+ luci::test::MulAddRank4Graph _mag_r4;
};
class BatchNormPatternFinderAddTest : public ::testing::Test
@@ -169,6 +235,7 @@ public:
protected:
luci::test::AddGraph _ag;
+ luci::test::AddRank4Graph _ag_r4;
};
TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_add)
@@ -192,6 +259,19 @@ TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_add2)
ASSERT_TRUE(res);
}
+TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_add_rank4)
+{
+ _mag_r4.init({1, 16, 16, 4}, {1, 16, 16, 4});
+
+ luci::CircleMul *mul = nullptr;
+ luci::CircleConst *beta = nullptr;
+
+ auto res = luci::is_batchnorm_add(_mag_r4.add(), mul, beta);
+ ASSERT_TRUE(res);
+ ASSERT_NE(nullptr, mul);
+ ASSERT_NE(nullptr, beta);
+}
+
TEST_F(BatchNormPatternFinderAddTest, is_batchnorm_add_NEG)
{
_ag.init({1, 16, 16, 4}, {1, 16, 16, 4});
@@ -215,3 +295,16 @@ TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_mul)
ASSERT_NE(nullptr, pred);
ASSERT_NE(nullptr, gamma);
}
+
+TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_mul_rank4)
+{
+ _mag_r4.init({1, 16, 16, 4}, {1, 16, 16, 4});
+
+ luci::CircleNode *pred = nullptr;
+ luci::CircleConst *gamma = nullptr;
+
+ auto res = luci::is_batchnorm_mul(_mag_r4.mul(), pred, gamma);
+ ASSERT_TRUE(res);
+ ASSERT_NE(nullptr, pred);
+ ASSERT_NE(nullptr, gamma);
+}
diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp
index 75f04b3b5..6dbb22d7c 100644
--- a/compiler/luci/pass/src/CircleOptimizer.cpp
+++ b/compiler/luci/pass/src/CircleOptimizer.cpp
@@ -22,9 +22,9 @@
#include "luci/Pass/FoldCastPass.h"
#include "luci/Pass/FoldDepthwiseConv2DPass.h"
#include "luci/Pass/FoldDequantizePass.h"
+#include "luci/Pass/FoldGatherPass.h"
#include "luci/Pass/FoldSparseToDensePass.h"
#include "luci/Pass/ForwardReshapeToUnaryOpPass.h"
-#include "luci/Pass/ForceQuantParamPass.h"
#include "luci/Pass/FuseActivationFunctionPass.h"
#include "luci/Pass/FuseAddWithFullyConnectedPass.h"
#include "luci/Pass/FuseAddWithTConvPass.h"
@@ -37,11 +37,11 @@
#include "luci/Pass/FusePreActivationBatchNormPass.h"
#include "luci/Pass/FuseTransposeWithMeanPass.h"
#include "luci/Pass/MakeBatchNormGammaPositivePass.h"
-#include "luci/Pass/PropagateQuantParamPass.h"
#include "luci/Pass/RemoveFakeQuantPass.h"
#include "luci/Pass/RemoveQuantDequantSeqPass.h"
#include "luci/Pass/RemoveRedundantReshapePass.h"
#include "luci/Pass/RemoveRedundantTransposePass.h"
+#include "luci/Pass/RemoveRedundantQuantizePass.h"
#include "luci/Pass/RemoveUnnecessaryReshapePass.h"
#include "luci/Pass/RemoveUnnecessarySlicePass.h"
#include "luci/Pass/RemoveUnnecessaryStridedSlicePass.h"
@@ -52,9 +52,6 @@
#include "luci/Pass/ResolveCustomOpBatchMatMulPass.h"
#include "luci/Pass/ResolveCustomOpMatMulPass.h"
#include "luci/Pass/ResolveCustomOpMaxPoolWithArgmaxPass.h"
-#include "luci/Pass/RequantizePass.h"
-#include "luci/Pass/QuantizeWithMinMaxPass.h"
-#include "luci/Pass/QuantizeDequantizeWeightsPass.h"
#include "luci/Pass/SparsifyTensorPass.h"
#include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h"
#include "luci/Pass/SubstitutePackToReshapePass.h"
@@ -75,9 +72,6 @@
#include "ModulePhase.h"
#include "ProgressReporter.h"
-#include "helpers/Strings.h"
-
-#include "QuantizedModelVerifier.h"
#include <luci/IR/CircleNodes.h>
#include <logo/Phase.h>
@@ -91,37 +85,17 @@ namespace
using namespace luci;
-template <typename T> T lexical_cast(const std::string &str)
-{
- std::istringstream ss;
- ss.str(str);
- T data;
- ss >> data;
- return data;
-}
-
-template <typename T> std::vector<T> lexical_cast(std::vector<std::string> &sv)
-{
- std::vector<T> result;
- std::transform(sv.begin(), sv.end(), std::back_inserter(result),
- [](std::string str) -> T { return lexical_cast<T>(str); });
- return result;
-}
-
class OptimizeOptionsImpl final : public luci::CircleOptimizer::Options
{
public:
void enable(Algorithm) final;
void param(AlgorithmParameters, const std::string &) final;
const std::string param(AlgorithmParameters) const final;
- void params(AlgorithmParameters, std::vector<std::string> &) final;
- std::vector<std::string> params(AlgorithmParameters) const final;
bool query(Algorithm) final;
private:
std::vector<Algorithm> _algorithms;
std::map<AlgorithmParameters, const std::string> _algorithm_params;
- std::map<AlgorithmParameters, std::vector<std::string>> _multiple_params;
};
void OptimizeOptionsImpl::enable(Algorithm algo) { _algorithms.push_back(algo); }
@@ -144,24 +118,6 @@ const std::string OptimizeOptionsImpl::param(AlgorithmParameters param) const
}
}
-void OptimizeOptionsImpl::params(AlgorithmParameters param, std::vector<std::string> &vec)
-{
- _multiple_params[param] = vec;
-}
-
-std::vector<std::string> OptimizeOptionsImpl::params(AlgorithmParameters param) const
-{
- auto param_vec = _multiple_params.find(param);
- if (param_vec != _multiple_params.end())
- {
- return param_vec->second;
- }
- else
- {
- return std::vector<std::string>();
- }
-}
-
bool OptimizeOptionsImpl::query(Algorithm algo)
{
std::vector<Algorithm>::iterator it = std::find(_algorithms.begin(), _algorithms.end(), algo);
@@ -312,6 +268,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<luci::FoldDequantizePass>());
}
+ if (_options->query(Options::Algorithm::FoldGather))
+ {
+ phase.emplace_back(std::make_unique<luci::FoldGatherPass>());
+ }
if (_options->query(Options::Algorithm::FoldSparseToDense))
{
phase.emplace_back(std::make_unique<luci::FoldSparseToDensePass>());
@@ -368,6 +328,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<luci::RemoveRedundantTransposePass>());
}
+ if (_options->query(Options::Algorithm::RemoveRedundantQuantize))
+ {
+ phase.emplace_back(std::make_unique<luci::RemoveRedundantQuantizePass>());
+ }
if (_options->query(Options::Algorithm::ReplaceMulAddWithDepthwiseConv))
{
phase.emplace_back(std::make_unique<luci::ReplaceMulAddWithDepthwiseConvPass>());
@@ -417,174 +381,6 @@ void CircleOptimizer::optimize(loco::Graph *g) const
phase_runner.run(phase);
}
-void CircleOptimizer::quantize(loco::Graph *g) const
-{
- // Fake quantization of weights
- if (_options->query(Options::Algorithm::QuantizeDequantizeWeights))
- {
- static const std::vector<std::string> fakeq_supported_input_model_dtype{"float32"};
- static const std::vector<std::string> fakeq_supported_output_model_dtype{"uint8", "int16"};
- static const std::vector<std::string> fakeq_supported_granularity{"layer", "channel"};
-
- auto input_model_dtype =
- _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
- auto output_model_dtype =
- _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
- auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
-
- if (!in_array(to_lower_case(input_model_dtype), fakeq_supported_input_model_dtype))
- throw std::runtime_error("Unsupported input type. List of supported input type: " +
- to_string(fakeq_supported_input_model_dtype));
-
- if (!in_array(to_lower_case(output_model_dtype), fakeq_supported_output_model_dtype))
- throw std::runtime_error("Unsupported output type. List of supported output type: " +
- to_string(fakeq_supported_output_model_dtype));
-
- if (!in_array(to_lower_case(granularity), fakeq_supported_granularity))
- throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
- to_string(fakeq_supported_granularity));
-
- if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise &&
- str_to_dtype(output_model_dtype) != loco::DataType::U8)
- throw std::runtime_error("Layer-wise quantization only supports uint8 dtype.");
-
- // Clear existing quantparams before doing fake quantization
- for (auto node : loco::active_nodes(loco::output_nodes(g)))
- {
- auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- if (circle_node->quantparam() != nullptr)
- circle_node->quantparam(nullptr);
- }
-
- luci::QuantizeDequantizeWeightsPass fake_quantizer(str_to_dtype(input_model_dtype),
- str_to_dtype(output_model_dtype),
- str_to_granularity(granularity));
- fake_quantizer.run(g);
- }
-
- // Actual quantization of weights, bias, and activation
- if (_options->query(Options::Algorithm::QuantizeWithMinMax))
- {
- static const std::vector<std::string> qwmm_supported_input_model_dtype{"float32"};
- static const std::vector<std::string> qwmm_supported_output_model_dtype{"uint8", "int16"};
- static const std::vector<std::string> qwmm_supported_granularity{"layer", "channel"};
- static const std::vector<std::string> qwmm_supported_input_type{"uint8", "int16"};
- static const std::vector<std::string> qwmm_supported_output_type{"uint8", "int16"};
-
- auto input_model_dtype =
- _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
- auto output_model_dtype =
- _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
- auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
- auto input_type = _options->param(Options::AlgorithmParameters::Quantize_input_type);
- if (input_type.empty())
- input_type = output_model_dtype;
- auto output_type = _options->param(Options::AlgorithmParameters::Quantize_output_type);
- if (output_type.empty())
- output_type = output_model_dtype;
-
- if (!in_array(to_lower_case(input_model_dtype), qwmm_supported_input_model_dtype))
- throw std::runtime_error("Unsupported input type. List of supported input types: " +
- to_string(qwmm_supported_input_model_dtype));
-
- if (!in_array(to_lower_case(output_model_dtype), qwmm_supported_output_model_dtype))
- throw std::runtime_error("Unsupported output type. List of supported output types: " +
- to_string(qwmm_supported_output_model_dtype));
-
- if (!in_array(to_lower_case(granularity), qwmm_supported_granularity))
- throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
- to_string(qwmm_supported_granularity));
-
- if (!in_array(to_lower_case(input_type), qwmm_supported_input_type))
- throw std::runtime_error("Unsupported input type. List of supported input types: " +
- to_string(qwmm_supported_input_type));
-
- if (!in_array(to_lower_case(output_type), qwmm_supported_output_type))
- throw std::runtime_error("Unsupported output type. List of supported output types: " +
- to_string(qwmm_supported_output_type));
-
- if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise &&
- str_to_dtype(output_model_dtype) != loco::DataType::U8)
- throw std::runtime_error("Layer-wise quantization only supports uint8 dtype.");
-
- luci::QuantizeWithMinMaxPass quantizer(
- str_to_dtype(input_model_dtype), str_to_dtype(output_model_dtype),
- str_to_granularity(granularity), str_to_dtype(input_type), str_to_dtype(output_type));
- quantizer.run(g);
-
- // Post-quantization optimizations
- logo::Phase phase;
-
- phase.emplace_back(std::make_unique<luci::PropagateQuantParamPass>());
-
- phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
- phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
- phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
-
- ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
- logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
- phase_runner.attach(&prog);
- phase_runner.run(phase);
-
- // Verify the type/granularity of the quantized model
- luci::QuantizedModelVerifier verifier(str_to_dtype(output_model_dtype),
- str_to_granularity(granularity));
- verifier.verify(g);
- }
-
- // Requantize
- if (_options->query(Options::Algorithm::Requantize))
- {
- static const std::vector<std::string> rq_supported_input_model_dtype{"int8"};
- static const std::vector<std::string> rq_supported_output_model_dtype{"uint8"};
-
- auto input_model_dtype =
- _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
- auto output_model_dtype =
- _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
-
- if (!in_array(to_lower_case(input_model_dtype), rq_supported_input_model_dtype))
- throw std::runtime_error("Unsupported input type. List of supported input types: " +
- to_string(rq_supported_input_model_dtype));
-
- if (!in_array(to_lower_case(output_model_dtype), rq_supported_output_model_dtype))
- throw std::runtime_error("Unsupported output type. List of supported output types: " +
- to_string(rq_supported_output_model_dtype));
-
- luci::RequantizePass requantizer(str_to_dtype(input_model_dtype),
- str_to_dtype(output_model_dtype));
- requantizer.run(g);
- }
-
- // Force to write quantparam to specified tensors
- // NOTE Only per-tensor (not per-channel) qparam can be written
- if (_options->query(Options::Algorithm::ForceQuantParam))
- {
- ForceQuantParamPass::TensorVector tensors =
- _options->params(Options::AlgorithmParameters::Quantize_tensor_names);
- auto str_scales = _options->params(Options::AlgorithmParameters::Quantize_scales);
- auto str_zero_points = _options->params(Options::AlgorithmParameters::Quantize_zero_points);
-
- // Cast scales/zero_points to proper types
- ForceQuantParamPass::ScaleVector scales = lexical_cast<float>(str_scales);
- ForceQuantParamPass::ZPVector zero_points = lexical_cast<int64_t>(str_zero_points);
-
- ForceQuantParamPass fq(tensors, scales, zero_points);
- fq.run(g);
- }
-
- logo::Phase phase;
-
- // Do Shape/Type inference
- phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
- phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
-
- ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
- logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
- phase_runner.attach(&prog);
- phase_runner.run(phase);
-}
-
void CircleOptimizer::sparsify(loco::Graph *g) const
{
if (_options->query(Options::Algorithm::SparsifyTensorPass))
diff --git a/compiler/luci/pass/src/CircleOptimizer.test.cpp b/compiler/luci/pass/src/CircleOptimizer.test.cpp
index a1b5c7f80..041fc7d75 100644
--- a/compiler/luci/pass/src/CircleOptimizer.test.cpp
+++ b/compiler/luci/pass/src/CircleOptimizer.test.cpp
@@ -71,171 +71,3 @@ TEST(CircleOptimizerTest, sparsify_simple)
SUCCEED();
}
-
-TEST(CircleOptimizerTest, quantize_quantdequant_simple)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::QuantizeDequantizeWeights);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
- options->param(AlgorithmParameters::Quantize_granularity, "layer");
-
- o.quantize(&g);
-
- SUCCEED();
-}
-
-TEST(CircleOptimizerTest, quantize_quantdequant_input_NEG)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::QuantizeDequantizeWeights);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "invalid");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
- options->param(AlgorithmParameters::Quantize_granularity, "layer");
-
- EXPECT_THROW(o.quantize(&g), std::runtime_error);
-}
-
-TEST(CircleOptimizerTest, quantize_quantdequant_output_NEG)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::QuantizeDequantizeWeights);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "invalid");
- options->param(AlgorithmParameters::Quantize_granularity, "layer");
-
- EXPECT_THROW(o.quantize(&g), std::runtime_error);
-}
-
-TEST(CircleOptimizerTest, quantize_quantdequant_gran_NEG)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::QuantizeDequantizeWeights);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
- options->param(AlgorithmParameters::Quantize_granularity, "invalid");
-
- EXPECT_THROW(o.quantize(&g), std::runtime_error);
-}
-
-TEST(CircleOptimizerTest, quantize_minmax_simple)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::QuantizeWithMinMax);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
- options->param(AlgorithmParameters::Quantize_granularity, "layer");
-
- o.quantize(&g);
-
- SUCCEED();
-}
-
-TEST(CircleOptimizerTest, quantize_minmax_input_NEG)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::QuantizeWithMinMax);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "invalid");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
- options->param(AlgorithmParameters::Quantize_granularity, "layer");
-
- EXPECT_THROW(o.quantize(&g), std::runtime_error);
-}
-
-TEST(CircleOptimizerTest, quantize_minmax_output_NEG)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::QuantizeWithMinMax);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "invalid");
- options->param(AlgorithmParameters::Quantize_granularity, "layer");
-
- EXPECT_THROW(o.quantize(&g), std::runtime_error);
-}
-
-TEST(CircleOptimizerTest, quantize_minmax_gran_NEG)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::QuantizeWithMinMax);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
- options->param(AlgorithmParameters::Quantize_granularity, "invalid");
-
- EXPECT_THROW(o.quantize(&g), std::runtime_error);
-}
-
-TEST(CircleOptimizerTest, quantize_requant_simple)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::Requantize);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "int8");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
-
- o.quantize(&g);
-
- SUCCEED();
-}
-
-TEST(CircleOptimizerTest, quantize_requant_input_NEG)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::Requantize);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "invalid");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
-
- EXPECT_THROW(o.quantize(&g), std::runtime_error);
-}
-
-TEST(CircleOptimizerTest, quantize_requant_output_NEG)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::Requantize);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "int8");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "invalid");
-
- EXPECT_THROW(o.quantize(&g), std::runtime_error);
-}
diff --git a/compiler/luci/pass/src/CircleQuantizer.cpp b/compiler/luci/pass/src/CircleQuantizer.cpp
new file mode 100644
index 000000000..ce38a90b9
--- /dev/null
+++ b/compiler/luci/pass/src/CircleQuantizer.cpp
@@ -0,0 +1,458 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/CircleQuantizer.h"
+
+#include "luci/Pass/CopyQuantParamPass.h"
+#include "luci/Pass/ForceQuantParamPass.h"
+#include "luci/Pass/PropagateQParamForwardPass.h"
+#include "luci/Pass/RequantizePass.h"
+#include "luci/Pass/ConvertToFakeQuantizedModelPass.h"
+#include "luci/Pass/FoldDequantizePass.h"
+#include "luci/Pass/QuantizePreCheckerPass.h"
+#include "luci/Pass/QuantizeWithMinMaxPass.h"
+#include "luci/Pass/QuantizeDequantizeWeightsPass.h"
+
+#include "luci/Pass/CircleShapeInferencePass.h"
+#include "luci/Pass/CircleTypeInferencePass.h"
+
+// logo passes
+#include <logo/RemoveDeadNodeWithQueryPass.h>
+
+#include "ProgressReporter.h"
+#include "helpers/Strings.h"
+
+#include "QuantizedModelVerifier.h"
+
+#include <luci/IR/CircleNode.h>
+#include <logo/Phase.h>
+
+#include <memory>
+
+namespace
+{
+
+using namespace luci;
+using LayerParam = luci::CircleQuantizer::Options::LayerParam;
+
+template <typename T> T lexical_cast(const std::string &str)
+{
+ std::istringstream ss;
+ ss.str(str);
+ T data;
+ ss >> data;
+ return data;
+}
+
+template <typename T> std::vector<T> lexical_cast(std::vector<std::string> &sv)
+{
+ std::vector<T> result;
+ std::transform(sv.begin(), sv.end(), std::back_inserter(result),
+ [](std::string str) -> T { return lexical_cast<T>(str); });
+ return result;
+}
+
+class QuantizeOptionsImpl final : public luci::CircleQuantizer::Options
+{
+public:
+ void enable(Algorithm) final;
+ void param(AlgorithmParameters, const std::string &) final;
+ const std::string param(AlgorithmParameters) const final;
+ void params(AlgorithmParameters, std::vector<std::string> &) final;
+ std::vector<std::string> params(AlgorithmParameters) const final;
+ void layer_params(AlgorithmParameters, std::vector<std::shared_ptr<LayerParam>> &) final;
+ std::vector<std::shared_ptr<LayerParam>> layer_params(AlgorithmParameters) const final;
+ bool query(Algorithm) final;
+
+private:
+ std::vector<Algorithm> _algorithms;
+ std::map<AlgorithmParameters, const std::string> _algorithm_params;
+ std::map<AlgorithmParameters, std::vector<std::string>> _multiple_params;
+ std::map<AlgorithmParameters, std::vector<std::shared_ptr<LayerParam>>> _layer_params;
+};
+
+void QuantizeOptionsImpl::enable(Algorithm algo) { _algorithms.push_back(algo); }
+
+void QuantizeOptionsImpl::param(AlgorithmParameters param, const std::string &str)
+{
+ _algorithm_params.insert(std::pair<AlgorithmParameters, const std::string>(param, str));
+}
+
+const std::string QuantizeOptionsImpl::param(AlgorithmParameters param) const
+{
+ auto param_str = _algorithm_params.find(param);
+ if (param_str != _algorithm_params.end())
+ {
+ return param_str->second;
+ }
+ else
+ {
+ return std::string();
+ }
+}
+
+void QuantizeOptionsImpl::params(AlgorithmParameters param, std::vector<std::string> &vec)
+{
+ _multiple_params[param] = vec;
+}
+
+std::vector<std::string> QuantizeOptionsImpl::params(AlgorithmParameters param) const
+{
+ auto param_vec = _multiple_params.find(param);
+ if (param_vec != _multiple_params.end())
+ {
+ return param_vec->second;
+ }
+ else
+ {
+ return std::vector<std::string>();
+ }
+}
+
+void QuantizeOptionsImpl::layer_params(AlgorithmParameters param,
+ std::vector<std::shared_ptr<LayerParam>> &vec)
+{
+ _layer_params[param] = vec;
+}
+
+std::vector<std::shared_ptr<LayerParam>>
+QuantizeOptionsImpl::layer_params(AlgorithmParameters param) const
+{
+ auto param_vec = _layer_params.find(param);
+ if (param_vec != _layer_params.end())
+ {
+ return param_vec->second;
+ }
+ else
+ {
+ return std::vector<std::shared_ptr<LayerParam>>();
+ }
+}
+
+bool QuantizeOptionsImpl::query(Algorithm algo)
+{
+ std::vector<Algorithm>::iterator it = std::find(_algorithms.begin(), _algorithms.end(), algo);
+ if (it == _algorithms.end())
+ return false;
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+CircleQuantizer::Options *CircleQuantizer::options(void)
+{
+ if (_options == nullptr)
+ {
+ _options = std::make_unique<QuantizeOptionsImpl>();
+ }
+
+ return _options.get();
+}
+
+void CircleQuantizer::quantize(loco::Graph *g) const
+{
+ // Fake quantization of weights
+ if (_options->query(Options::Algorithm::QuantizeDequantizeWeights))
+ {
+ static const std::vector<std::string> fakeq_supported_input_model_dtype{"float32"};
+ static const std::vector<std::string> fakeq_supported_output_model_dtype{"uint8", "int16"};
+ static const std::vector<std::string> fakeq_supported_granularity{"layer", "channel"};
+
+ auto input_model_dtype =
+ _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
+ auto output_model_dtype =
+ _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
+ auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
+ auto layer_params = _options->layer_params(Options::AlgorithmParameters::Quantize_layer_params);
+
+ if (!in_array(to_lower_case(input_model_dtype), fakeq_supported_input_model_dtype))
+ throw std::runtime_error("Unsupported input type. List of supported input type: " +
+ to_string(fakeq_supported_input_model_dtype));
+
+ if (!in_array(to_lower_case(output_model_dtype), fakeq_supported_output_model_dtype))
+ throw std::runtime_error("Unsupported output type. List of supported output type: " +
+ to_string(fakeq_supported_output_model_dtype));
+
+ if (!in_array(to_lower_case(granularity), fakeq_supported_granularity))
+ throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
+ to_string(fakeq_supported_granularity));
+
+ if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise &&
+ str_to_dtype(output_model_dtype) != loco::DataType::U8)
+ throw std::runtime_error("Layer-wise quantization only supports uint8 dtype.");
+
+ // Check dtype/granularity of layer params
+ for (auto layer_param : layer_params)
+ {
+ auto name = layer_param->name;
+ if (!in_array(to_lower_case(layer_param->dtype), fakeq_supported_output_model_dtype))
+ {
+ throw std::runtime_error("Unsupported dtype in " + name + ". List of supported dtype: " +
+ to_string(fakeq_supported_output_model_dtype));
+ }
+ if (!in_array(to_lower_case(layer_param->granularity), fakeq_supported_granularity))
+ {
+ throw std::runtime_error(
+ "Unsupported granularity in " + name +
+ ". List of supported granularity: " + to_string(fakeq_supported_granularity));
+ }
+ }
+
+ // Clear existing quantparams before doing fake quantization
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ if (circle_node->quantparam() != nullptr)
+ circle_node->quantparam(nullptr);
+ }
+
+ auto ctx = std::make_unique<luci::QuantizeDequantizeWeightsPass::Context>();
+ {
+ ctx->input_model_dtype = str_to_dtype(input_model_dtype);
+ ctx->output_model_dtype = str_to_dtype(output_model_dtype);
+ ctx->granularity = str_to_granularity(granularity);
+
+ for (auto layer_param : layer_params)
+ {
+ LayerInfo info;
+ {
+ info.name = layer_param->name;
+ info.dtype = str_to_dtype(layer_param->dtype);
+ info.granularity = str_to_granularity(layer_param->granularity);
+ }
+ ctx->layers_info.emplace_back(info);
+ }
+ }
+
+ luci::QuantizeDequantizeWeightsPass fake_quantizer(std::move(ctx));
+
+ fake_quantizer.run(g);
+ }
+
+ // Actual quantization of weights, bias, and activation
+ if (_options->query(Options::Algorithm::QuantizeWithMinMax))
+ {
+ static const std::vector<std::string> qwmm_supported_input_model_dtype{"float32"};
+ static const std::vector<std::string> qwmm_supported_output_model_dtype{"uint8", "int16"};
+ static const std::vector<std::string> qwmm_supported_granularity{"layer", "channel"};
+ static const std::vector<std::string> qwmm_supported_input_type{"uint8", "int16"};
+ static const std::vector<std::string> qwmm_supported_output_type{"uint8", "int16"};
+
+ auto input_model_dtype =
+ _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
+ auto output_model_dtype =
+ _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
+ auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
+ auto input_type = _options->param(Options::AlgorithmParameters::Quantize_input_type);
+ if (input_type.empty())
+ input_type = output_model_dtype;
+ auto output_type = _options->param(Options::AlgorithmParameters::Quantize_output_type);
+ if (output_type.empty())
+ output_type = output_model_dtype;
+
+ bool TF_style_maxpool =
+ _options->param(Options::AlgorithmParameters::Quantize_TF_style_maxpool) == "True";
+
+ auto layer_params = _options->layer_params(Options::AlgorithmParameters::Quantize_layer_params);
+
+ if (!in_array(to_lower_case(input_model_dtype), qwmm_supported_input_model_dtype))
+ throw std::runtime_error("Unsupported input type. List of supported input types: " +
+ to_string(qwmm_supported_input_model_dtype));
+
+ if (!in_array(to_lower_case(output_model_dtype), qwmm_supported_output_model_dtype))
+ throw std::runtime_error("Unsupported output type. List of supported output types: " +
+ to_string(qwmm_supported_output_model_dtype));
+
+ if (!in_array(to_lower_case(granularity), qwmm_supported_granularity))
+ throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
+ to_string(qwmm_supported_granularity));
+
+ if (!in_array(to_lower_case(input_type), qwmm_supported_input_type))
+ throw std::runtime_error("Unsupported input type. List of supported input types: " +
+ to_string(qwmm_supported_input_type));
+
+ if (!in_array(to_lower_case(output_type), qwmm_supported_output_type))
+ throw std::runtime_error("Unsupported output type. List of supported output types: " +
+ to_string(qwmm_supported_output_type));
+
+ if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise &&
+ str_to_dtype(output_model_dtype) != loco::DataType::U8)
+ throw std::runtime_error("Layer-wise quantization only supports uint8 dtype.");
+
+ // Check dtype/granularity of layer params
+ for (auto layer_param : layer_params)
+ {
+ auto name = layer_param->name;
+ if (!in_array(to_lower_case(layer_param->dtype), qwmm_supported_output_model_dtype))
+ {
+ throw std::runtime_error("Unsupported dtype in " + name + ". List of supported dtype: " +
+ to_string(qwmm_supported_output_model_dtype));
+ }
+ if (!in_array(to_lower_case(layer_param->granularity), qwmm_supported_granularity))
+ {
+ throw std::runtime_error(
+ "Unsupported granularity in " + name +
+ ". List of supported granularity: " + to_string(qwmm_supported_granularity));
+ }
+ }
+
+ // Input model checker for quantization
+ luci::QuantizePreCheckerPass input_model_checker{};
+ input_model_checker.run(g);
+
+ auto ctx = std::make_unique<luci::QuantizeWithMinMaxPass::Context>();
+ {
+ ctx->input_model_dtype = str_to_dtype(input_model_dtype);
+ ctx->output_model_dtype = str_to_dtype(output_model_dtype);
+ ctx->granularity = str_to_granularity(granularity);
+ ctx->input_type = str_to_dtype(input_type);
+ ctx->output_type = str_to_dtype(output_type);
+ ctx->TF_style_maxpool = TF_style_maxpool;
+
+ for (auto layer_param : layer_params)
+ {
+ LayerInfo info;
+ {
+ info.name = layer_param->name;
+ info.dtype = str_to_dtype(layer_param->dtype);
+ info.granularity = str_to_granularity(layer_param->granularity);
+ }
+ ctx->layers_info.emplace_back(info);
+ }
+ }
+
+ luci::QuantizeWithMinMaxPass quantizer(std::move(ctx));
+
+ quantizer.run(g);
+
+ auto verify_ctx = std::make_unique<luci::QuantizedModelVerifier::Context>();
+ {
+ verify_ctx->output_model_dtype = str_to_dtype(output_model_dtype);
+ verify_ctx->granularity = str_to_granularity(granularity);
+ verify_ctx->input_type = str_to_dtype(input_type);
+ verify_ctx->output_type = str_to_dtype(output_type);
+ verify_ctx->TF_style_maxpool = TF_style_maxpool;
+
+ for (auto layer_param : layer_params)
+ {
+ LayerInfo info;
+ {
+ info.name = layer_param->name;
+ info.dtype = str_to_dtype(layer_param->dtype);
+ info.granularity = str_to_granularity(layer_param->granularity);
+ }
+ verify_ctx->layers_info.emplace_back(info);
+ }
+ }
+
+ // Verify the type/granularity of the quantized model
+ luci::QuantizedModelVerifier verifier(std::move(verify_ctx));
+
+ verifier.verify(g);
+ }
+
+ // Requantize
+ if (_options->query(Options::Algorithm::Requantize))
+ {
+ static const std::vector<std::string> rq_supported_input_model_dtype{"int8"};
+ static const std::vector<std::string> rq_supported_output_model_dtype{"uint8"};
+
+ auto input_model_dtype =
+ _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
+ auto output_model_dtype =
+ _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
+
+ if (!in_array(to_lower_case(input_model_dtype), rq_supported_input_model_dtype))
+ throw std::runtime_error("Unsupported input type. List of supported input types: " +
+ to_string(rq_supported_input_model_dtype));
+
+ if (!in_array(to_lower_case(output_model_dtype), rq_supported_output_model_dtype))
+ throw std::runtime_error("Unsupported output type. List of supported output types: " +
+ to_string(rq_supported_output_model_dtype));
+
+ luci::RequantizePass requantizer(str_to_dtype(input_model_dtype),
+ str_to_dtype(output_model_dtype));
+ requantizer.run(g);
+ }
+
+ // Force to write quantparam to specified tensors
+ // NOTE Only per-tensor (not per-channel) qparam can be written
+ if (_options->query(Options::Algorithm::ForceQuantParam))
+ {
+ ForceQuantParamPass::TensorVector tensors =
+ _options->params(Options::AlgorithmParameters::Quantize_tensor_names);
+ auto str_scales = _options->params(Options::AlgorithmParameters::Quantize_scales);
+ auto str_zero_points = _options->params(Options::AlgorithmParameters::Quantize_zero_points);
+
+ // Cast scales/zero_points to proper types
+ ForceQuantParamPass::ScaleVector scales = lexical_cast<float>(str_scales);
+ ForceQuantParamPass::ZPVector zero_points = lexical_cast<int64_t>(str_zero_points);
+
+ ForceQuantParamPass fq(tensors, scales, zero_points);
+ fq.run(g);
+ }
+
+ // Copy quantparam of a tensor to another tensor
+ if (_options->query(Options::Algorithm::CopyQuantParam))
+ {
+ CopyQuantParamPass::TensorVector src_tensors =
+ _options->params(Options::AlgorithmParameters::Quantize_src_tensor_names);
+ CopyQuantParamPass::TensorVector dst_tensors =
+ _options->params(Options::AlgorithmParameters::Quantize_dst_tensor_names);
+
+ CopyQuantParamPass cq(src_tensors, dst_tensors);
+ cq.run(g);
+ }
+
+ // Convert quantized model to fake-quantized model
+ if (_options->query(Options::Algorithm::ConvertToFakeQuantizedModel))
+ {
+ luci::ConvertToFakeQuantizedModelPass fake_quantizer;
+ fake_quantizer.run(g);
+
+ logo::Phase phase;
+
+ // Default passes
+ phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
+ phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
+ phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
+
+ // Fold Dequantize Ops generated during fake quantization
+ phase.emplace_back(std::make_unique<luci::FoldDequantizePass>());
+
+ ProgressReporter prog(g, logo::PhaseStrategy::Restart);
+ logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
+ phase_runner.attach(&prog);
+ phase_runner.run(phase);
+ }
+
+ logo::Phase phase;
+
+ // Do Shape/Type inference
+ phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
+ phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
+
+ ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
+ logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
+ phase_runner.attach(&prog);
+ phase_runner.run(phase);
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/CircleQuantizer.test.cpp b/compiler/luci/pass/src/CircleQuantizer.test.cpp
new file mode 100644
index 000000000..5766d5fe5
--- /dev/null
+++ b/compiler/luci/pass/src/CircleQuantizer.test.cpp
@@ -0,0 +1,191 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/CircleQuantizer.h"
+
+#include <gtest/gtest.h>
+
+using namespace luci;
+using Algorithms = luci::CircleQuantizer::Options::Algorithm;
+using AlgorithmParameters = luci::CircleQuantizer::Options::AlgorithmParameters;
+
+TEST(CircleQuantizerTest, quantize_quantdequant_simple)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeDequantizeWeights);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
+ options->param(AlgorithmParameters::Quantize_granularity, "layer");
+
+ o.quantize(&g);
+
+ SUCCEED();
+}
+
+TEST(CircleQuantizerTest, quantize_quantdequant_input_NEG)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeDequantizeWeights);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "invalid");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
+ options->param(AlgorithmParameters::Quantize_granularity, "layer");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleQuantizerTest, quantize_quantdequant_output_NEG)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeDequantizeWeights);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "invalid");
+ options->param(AlgorithmParameters::Quantize_granularity, "layer");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleQuantizerTest, quantize_quantdequant_gran_NEG)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeDequantizeWeights);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
+ options->param(AlgorithmParameters::Quantize_granularity, "invalid");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleQuantizerTest, quantize_minmax_simple)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeWithMinMax);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
+ options->param(AlgorithmParameters::Quantize_granularity, "layer");
+
+ o.quantize(&g);
+
+ SUCCEED();
+}
+
+TEST(CircleQuantizerTest, quantize_minmax_input_NEG)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeWithMinMax);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "invalid");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
+ options->param(AlgorithmParameters::Quantize_granularity, "layer");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleQuantizerTest, quantize_minmax_output_NEG)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeWithMinMax);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "invalid");
+ options->param(AlgorithmParameters::Quantize_granularity, "layer");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleQuantizerTest, quantize_minmax_gran_NEG)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeWithMinMax);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
+ options->param(AlgorithmParameters::Quantize_granularity, "invalid");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleQuantizerTest, quantize_requant_simple)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::Requantize);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "int8");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
+
+ o.quantize(&g);
+
+ SUCCEED();
+}
+
+TEST(CircleQuantizerTest, quantize_requant_input_NEG)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::Requantize);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "invalid");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleQuantizerTest, quantize_requant_output_NEG)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::Requantize);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "int8");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "invalid");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp
index 270714049..ce4f54035 100644
--- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp
+++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp
@@ -228,6 +228,9 @@ bool check_4d_reshape(loco::Node *node, const std::vector<int32_t> indices)
if (input->shape_status() != luci::ShapeStatus::VALID)
return false;
+ if (input->rank() != 4)
+ return false;
+
if (reshape->shape_status() != luci::ShapeStatus::VALID)
return false;
@@ -804,6 +807,8 @@ class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool>
return true;
}
+ bool visit(luci::CircleElu *node) { return convert_unary_features<luci::CircleElu>(node); }
+
bool visit(luci::CircleLeakyRelu *node)
{
return convert_unary_features<luci::CircleLeakyRelu>(node);
@@ -1240,6 +1245,7 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g)
break;
case luci::CircleOpcode::ADD:
case luci::CircleOpcode::CONCATENATION:
+ case luci::CircleOpcode::ELU:
case luci::CircleOpcode::LEAKY_RELU:
case luci::CircleOpcode::LOGISTIC:
case luci::CircleOpcode::MAXIMUM:
diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
index c9412fbb1..dd81d1380 100644
--- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
+++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
@@ -264,6 +264,22 @@ public:
luci::CircleConst *input2 = nullptr;
};
+class EluGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ elu = g.nodes()->create<luci::CircleElu>();
+ elu->features(input);
+ elu->name("elu");
+
+ return elu;
+ }
+
+public:
+ luci::CircleElu *elu = nullptr;
+};
+
class LeakyReluGraph final : public SimpleGraph
{
protected:
@@ -941,6 +957,26 @@ TEST(ConvertNCHWToNHWC, Concatenation)
EXPECT_EQ(3, g.concat->axis());
}
+TEST(ConvertNCHWToNHWC, Elu)
+{
+ EluGraph g;
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.elu->features());
+
+ auto elu_succs = loco::succs(g.elu);
+ EXPECT_EQ(1, elu_succs.size());
+ check_post_trans(*elu_succs.begin());
+
+ // Check elu shape
+ EXPECT_EQ(1, g.elu->dim(0).value());
+ EXPECT_EQ(4, g.elu->dim(1).value());
+ EXPECT_EQ(4, g.elu->dim(2).value());
+ EXPECT_EQ(16, g.elu->dim(3).value());
+}
+
TEST(ConvertNCHWToNHWC, LeakyRelu)
{
LeakyReluGraph g;
diff --git a/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp
new file mode 100644
index 000000000..11970fff5
--- /dev/null
+++ b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp
@@ -0,0 +1,214 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ConvertToFakeQuantizedModelPass.h"
+#include "luci/Pass/QuantizationParameters.h"
+
+#include "QuantizationUtils.h"
+
+#include <luci/Profile/CircleNodeOrigin.h>
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Log.h>
+
+namespace
+{
+
+// Create Quantize Op whose dtype/shape/qparam are the same with node
+luci::CircleQuantize *create_quantize(luci::CircleNode *node)
+{
+ auto quantize = node->graph()->nodes()->create<luci::CircleQuantize>();
+ quantize->name(node->name() + "_Quantize");
+ quantize->dtype(node->dtype());
+ quantize->rank(node->rank());
+ for (uint32_t i = 0; i < node->rank(); i++)
+ quantize->dim(i).set(node->dim(i).value());
+
+ quantize->shape_status(luci::ShapeStatus::VALID);
+
+ copy_quantparam(node, quantize);
+
+ luci::add_origin(quantize, luci::get_origin(node));
+
+ return quantize;
+}
+
+// Create Dequantize Op whose shape is the same with node
+luci::CircleDequantize *create_dequantize(luci::CircleNode *node)
+{
+ auto dequantize = node->graph()->nodes()->create<luci::CircleDequantize>();
+ dequantize->name(node->name() + "_Dequantize");
+ dequantize->dtype(loco::DataType::FLOAT32);
+ dequantize->rank(node->rank());
+ for (uint32_t i = 0; i < node->rank(); i++)
+ dequantize->dim(i).set(node->dim(i).value());
+
+ dequantize->shape_status(luci::ShapeStatus::VALID);
+
+ luci::add_origin(dequantize, luci::get_origin(node));
+
+ return dequantize;
+}
+
+// Return true if node is quantized activation
+// 1. dtype is u8 or s16
+// 2. node has qparam
+bool is_quant_act(const luci::CircleNode *node)
+{
+ if (node->dtype() != loco::DataType::U8 and node->dtype() != loco::DataType::S16)
+ return false;
+
+ if (not node->quantparam())
+ return false;
+
+ return true;
+}
+
+// Return true if node is quantized const
+// 1. dtype is not fp32
+// 2. node has qparam
+// NOTE Quantized const can have the following types
+// u8 (weights, activation), s16 (weights, activation), s32 (bias), s64 (bias)
+bool is_quant_const(const luci::CircleConst *node)
+{
+ if (node->dtype() == loco::DataType::FLOAT32)
+ return false;
+
+ if (not node->quantparam())
+ return false;
+
+ return true;
+}
+
+// Insert dequantize Op after node
+void insert_dequantize(loco::Node *lnode)
+{
+ auto node = loco::must_cast<luci::CircleNode *>(lnode);
+ auto dequant = create_dequantize(node);
+ loco::replace(node).with(dequant);
+ dequant->input(node);
+}
+
+// Insert quantize Op after node and return the quantize Op
+luci::CircleQuantize *insert_quantize(loco::Node *lnode)
+{
+ auto node = loco::must_cast<luci::CircleNode *>(lnode);
+ auto quant = create_quantize(node);
+ loco::replace(node).with(quant);
+ quant->input(node);
+ return quant;
+}
+
+// Dequantize node
+void dequantize(luci::CircleNode *node)
+{
+ node->dtype(loco::DataType::FLOAT32);
+ node->quantparam(nullptr);
+}
+
+// Do fake quantization on quantized activation
+// 1. Insert Quantize-Dequantize Ops
+// 2. Update dtype/quantparam of node
+void fq_activation(luci::CircleNode *node)
+{
+ if (not is_quant_act(node))
+ return;
+
+ auto quant = insert_quantize(node);
+ insert_dequantize(quant);
+
+ dequantize(node);
+}
+
+#define RETURN_UNLESS(COND) \
+ if (not(COND)) \
+ return;
+
+// Visitor to do fake quantization for each Op
+// For non-const activation, insert Quantize-Dequantize after the ofm
+// For quantized const, insert Dequantize after the const
+struct FakeQuantize final : public luci::CircleNodeMutableVisitor<void>
+{
+ void visit(luci::CircleNode *node)
+ {
+ throw std::runtime_error("Unsupported op for fake quantization in " + node->name());
+ }
+
+ void visit(luci::CircleInput *node)
+ {
+ RETURN_UNLESS(is_quant_act(node));
+
+ auto quant = insert_quantize(node);
+ insert_dequantize(quant);
+
+ dequantize(node);
+
+ // Update graph input
+ const auto inputs = node->graph()->inputs();
+ auto graph_input = inputs->at(node->index());
+ graph_input->dtype(loco::DataType::FLOAT32);
+ }
+
+ void visit(luci::CircleOutput *node)
+ {
+ RETURN_UNLESS(is_quant_act(node));
+
+ dequantize(node);
+
+ // Update graph output
+ const auto outputs = node->graph()->outputs();
+ auto graph_output = outputs->at(node->index());
+ graph_output->dtype(loco::DataType::FLOAT32);
+ }
+
+ // For quantized const, insert Dequantize Op
+ void visit(luci::CircleConst *node)
+ {
+ RETURN_UNLESS(is_quant_const(node));
+
+ insert_dequantize(node);
+ }
+
+ // For non-const activation, insert Quantize-Dequantize Ops
+ // and dequantize the node
+ void visit(luci::CircleConv2D *node) { fq_activation(node); }
+ void visit(luci::CircleAdd *node) { fq_activation(node); }
+};
+
+#undef RETURN_UNLESS
+
+} // namespace
+
+namespace luci
+{
+
+bool ConvertToFakeQuantizedModelPass::run(loco::Graph *g)
+{
+ LOGGER(l);
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ INFO(l) << "ConvertToFakeQuantizedModelPass visit node: " << circle_node->name() << std::endl;
+
+ FakeQuantize fq;
+ circle_node->accept(&fq);
+ }
+
+ // One time run
+ return false;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.test.cpp b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.test.cpp
new file mode 100644
index 000000000..560d68a74
--- /dev/null
+++ b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.test.cpp
@@ -0,0 +1,277 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <logo/Phase.h>
+
+#include "luci/Pass/ConvertToFakeQuantizedModelPass.h"
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+// Check the below pattern
+// Quantize (scale, zp) -> Dequantize (node)
+void check_q_dq(loco::Node *node, float scale, int64_t zp)
+{
+ auto dequant = dynamic_cast<luci::CircleDequantize *>(node);
+ EXPECT_TRUE(dequant != nullptr);
+ auto quant = dynamic_cast<luci::CircleQuantize *>(dequant->input());
+ EXPECT_TRUE(quant != nullptr);
+ auto qparam = quant->quantparam();
+ EXPECT_EQ(scale, qparam->scale[0]);
+ EXPECT_EQ(zp, qparam->zerop[0]);
+}
+
+// Check the below pattern
+// Dequantize (node)
+void check_dq(loco::Node *node)
+{
+ auto dequant = dynamic_cast<luci::CircleDequantize *>(node);
+ EXPECT_TRUE(dequant != nullptr);
+}
+
+void set_qparam(luci::CircleNode *node, float scale, int64_t zp)
+{
+ auto qparam = std::make_unique<luci::CircleQuantParam>();
+ {
+ qparam->scale.push_back(scale);
+ qparam->zerop.push_back(zp);
+ }
+ node->quantparam(std::move(qparam));
+}
+
+/**
+ * SimpleGraph for testing
+ * - Child class should implement insertGraphBody()
+ *
+ * Example (U8ConvGraph inherits SimpleGraph and create Conv2D Op)
+ *
+ * BEFORE
+ * - A model is quantized (ex: u8)
+ *
+ * [Input(u8)] [Filter(u8)] [Bias(s32)]
+ * \ | /
+ * \ | /
+ * \ | /
+ * [Conv2D(u8)]
+ * |
+ * [Output(u8)]
+ *
+ * AFTER
+ * - Ops are converted to fp32
+ * - Quantize/Dequantize Ops are inserted properly
+ * - Q-DQ is inserted after non-const activation
+ * - DQ is inserted after const
+ *
+ * [Input(u8)]
+ * |
+ * [Quant(u8)] [Filter(u8)] [Bias(s32)]
+ * | | |
+ * [Dequant(fp32)] [Dequant(fp32)] [Dequant(fp32)]
+ * \ | /
+ * \ | /
+ * \ | /
+ * [Conv2D(fp32)]
+ * |
+ * [Quant(u8)]
+ * |
+ * [Dequant(fp32)]
+ * |
+ * [Output(fp32)]
+ */
+template <loco::DataType T> class SimpleGraph
+{
+public:
+ void init()
+ {
+ input = g.nodes()->create<luci::CircleInput>();
+ output = g.nodes()->create<luci::CircleOutput>();
+ input->name("input");
+ output->name("output");
+
+ auto graph_input = g.inputs()->create();
+ input->index(graph_input->index());
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ graph_input->dtype(T);
+ input->dtype(T);
+ output->dtype(T);
+ graph_output->dtype(T);
+
+ graph_input->shape({1, 4, 4, 4});
+ input->shape({1, 4, 4, 4});
+ output->shape({1, 4, 4, 4});
+ graph_output->shape({1, 4, 4, 4});
+
+ set_qparam(input, 1.0, 0);
+ set_qparam(output, 1.0, 0);
+
+ auto graph_body = insertGraphBody(input);
+ output->from(graph_body);
+ }
+
+ virtual ~SimpleGraph() = default;
+
+protected:
+ virtual loco::Node *insertGraphBody(loco::Node *input) = 0;
+
+public:
+ loco::Graph g;
+ luci::CircleInput *input = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+class U8ConvGraph final : public SimpleGraph<loco::DataType::U8>
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ conv = g.nodes()->create<luci::CircleConv2D>();
+ weights = g.nodes()->create<luci::CircleConst>();
+ bias = g.nodes()->create<luci::CircleConst>();
+
+ conv->dtype(loco::DataType::U8);
+ weights->dtype(loco::DataType::U8);
+ bias->dtype(loco::DataType::S32);
+
+ conv->shape({1, 4, 4, 4});
+ weights->shape({4, 1, 1, 4});
+ bias->shape({4});
+
+ weights->size<loco::DataType::U8>(16);
+ for (uint32_t i = 0; i < 16; i++)
+ weights->at<loco::DataType::U8>(i) = i;
+
+ bias->size<loco::DataType::S32>(4);
+ for (uint32_t i = 0; i < 4; i++)
+ bias->at<loco::DataType::S32>(i) = i;
+
+ set_qparam(conv, 2.0, 127);
+ set_qparam(weights, 2.0, 127);
+ set_qparam(bias, 2.0, 127);
+
+ conv->input(input);
+ conv->filter(weights);
+ conv->bias(bias);
+
+ conv->name("conv");
+ weights->name("weights");
+ bias->name("bias");
+
+ return conv;
+ }
+
+public:
+ luci::CircleConv2D *conv = nullptr;
+ luci::CircleConst *weights = nullptr;
+ luci::CircleConst *bias = nullptr;
+};
+
+class FP32ConvGraph final : public SimpleGraph<loco::DataType::FLOAT32>
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ conv = g.nodes()->create<luci::CircleConv2D>();
+ weights = g.nodes()->create<luci::CircleConst>();
+ bias = g.nodes()->create<luci::CircleConst>();
+
+ conv->dtype(loco::DataType::FLOAT32);
+ weights->dtype(loco::DataType::FLOAT32);
+ bias->dtype(loco::DataType::FLOAT32);
+
+ conv->shape({1, 4, 4, 4});
+ weights->shape({4, 1, 1, 4});
+ bias->shape({4});
+
+ weights->size<loco::DataType::FLOAT32>(16);
+ for (uint32_t i = 0; i < 16; i++)
+ weights->at<loco::DataType::FLOAT32>(i) = i;
+
+ bias->size<loco::DataType::FLOAT32>(4);
+ for (uint32_t i = 0; i < 4; i++)
+ bias->at<loco::DataType::FLOAT32>(i) = i;
+
+ conv->input(input);
+ conv->filter(weights);
+ conv->bias(bias);
+
+ conv->name("conv");
+ weights->name("weights");
+ bias->name("bias");
+
+ return conv;
+ }
+
+public:
+ luci::CircleConv2D *conv = nullptr;
+ luci::CircleConst *weights = nullptr;
+ luci::CircleConst *bias = nullptr;
+};
+
+} // namespace
+
+TEST(ConvertToFakeQuantizedModelTest, U8Conv2D)
+{
+ U8ConvGraph g;
+ g.init();
+
+ luci::ConvertToFakeQuantizedModelPass fq;
+ fq.run(&g.g);
+
+ // Check ifm
+ check_q_dq(g.conv->input(), 1.0, 0);
+
+ // Check weights
+ check_dq(g.conv->filter());
+
+ // Check bias
+ check_dq(g.conv->bias());
+
+ // Check ofm
+ check_q_dq(g.output->from(), 2.0, 127);
+
+ SUCCEED();
+}
+
+TEST(ConvertToFakeQuantizedModelTest, F32Conv2D_NEG)
+{
+ FP32ConvGraph g;
+ g.init();
+
+ luci::ConvertToFakeQuantizedModelPass fq;
+ fq.run(&g.g);
+
+ uint32_t dequant_count = 0;
+ uint32_t quant_count = 0;
+
+ for (auto node : loco::active_nodes(loco::output_nodes(&g.g)))
+ {
+ auto cnode = loco::must_cast<luci::CircleNode *>(node);
+ auto opcode = cnode->opcode();
+ if (opcode == luci::CircleOpcode::DEQUANTIZE)
+ dequant_count++;
+ if (opcode == luci::CircleOpcode::QUANTIZE)
+ quant_count++;
+ }
+
+ // Check no quant/dequant Op is inserted
+ EXPECT_EQ(0, quant_count);
+ EXPECT_EQ(0, dequant_count);
+}
diff --git a/compiler/luci/pass/src/CopyQuantParamPass.cpp b/compiler/luci/pass/src/CopyQuantParamPass.cpp
new file mode 100644
index 000000000..9b1bb0ea9
--- /dev/null
+++ b/compiler/luci/pass/src/CopyQuantParamPass.cpp
@@ -0,0 +1,82 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/CopyQuantParamPass.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Log.h>
+
+namespace luci
+{
+
+namespace
+{
+
+struct SrcDst
+{
+ CircleNode *src = nullptr;
+ CircleNode *dst = nullptr;
+};
+
+} // namespace
+
+bool CopyQuantParamPass::run(loco::Graph *g)
+{
+ LOGGER(l);
+
+ INFO(l) << "CopyQuantParamPass Start" << std::endl;
+
+ if (_src_tensors.size() != _dst_tensors.size())
+ throw std::runtime_error("The numbers of Source/Destination tensors do not match.");
+
+ // Return src/dst CircleNodes
+ auto get_src_dst = [&g](std::string src, std::string dst) {
+ SrcDst src_dst;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto const cnode = loco::must_cast<CircleNode *>(node);
+ auto const name = cnode->name();
+ if (name == src)
+ src_dst.src = cnode;
+
+ if (name == dst)
+ src_dst.dst = cnode;
+ }
+ return src_dst;
+ };
+
+ for (uint32_t i = 0; i < _src_tensors.size(); i++)
+ {
+ auto src = _src_tensors[i];
+ auto dst = _dst_tensors[i];
+
+ auto nodes = get_src_dst(src, dst);
+ if (not nodes.src)
+ throw std::runtime_error("The tensor named " + src + " does not exist.");
+
+ if (not nodes.dst)
+ throw std::runtime_error("The tensor named " + dst + " does not exist.");
+
+ copy_quantparam(nodes.src, nodes.dst);
+
+ INFO(l) << "Quantparam of " << src << " is copied to " << dst << std::endl;
+ }
+
+ INFO(l) << "CopyQuantParamPass End" << std::endl;
+
+ return false; // one time run
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/FoldGatherPass.cpp b/compiler/luci/pass/src/FoldGatherPass.cpp
new file mode 100644
index 000000000..f179d74bd
--- /dev/null
+++ b/compiler/luci/pass/src/FoldGatherPass.cpp
@@ -0,0 +1,185 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FoldGatherPass.h"
+#include "CircleOptimizerUtils.h"
+
+#include <luci/IR/CircleNodes.h>
+
+namespace
+{
+
+/**
+ * Fold to const if
+ *
+ * 1. params: const and dtype = S32 or S64
+ * 2. indices: const and dtype = S32 or S64
+ *
+ * BEFORE
+ *
+ * [CircleConst] [CircleConst]
+ * | |
+ * +---------[Gather]---------+
+ *
+ * AFTER
+ *
+ * [CircleConst]
+ *
+ **/
+template <loco::DataType InputT, loco::DataType IndexT>
+bool fold_gather(luci::CircleGather *gather_node)
+{
+ const auto params = loco::must_cast<luci::CircleConst *>(gather_node->params());
+ const auto indices = loco::must_cast<luci::CircleConst *>(gather_node->indices());
+
+ const auto rank = params->rank();
+ auto axis = gather_node->axis();
+ if (axis < 0)
+ {
+ axis += static_cast<int32_t>(rank);
+ }
+
+ if (axis < 0 or axis >= static_cast<int32_t>(rank))
+ throw std::runtime_error("Unsupported axis value");
+
+ const auto name = gather_node->name();
+ assert(name.length() > 0);
+
+ auto constant = gather_node->graph()->nodes()->create<luci::CircleConst>();
+ constant->dtype(InputT);
+ constant->name(name + "_folded");
+
+ constant->rank(rank + indices->rank() - 1);
+
+ assert(constant->rank() > 0);
+
+ std::vector<uint32_t> shape;
+ for (uint32_t i = 0; i < rank; ++i)
+ {
+ if (i != static_cast<uint32_t>(axis))
+ {
+ const auto dim = params->dim(i).value();
+ shape.push_back(dim);
+ }
+ else
+ {
+ for (uint32_t j = 0; j < indices->rank(); ++j)
+ {
+ const auto dim = indices->dim(j).value();
+ shape.push_back(dim);
+ }
+ }
+ }
+
+ uint32_t size = 1;
+ for (uint32_t i = 0; i < shape.size(); ++i)
+ {
+ constant->dim(i).set(shape.at(i));
+ size *= shape.at(i);
+ }
+
+ constant->size<InputT>(size);
+
+ uint32_t outer_size = 1;
+ for (uint32_t i = 0; i < static_cast<uint32_t>(axis); ++i)
+ {
+ outer_size *= params->dim(i).value();
+ }
+
+ uint32_t inner_size = 1;
+ for (uint32_t i = axis + 1; i < rank; ++i)
+ {
+ inner_size *= params->dim(i).value();
+ }
+
+ uint32_t coord_size = 1;
+ for (uint32_t i = 0; i < indices->rank(); ++i)
+ {
+ coord_size *= indices->dim(i).value();
+ }
+
+ const auto axis_size = params->dim(axis).value();
+
+ for (uint32_t outer = 0; outer < outer_size; ++outer)
+ {
+ for (uint32_t i = 0; i < coord_size; ++i)
+ {
+ constant->at<InputT>((outer * coord_size + i) * inner_size) =
+ params->at<InputT>((outer * axis_size + indices->at<IndexT>(i)) * inner_size);
+ }
+ }
+ loco::replace(gather_node).with(constant);
+
+ return true;
+}
+
+bool fold_gather(luci::CircleGather *gather_node)
+{
+ const auto params = dynamic_cast<luci::CircleConst *>(gather_node->params());
+ if (not params)
+ return false;
+
+ const auto indices = dynamic_cast<luci::CircleConst *>(gather_node->indices());
+ if (not indices)
+ return false;
+
+ // TODO: support more types
+ if (params->dtype() != loco::DataType::S32 and params->dtype() != loco::DataType::S64)
+ return false;
+
+ if (indices->dtype() != loco::DataType::S32 and indices->dtype() != loco::DataType::S64)
+ throw std::runtime_error("Unsupported type");
+
+ if (params->dtype() == loco::DataType::S64)
+ {
+ if (indices->dtype() == loco::DataType::S64)
+ return fold_gather<loco::DataType::S64, loco::DataType::S64>(gather_node);
+ else
+ return fold_gather<loco::DataType::S64, loco::DataType::S32>(gather_node);
+ }
+ else
+ {
+ if (indices->dtype() == loco::DataType::S64)
+ return fold_gather<loco::DataType::S32, loco::DataType::S64>(gather_node);
+ else
+ return fold_gather<loco::DataType::S32, loco::DataType::S32>(gather_node);
+ }
+}
+
+} // namespace
+
+namespace luci
+{
+
+/**
+ * Constant Folding for Gather Op
+ **/
+bool FoldGatherPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto gather_node = dynamic_cast<luci::CircleGather *>(node))
+ {
+ if (fold_gather(gather_node))
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/FoldGatherPass.test.cpp b/compiler/luci/pass/src/FoldGatherPass.test.cpp
new file mode 100644
index 000000000..b02c034a5
--- /dev/null
+++ b/compiler/luci/pass/src/FoldGatherPass.test.cpp
@@ -0,0 +1,214 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FoldGatherPass.h"
+#include "PassTestGraphs.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+/**
+ *
+ * Graph that has a Gather S64 Op with const inputs
+ *
+ * BEFORE
+ * params: [Const] (shape: [3], values: [1, 2, 3])
+ * indices: [Const] (shape: [1], values: [1])
+ *
+ * [params] [indices]
+ * | |
+ * ---[Gather]---
+ *
+ * AFTER
+ * [Const] (shape: [1], values: [2])
+ *
+ */
+class S64FoldGatherSimpleTest : public luci::ConstantFoldingAddTestGraph, public ::testing::Test
+{
+public:
+ S64FoldGatherSimpleTest() : luci::ConstantFoldingAddTestGraph({1}, loco::DataType::S64) {}
+
+ virtual void SetUp() { init(); }
+
+ loco::Node *createFoldedPattern() override
+ {
+ _gather = _g.nodes()->create<luci::CircleGather>();
+ _params = _g.nodes()->create<luci::CircleConst>();
+ _indices = _g.nodes()->create<luci::CircleConst>();
+
+ _gather->dtype(loco::DataType::S64);
+ _params->dtype(loco::DataType::S64);
+ _indices->dtype(loco::DataType::S64);
+
+ _params->shape({3});
+ _indices->shape({1});
+
+ _params->size<loco::DataType::S64>(3);
+ _params->at<loco::DataType::S64>(0) = 1;
+ _params->at<loco::DataType::S64>(1) = 2;
+ _params->at<loco::DataType::S64>(2) = 3;
+
+ _indices->size<loco::DataType::S64>(1);
+ _indices->at<loco::DataType::S64>(0) = 1;
+
+ _gather->params(_params);
+ _gather->indices(_indices);
+
+ _gather->name("gather");
+ _params->name("params");
+ _indices->name("indices");
+
+ return _gather;
+ }
+
+protected:
+ luci::CircleGather *_gather = nullptr;
+ luci::CircleConst *_params = nullptr;
+ luci::CircleConst *_indices = nullptr;
+};
+
+/**
+ *
+ * Graph that has a Gather S32 Op with axis = 1 and with const inputs
+ *
+ * BEFORE
+ * params: [Const] (shape: [2, 3], values: [0, 1, 2, 3, 4, 5])
+ * indices: [Const] (shape: [2], values: [2, 1])
+ *
+ * [params] [indices]
+ * | |
+ * ---[Gather]---
+ *
+ * AFTER
+ * [Const] (shape: [2, 2], values: [2, 1, 5, 4])
+ *
+ */
+
+class S32FoldGatherTwoDimsTest : public luci::ConstantFoldingAddTestGraph, public ::testing::Test
+{
+public:
+ S32FoldGatherTwoDimsTest() : luci::ConstantFoldingAddTestGraph({4, 2}, loco::DataType::S32) {}
+
+ virtual void SetUp() { init(); }
+
+ loco::Node *createFoldedPattern() override
+ {
+ _gather = _g.nodes()->create<luci::CircleGather>();
+ _params = _g.nodes()->create<luci::CircleConst>();
+ _indices = _g.nodes()->create<luci::CircleConst>();
+
+ _gather->dtype(loco::DataType::S32);
+ _params->dtype(loco::DataType::S32);
+ _indices->dtype(loco::DataType::S32);
+
+ _params->shape({2, 3});
+ _indices->shape({2});
+
+ _params->size<loco::DataType::S32>(6);
+ _params->at<loco::DataType::S32>(0) = 0;
+ _params->at<loco::DataType::S32>(1) = 1;
+ _params->at<loco::DataType::S32>(2) = 2;
+ _params->at<loco::DataType::S32>(3) = 3;
+ _params->at<loco::DataType::S32>(4) = 4;
+ _params->at<loco::DataType::S32>(5) = 5;
+
+ _indices->size<loco::DataType::S32>(2);
+ _indices->at<loco::DataType::S32>(0) = 2;
+ _indices->at<loco::DataType::S32>(1) = 1;
+
+ _gather->params(_params);
+ _gather->indices(_indices);
+
+ _gather->axis(1);
+
+ _gather->name("gather");
+ _params->name("params");
+ _indices->name("indices");
+
+ return _gather;
+ }
+
+protected:
+ luci::CircleGather *_gather = nullptr;
+ luci::CircleConst *_params = nullptr;
+ luci::CircleConst *_indices = nullptr;
+};
+
+} // namespace
+
+TEST(FoldGatherTest, name)
+{
+ luci::FoldGatherPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST_F(S64FoldGatherSimpleTest, fold_gather_simple)
+{
+ luci::FoldGatherPass pass;
+ while (pass.run(graph()))
+ ;
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_NE(nullptr, folded_const);
+
+ // Chec type, shape, values of folded const
+ EXPECT_EQ(loco::DataType::S64, folded_const->dtype());
+ EXPECT_EQ(1, folded_const->rank());
+ EXPECT_EQ(1, folded_const->dim(0).value());
+ EXPECT_EQ(2, folded_const->at<loco::DataType::S64>(0));
+}
+
+TEST_F(S32FoldGatherTwoDimsTest, fold_gather_with_two_dim)
+{
+ luci::FoldGatherPass pass;
+ while (pass.run(graph()))
+ ;
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_NE(nullptr, folded_const);
+
+ // Chec type, shape, values of folded const
+ EXPECT_EQ(loco::DataType::S32, folded_const->dtype());
+ EXPECT_EQ(2, folded_const->rank());
+ EXPECT_EQ(2, folded_const->dim(0).value());
+ EXPECT_EQ(2, folded_const->dim(1).value());
+
+ EXPECT_EQ(2, folded_const->at<loco::DataType::S32>(0));
+ EXPECT_EQ(1, folded_const->at<loco::DataType::S32>(1));
+ EXPECT_EQ(5, folded_const->at<loco::DataType::S32>(2));
+ EXPECT_EQ(4, folded_const->at<loco::DataType::S32>(3));
+}
+
+TEST_F(S64FoldGatherSimpleTest, illegal_input_NEG)
+{
+ _indices->dtype(loco::DataType::FLOAT32);
+
+ luci::FoldGatherPass pass;
+ EXPECT_ANY_THROW(pass.run(graph()));
+}
+
+TEST_F(S64FoldGatherSimpleTest, illegal_axis_NEG)
+{
+ _gather->axis(1);
+
+ luci::FoldGatherPass pass;
+ EXPECT_ANY_THROW(pass.run(graph()));
+}
diff --git a/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp b/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp
index de973a431..68136b244 100644
--- a/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp
+++ b/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp
@@ -186,12 +186,12 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8)
// (1) normal case: qparam is propagated to input_1 and input_2
// (2) input used by other Op: input_1 is an input of input_2. qparam is propagated only to
// input_2
- // (3) subsequent concat: input_1 is concat. qparam is propagated only to input_2
+ // (3) subsequent concat: input_1 is concat. qparam is propagated to subsequent concat
// (4) const input: input_1 is const. constant values are quantized
// normal case: qparam of concat_node is propagated to input_1 and input_2
SimpleConcatGraph g(loco::DataType::U8);
- luci::propagate_concat_quantparam(&g.concat_node, loco::DataType::U8);
+ luci::propagate_concat_quantparam(&g.concat_node);
EXPECT_FLOAT_EQ(3.14, g.concat_node.quantparam()->scale[0]);
EXPECT_EQ(77, g.concat_node.quantparam()->zerop[0]);
EXPECT_FLOAT_EQ(3.14, g.input_1.quantparam()->scale[0]);
@@ -202,7 +202,7 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8)
// input_1 is an input of input_2. qparam is propagated only to input_2
SimpleConcatGraph g2(loco::DataType::U8);
g2.input_2.input(&g2.input_1);
- luci::propagate_concat_quantparam(&g2.concat_node, loco::DataType::U8);
+ luci::propagate_concat_quantparam(&g2.concat_node);
EXPECT_FLOAT_EQ(3.14, g2.concat_node.quantparam()->scale[0]);
EXPECT_EQ(77, g2.concat_node.quantparam()->zerop[0]);
EXPECT_FLOAT_EQ(1.0, g2.input_1.quantparam()->scale[0]);
@@ -210,19 +210,19 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8)
EXPECT_FLOAT_EQ(3.14, g2.input_2.quantparam()->scale[0]);
EXPECT_EQ(77, g2.input_2.quantparam()->zerop[0]);
- // input_1 is concat. qparam is propagated only to input_2
+ // input_1 is concat. qparam is propagated to subsequent concat
SubsequentConcatGraph sg(loco::DataType::U8);
- luci::propagate_concat_quantparam(&sg.concat_node, loco::DataType::U8);
+ luci::propagate_concat_quantparam(&sg.concat_node);
EXPECT_FLOAT_EQ(3.14, sg.concat_node.quantparam()->scale[0]);
EXPECT_EQ(77, sg.concat_node.quantparam()->zerop[0]);
- EXPECT_FLOAT_EQ(1.0, sg.input_1.quantparam()->scale[0]);
- EXPECT_EQ(1, sg.input_1.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(3.14, sg.input_1.quantparam()->scale[0]);
+ EXPECT_EQ(77, sg.input_1.quantparam()->zerop[0]);
EXPECT_FLOAT_EQ(3.14, sg.input_2.quantparam()->scale[0]);
EXPECT_EQ(77, sg.input_2.quantparam()->zerop[0]);
// input_1 is const. const values are quantized with the qparam of concat
ConstInputConcatGraph cg(loco::DataType::U8);
- luci::propagate_concat_quantparam(cg.concat_node, loco::DataType::U8);
+ luci::propagate_concat_quantparam(cg.concat_node);
EXPECT_FLOAT_EQ(0.1, cg.concat_node->quantparam()->scale[0]);
EXPECT_EQ(10, cg.concat_node->quantparam()->zerop[0]);
const auto cg_input_1 = loco::must_cast<luci::CircleConst *>(cg.concat_node->values(0));
@@ -248,7 +248,7 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8_NEG)
// concat has fused activation function
g.concat_node.fusedActivationFunction(luci::FusedActFunc::RELU);
- luci::propagate_concat_quantparam(&g.concat_node, loco::DataType::U8);
+ luci::propagate_concat_quantparam(&g.concat_node);
EXPECT_FLOAT_EQ(3.14, g.concat_node.quantparam()->scale[0]);
EXPECT_EQ(77, g.concat_node.quantparam()->zerop[0]);
EXPECT_FLOAT_EQ(1.0, g.input_1.quantparam()->scale[0]);
@@ -261,7 +261,7 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8_NEG)
// const values are quantized using its min/max
ConstInputConcatGraph cg(loco::DataType::U8);
cg.concat_node->fusedActivationFunction(luci::FusedActFunc::RELU);
- luci::propagate_concat_quantparam(cg.concat_node, loco::DataType::U8);
+ luci::propagate_concat_quantparam(cg.concat_node);
EXPECT_FLOAT_EQ(0.1, cg.concat_node->quantparam()->scale[0]);
EXPECT_EQ(10, cg.concat_node->quantparam()->zerop[0]);
const auto cg_input_1 = loco::must_cast<luci::CircleConst *>(cg.concat_node->values(0));
@@ -283,12 +283,12 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16)
// (1) normal case: qparam is propagated to input_1 and input_2
// (2) input used by other Op: input_1 is an input of input_2. qparam is propagated only to
// input_2
- // (3) subsequent concat: input_1 is concat. qparam is propagated only to input_2
+ // (3) subsequent concat: input_1 is concat. qparam is propagated to subsequent concat
// (4) const input: input_1 is const. constant values are quantized
// normal case: qparam of concat_node is propagated to input_1 and input_2
SimpleConcatGraph g(loco::DataType::S16);
- luci::propagate_concat_quantparam(&g.concat_node, loco::DataType::S16);
+ luci::propagate_concat_quantparam(&g.concat_node);
EXPECT_FLOAT_EQ(3.14, g.concat_node.quantparam()->scale[0]);
EXPECT_EQ(0, g.concat_node.quantparam()->zerop[0]);
EXPECT_FLOAT_EQ(3.14, g.input_1.quantparam()->scale[0]);
@@ -299,7 +299,7 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16)
// input_1 is an input of input_2. qparam is propagated only to input_2
SimpleConcatGraph g2(loco::DataType::S16);
g2.input_2.input(&g2.input_1);
- luci::propagate_concat_quantparam(&g2.concat_node, loco::DataType::S16);
+ luci::propagate_concat_quantparam(&g2.concat_node);
EXPECT_FLOAT_EQ(3.14, g2.concat_node.quantparam()->scale[0]);
EXPECT_EQ(0, g2.concat_node.quantparam()->zerop[0]);
EXPECT_FLOAT_EQ(1.0, g2.input_1.quantparam()->scale[0]);
@@ -309,17 +309,17 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16)
// input_1 is concat. qparam is propagated only to input_2
SubsequentConcatGraph sg(loco::DataType::S16);
- luci::propagate_concat_quantparam(&sg.concat_node, loco::DataType::S16);
+ luci::propagate_concat_quantparam(&sg.concat_node);
EXPECT_FLOAT_EQ(3.14, sg.concat_node.quantparam()->scale[0]);
EXPECT_EQ(0, sg.concat_node.quantparam()->zerop[0]);
- EXPECT_FLOAT_EQ(1.0, sg.input_1.quantparam()->scale[0]);
+ EXPECT_FLOAT_EQ(3.14, sg.input_1.quantparam()->scale[0]);
EXPECT_EQ(0, sg.input_1.quantparam()->zerop[0]);
EXPECT_FLOAT_EQ(3.14, sg.input_2.quantparam()->scale[0]);
EXPECT_EQ(0, sg.input_2.quantparam()->zerop[0]);
// input_1 is const. const values are quantized with the qparam of concat
ConstInputConcatGraph cg(loco::DataType::S16);
- luci::propagate_concat_quantparam(cg.concat_node, loco::DataType::S16);
+ luci::propagate_concat_quantparam(cg.concat_node);
EXPECT_FLOAT_EQ(0.1, cg.concat_node->quantparam()->scale[0]);
EXPECT_EQ(0, cg.concat_node->quantparam()->zerop[0]);
const auto cg_input_1 = loco::must_cast<luci::CircleConst *>(cg.concat_node->values(0));
@@ -345,7 +345,7 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16_NEG)
// concat has fused activation function
g.concat_node.fusedActivationFunction(luci::FusedActFunc::RELU);
- luci::propagate_concat_quantparam(&g.concat_node, loco::DataType::S16);
+ luci::propagate_concat_quantparam(&g.concat_node);
EXPECT_FLOAT_EQ(3.14, g.concat_node.quantparam()->scale[0]);
EXPECT_EQ(0, g.concat_node.quantparam()->zerop[0]);
EXPECT_FLOAT_EQ(1.0, g.input_1.quantparam()->scale[0]);
@@ -358,7 +358,7 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16_NEG)
// const values are quantized using its min/max
ConstInputConcatGraph cg(loco::DataType::S16);
cg.concat_node->fusedActivationFunction(luci::FusedActFunc::RELU);
- luci::propagate_concat_quantparam(cg.concat_node, loco::DataType::S16);
+ luci::propagate_concat_quantparam(cg.concat_node);
EXPECT_FLOAT_EQ(0.1, cg.concat_node->quantparam()->scale[0]);
EXPECT_EQ(0, cg.concat_node->quantparam()->zerop[0]);
const auto cg_input_1 = loco::must_cast<luci::CircleConst *>(cg.concat_node->values(0));
diff --git a/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp b/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp
new file mode 100644
index 000000000..b4975486d
--- /dev/null
+++ b/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp
@@ -0,0 +1,482 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/PropagateQParamBackwardPass.h"
+#include "QuantizationUtils.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Service/Nodes/CircleConst.h>
+#include <luci/Log.h>
+
+#include <cmath>
+
+namespace
+{
+
+void quant_const_values(luci::CircleConst *const_node, float scaling_factor, float zerop,
+ loco::DataType quant_type)
+{
+ uint32_t size = const_node->size<loco::DataType::FLOAT32>();
+
+ const float scaling_factor_inv = 1.0 / scaling_factor;
+ std::vector<int32_t> quantized_values(size);
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ auto data = static_cast<double>(const_node->at<loco::DataType::FLOAT32>(i));
+ double quantized_data = std::round(data * scaling_factor_inv) + zerop;
+ constexpr double int_max = static_cast<double>(std::numeric_limits<int32_t>::max());
+ constexpr double int_min = static_cast<double>(std::numeric_limits<int32_t>::min());
+ quantized_data = std::min(int_max, std::max(int_min, quantized_data));
+
+ quantized_values[i] = static_cast<int32_t>(quantized_data);
+ }
+
+ switch (quant_type)
+ {
+ case loco::DataType::U8:
+ const_node->dtype(loco::DataType::U8); // change the type of tensor
+ const_node->size<loco::DataType::U8>(size); // resize tensor
+ for (uint32_t i = 0; i < size; ++i)
+ const_node->at<loco::DataType::U8>(i) = std::min(255, std::max(0, quantized_values[i]));
+ break;
+ case loco::DataType::S16:
+ assert(zerop == 0);
+ const_node->dtype(loco::DataType::S16); // change the type of tensor
+ const_node->size<loco::DataType::S16>(size); // resize tensor
+ for (uint32_t i = 0; i < size; ++i)
+ const_node->at<loco::DataType::S16>(i) =
+ std::min(32767, std::max(-32767, quantized_values[i]));
+ break;
+ default:
+ throw std::runtime_error("Unsupported data type");
+ }
+}
+
+void overwrite_quantparam(const luci::CircleNode *source, luci::CircleNode *target)
+{
+ auto source_qparam = source->quantparam();
+ if (source_qparam == nullptr)
+ throw std::runtime_error("source quantparam is not found during overwrite");
+
+ auto target_qparam = target->quantparam();
+ if (target_qparam == nullptr)
+ {
+ auto quantparam = std::make_unique<luci::CircleQuantParam>();
+ target->quantparam(std::move(quantparam));
+ target_qparam = target->quantparam();
+
+ if (target_qparam == nullptr)
+ throw std::runtime_error("Creating new quant param failed");
+ }
+ target_qparam->min = source_qparam->min;
+ target_qparam->max = source_qparam->max;
+ target_qparam->scale = source_qparam->scale;
+ target_qparam->zerop = source_qparam->zerop;
+ target_qparam->quantized_dimension = source_qparam->quantized_dimension;
+}
+
+/**
+ * Tells if pad_v2 quantization should ignore padding value
+ * In that case padding const will be quantized with input parameters, and probably clipped
+ */
+bool ignore_pad_v2_const_quantization(const luci::CirclePadV2 *pad)
+{
+ // This is a workaround to quantize pad generated from MaxPoolWithArgmax operation properly
+ // TODO use metadata hints to detect this case
+ auto const_value_node = dynamic_cast<const luci::CircleConst *>(pad->arg(2));
+ if (!const_value_node)
+ return false;
+ if (const_value_node->dtype() == loco::DataType::FLOAT32)
+ {
+ float const_value = const_value_node->at<loco::DataType::FLOAT32>(0);
+ if (const_value == std::numeric_limits<float>::lowest())
+ return true;
+ }
+ return false;
+}
+
+/** EXAMPLE
+ *
+ * BEFORE
+ *
+ * [CircleNode] [CircleConst]
+ * (qparam1) (FP32)
+ * \ /
+ * \ /
+ * [CirclePack]
+ * (qparam2)
+ *
+ * AFTER
+ *
+ * [CircleNode] [CircleConst] [CircleConst] <- Dead node
+ * (qparam2) (qparam2) (FP32)
+ * \ /
+ * \ /
+ * [CirclePack]
+ * (qparam2)
+ *
+ * NOTE Quantization parameter of CirclePack (qparam2) is propagated to the inputs.
+ */
+void propagate_pack_quantparam(luci::CirclePack *pack)
+{
+ assert(pack->quantparam() != nullptr);
+
+ const auto num_inputs = pack->values_count();
+
+ for (uint32_t i = 0; i < num_inputs; i++)
+ {
+ auto node = loco::must_cast<luci::CircleNode *>(pack->arg(i));
+
+ // Quantize constant values
+ if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
+ {
+ luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
+ if (const_node->dtype() != loco::DataType::FLOAT32)
+ throw std::runtime_error("Unsupported data type for constant input of pack Op");
+
+ const auto pack_qparam = pack->quantparam();
+ if (pack_qparam == nullptr)
+ throw std::runtime_error("quantparam of pack is not found during propagation");
+
+ assert(pack_qparam->scale.size() == 1);
+ assert(pack_qparam->zerop.size() == 1);
+ const auto scaling_factor = pack_qparam->scale[0];
+ const auto zerop = pack_qparam->zerop[0];
+
+ auto new_const = luci::clone(const_node);
+ quant_const_values(new_const, scaling_factor, zerop, pack->dtype());
+ pack->values(i, new_const);
+ overwrite_quantparam(pack, new_const);
+ }
+ else
+ {
+ const auto succs = loco::succs(node);
+ if (succs.size() > 1)
+ continue;
+
+ // Non-const input must have been quantized
+ assert(node->quantparam() != nullptr);
+ overwrite_quantparam(pack, node);
+ }
+ }
+}
+
+/** EXAMPLE
+ *
+ *
+ *
+ * BEFORE
+ *
+ * [CircleNode] [CircleConst] [CircleConst] [CircleNode]
+ * (S32) (S32) (FP32) (U8 qparam1)
+ * \ \ / /
+ * \ \ / /
+ * \ \ / /
+ * -------[CircleOneHot]-------
+ * (U8 qparam2)
+ *
+ * AFTER
+ *
+ * [CircleNode] [CircleConst] [CircleConst] [CircleNode] [CircleConst] <- Dead node
+ * (S32) (S32) (U8 qparam2) (U8 qparam2) (FP32)
+ * \ \ / /
+ * \ \ / /
+ * \ \ / /
+ * -------[CircleOneHot]-------
+ * (U8 qparam2)
+ *
+ * NOTE Quantization parameter of CircleOneHot (qparam2) is propagated to on_value/off_value.
+ */
+void propagate_one_hot_quantparam(luci::CircleOneHot *one_hot)
+{
+ assert(one_hot->quantparam() != nullptr);
+
+ // Propagate quantization parameters from output to inputs,
+ // to fit both input and counstant_value in one quant range.
+ auto quant_input = [one_hot](void (luci::CircleOneHot::*arg_setter)(loco::Node *),
+ loco::Node *(luci::CircleOneHot::*arg_getter)() const) {
+ auto node = loco::must_cast<luci::CircleNode *>((one_hot->*arg_getter)());
+
+ // Quantize constant values
+ if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
+ {
+ luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
+ if (is_quantized(const_node))
+ return;
+
+ if (const_node->dtype() != loco::DataType::FLOAT32)
+ throw std::runtime_error("Unsupported data type for constant input of OneHot Op");
+
+ const auto qparam = one_hot->quantparam();
+ if (qparam == nullptr)
+ throw std::runtime_error("quantparam of OneHot is not found during propagation");
+
+ assert(qparam->scale.size() == 1);
+ const auto scaling_factor = qparam->scale.at(0);
+ const auto zerop = qparam->zerop.at(0);
+
+ auto new_const = luci::clone(const_node);
+ quant_const_values(new_const, scaling_factor, zerop, one_hot->dtype());
+ overwrite_quantparam(one_hot, new_const);
+ (one_hot->*arg_setter)(new_const);
+ }
+ else
+ {
+ const auto succs = loco::succs(node);
+ if (succs.size() > 1)
+ return;
+
+ // Non-const input must have been quantized
+ assert(node->quantparam() != nullptr);
+ overwrite_quantparam(one_hot, node);
+ }
+ };
+
+ quant_input(&luci::CircleOneHot::on_value, &luci::CircleOneHot::on_value);
+ quant_input(&luci::CircleOneHot::off_value, &luci::CircleOneHot::off_value);
+}
+
+} // namespace
+
+namespace luci
+{
+
+/** BEFORE
+ *
+ * [CircleNode] [CircleConst]
+ * (U8 qparam1) (FP32)
+ * \ /
+ * \ /
+ * [CircleConcatenation]
+ * (U8 qparam2)
+ *
+ * AFTER
+ * [CircleNode] [CircleConst] [CircleConst] <- Dead node
+ * (U8 qparam2) (U8 qparam2) (FP32)
+ * \ /
+ * \ /
+ * [CircleConcatenation]
+ * (U8 qparam2)
+ */
+void propagate_concat_quantparam(luci::CircleConcatenation *concat)
+{
+ assert(concat->quantparam() != nullptr);
+
+ const auto num_inputs = concat->numValues();
+
+ // Quantize const inputs using their values if concat has fused act function
+ if (concat->fusedActivationFunction() != luci::FusedActFunc::NONE)
+ {
+ for (uint32_t i = 0; i < num_inputs; i++)
+ {
+ auto node = concat->arg(i);
+ auto const_node = dynamic_cast<luci::CircleConst *>(node);
+ if (const_node != nullptr)
+ {
+ auto new_const = luci::clone(const_node);
+ quant_const(new_const, concat->dtype());
+ concat->values(i, new_const);
+ }
+ }
+ return;
+ }
+
+ for (uint32_t i = 0; i < num_inputs; i++)
+ {
+ auto node = loco::must_cast<luci::CircleNode *>(concat->arg(i));
+
+ // Quantize constant values
+ if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
+ {
+ luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
+
+ const auto concat_qparam = concat->quantparam();
+ assert(concat_qparam->scale.size() == 1);
+ const auto scaling_factor = concat_qparam->scale[0];
+ const auto zerop = concat_qparam->zerop[0];
+
+ auto new_const = luci::clone(const_node);
+ quant_const_values(new_const, scaling_factor, zerop, concat->dtype());
+ concat->values(i, new_const);
+ overwrite_quantparam(concat, new_const);
+ }
+ else
+ {
+ const auto succs = loco::succs(node);
+ if (succs.size() > 1)
+ continue;
+
+ // Non-const input must have been quantized
+ assert(node->quantparam() != nullptr);
+ overwrite_quantparam(concat, node);
+ }
+ }
+}
+
+/** BEFORE
+ *
+ * [CircleNode] [CircleConst] [CircleConst]
+ * (U8 qparam1) (S32) (FP32)
+ * \ | /
+ * \ | /
+ * [CirclePadV2]
+ * (U8 qparam2)
+ *
+ * AFTER (case 1)
+ *
+ * By default qparam is propagated from output to inputs to meet backend requirements.
+ *
+ * [CircleNode] [CircleConst] [CircleConst] [CircleConst] <- Dead node
+ * (U8 qparam2) (S32) (U8 qparam2) (FP32)
+ * \ | /
+ * \ | /
+ * [CirclePadV2]
+ * (U8 qparam2)
+ *
+ * AFTER (case 2)
+ *
+ * In case padded value is the lowest float value
+ * Qparam is propagated from input to output and constant.
+ *
+ * This is a special case for optimization constructed pad, needed to guarantee that
+ * extremely large negative constant do not stretch output quantization range.
+ *
+ * [CircleNode] [CircleConst] [CircleConst] [CircleConst] <- Dead node
+ * (U8 qparam1) (S32) (U8 qparam1) (FP32)
+ * \ | /
+ * \ | /
+ * [CirclePadV2]
+ * (U8 qparam1)
+ */
+void propagate_pad_v2_quantparam(luci::CirclePadV2 *pad_v2)
+{
+ if (ignore_pad_v2_const_quantization(pad_v2))
+ {
+ // propagate input quantization paramters from input to output and padding const value
+ auto pad_v2_input = loco::must_cast<luci::CircleNode *>(pad_v2->arg(0));
+ overwrite_quantparam(pad_v2_input, pad_v2);
+
+ auto const_value_node = loco::must_cast<luci::CircleConst *>(
+ pad_v2->arg(2)); // FIX ignore_pad_v2_const_quantization UNLESS
+ auto new_const = luci::clone(const_value_node);
+
+ const auto pad_v2_input_qparam = pad_v2_input->quantparam();
+ assert(pad_v2_input_qparam != nullptr);
+ assert(pad_v2_input_qparam->scale.size() == 1);
+ const auto scaling_factor = pad_v2_input_qparam->scale.at(0);
+ const auto zerop = pad_v2_input_qparam->zerop.at(0);
+
+ quant_const_values(new_const, scaling_factor, zerop, pad_v2->dtype());
+ overwrite_quantparam(pad_v2_input, new_const);
+ pad_v2->constant_values(new_const);
+ return;
+ }
+
+ // Propagate quantization paramters from output to inputs,
+ // to fit both input and counstant_value in one quant range.
+ auto quant_input = [pad_v2](void (CirclePadV2::*arg_setter)(loco::Node *), uint32_t arg) {
+ auto node = loco::must_cast<luci::CircleNode *>(pad_v2->arg(arg));
+
+ // Quantize constant values
+ if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
+ {
+ luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
+ if (is_quantized(const_node))
+ return;
+
+ if (const_node->dtype() != loco::DataType::FLOAT32)
+ throw std::runtime_error("Unsupported data type for constant input of PadV2 Op");
+
+ const auto pad_v2_qparam = pad_v2->quantparam();
+ if (pad_v2_qparam == nullptr)
+ throw std::runtime_error("quantparam of PadV2 is not found during propagation");
+
+ assert(pad_v2_qparam->scale.size() == 1);
+ const auto scaling_factor = pad_v2_qparam->scale.at(0);
+ const auto zerop = pad_v2_qparam->zerop.at(0);
+
+ auto new_const = luci::clone(const_node);
+ quant_const_values(new_const, scaling_factor, zerop, pad_v2->dtype());
+ overwrite_quantparam(pad_v2, new_const);
+ (pad_v2->*arg_setter)(new_const);
+ }
+ else
+ {
+ const auto succs = loco::succs(node);
+ if (succs.size() > 1)
+ return;
+
+ // Non-const input must have been quantized
+ assert(node->quantparam() != nullptr);
+ overwrite_quantparam(pad_v2, node);
+ }
+ };
+
+ quant_input(&CirclePadV2::input, 0);
+ quant_input(&CirclePadV2::constant_values, 2);
+}
+
+} // namespace luci
+
+namespace
+{
+
+// Visitor to propagate quantization parameters backwards
+struct PropagateQParamBackward final : public luci::CircleNodeMutableVisitor<void>
+{
+ void visit(luci::CircleNode *) {}
+
+ void visit(luci::CircleConcatenation *node) { propagate_concat_quantparam(node); }
+
+ void visit(luci::CircleOneHot *node) { propagate_one_hot_quantparam(node); }
+
+ void visit(luci::CirclePack *node) { propagate_pack_quantparam(node); }
+
+ void visit(luci::CirclePadV2 *node) { propagate_pad_v2_quantparam(node); }
+};
+
+} // namespace
+
+namespace luci
+{
+
+bool PropagateQParamBackwardPass::run(loco::Graph *g)
+{
+ LOGGER(l);
+
+ // We use reverse post-order traversal as qparam is propagated backward
+ auto nodes = loco::postorder_traversal(loco::output_nodes(g));
+ std::reverse(nodes.begin(), nodes.end());
+ for (auto node : nodes)
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ INFO(l) << "PropagateQParamBackwardPass visit node: " << circle_node->name() << std::endl;
+
+ // We can't propagate non-existent qparam
+ if (circle_node->quantparam() == nullptr)
+ continue;
+
+ PropagateQParamBackward pqb;
+ circle_node->accept(&pqb);
+ }
+
+ // This pass is only run once, so return false
+ // TODO Refactoring not to return meaningless value
+ return false;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/PropagateQParamBackwardPass.test.cpp b/compiler/luci/pass/src/PropagateQParamBackwardPass.test.cpp
new file mode 100644
index 000000000..33af70449
--- /dev/null
+++ b/compiler/luci/pass/src/PropagateQParamBackwardPass.test.cpp
@@ -0,0 +1,167 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/PropagateQParamBackwardPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+using namespace luci;
+
+namespace
+{
+
+void set_qparam(luci::CircleNode *node, float scale, int64_t zp)
+{
+ auto qparam = std::make_unique<luci::CircleQuantParam>();
+ qparam->scale.emplace_back(scale);
+ qparam->zerop.emplace_back(zp);
+
+ node->quantparam(std::move(qparam));
+}
+
+/**
+ * @brief Base Test Graph
+ */
+struct TestGraph
+{
+public:
+ virtual void init(void) = 0;
+};
+
+/**
+ * Graph with two concats
+ *
+ * [CircleInput] [CircleConst]
+ * \ /
+ * [CircleConcatenation] [CircleConst]
+ * | |
+ * [CircleConcatenation]
+ * |
+ * [CircleOutput]
+ *
+ * BEFORE
+ * - Concat1 and Concat 2 have different qparams
+ *
+ * AFTER
+ * - All Ops have the same qparam
+ */
+struct SubsequentConcatGraph : public TestGraph
+{
+public:
+ void init(void) final
+ {
+ // graph input and output
+ auto graph_input = g.inputs()->create();
+ auto graph_output = g.outputs()->create();
+
+ // input
+ input = g.nodes()->create<luci::CircleInput>();
+ input->index(graph_input->index());
+ input->shape({1, 4, 4, 3});
+ input->dtype(loco::DataType::U8);
+ set_qparam(input, 1.0, 1);
+
+ // const1
+ const1 = g.nodes()->create<luci::CircleConst>();
+ const1->shape({1, 4, 4, 3});
+ const1->dtype(loco::DataType::FLOAT32);
+ const1->size<loco::DataType::FLOAT32>(48);
+ for (uint32_t i = 0; i < 48; i++)
+ const1->at<loco::DataType::FLOAT32>(i) = i;
+
+ // concat1
+ concat1 = g.nodes()->create<luci::CircleConcatenation>(2);
+ concat1->shape({1, 4, 4, 6});
+ concat1->dtype(loco::DataType::U8);
+ set_qparam(concat1, 2.0, 2);
+ concat1->values(0, input);
+ concat1->values(1, const1);
+ concat1->fusedActivationFunction(luci::FusedActFunc::NONE);
+
+ // const2
+ const2 = g.nodes()->create<luci::CircleConst>();
+ const2->shape({1, 4, 4, 3});
+ const2->dtype(loco::DataType::FLOAT32);
+ const2->size<loco::DataType::FLOAT32>(48);
+ for (uint32_t i = 0; i < 48; i++)
+ const2->at<loco::DataType::FLOAT32>(i) = i;
+
+ // concat2
+ concat2 = g.nodes()->create<luci::CircleConcatenation>(2);
+ concat2->shape({1, 4, 4, 9});
+ concat2->dtype(loco::DataType::U8);
+ set_qparam(concat2, 3.0, 3);
+ concat2->values(0, concat1);
+ concat2->values(1, const2);
+ concat2->fusedActivationFunction(luci::FusedActFunc::NONE);
+
+ // output
+ output = g.nodes()->create<luci::CircleOutput>();
+ output->index(graph_output->index());
+ output->from(concat2);
+ output->shape({1, 4, 4, 9});
+ output->dtype(loco::DataType::U8);
+ set_qparam(output, 3.0, 3);
+ }
+
+public:
+ loco::Graph g;
+ CircleInput *input = nullptr;
+ CircleConcatenation *concat1 = nullptr;
+ CircleConcatenation *concat2 = nullptr;
+ CircleConst *const1 = nullptr;
+ CircleConst *const2 = nullptr;
+ CircleOutput *output = nullptr;
+};
+
+} // namespace
+
+TEST(PropagateQParamBackwardPassTest, name)
+{
+ luci::PropagateQParamBackwardPass pass(loco::DataType::U8);
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST(PropagateQParamBackwardPassTest, subsequent_propagation)
+{
+ SubsequentConcatGraph graph;
+
+ graph.init();
+
+ luci::PropagateQParamBackwardPass pass(loco::DataType::U8);
+
+ pass.run(&graph.g);
+
+ EXPECT_EQ(3.0, graph.concat2->quantparam()->scale[0]);
+ EXPECT_EQ(3, graph.concat2->quantparam()->zerop[0]);
+
+ auto const2 = loco::must_cast<CircleNode *>(graph.concat2->values(1));
+ EXPECT_EQ(3.0, const2->quantparam()->scale[0]);
+ EXPECT_EQ(3, const2->quantparam()->zerop[0]);
+
+ EXPECT_EQ(3.0, graph.concat1->quantparam()->scale[0]);
+ EXPECT_EQ(3, graph.concat1->quantparam()->zerop[0]);
+
+ auto const1 = loco::must_cast<CircleNode *>(graph.concat1->values(1));
+ EXPECT_EQ(3.0, const1->quantparam()->scale[0]);
+ EXPECT_EQ(3, const1->quantparam()->zerop[0]);
+
+ EXPECT_EQ(3.0, graph.input->quantparam()->scale[0]);
+ EXPECT_EQ(3, graph.input->quantparam()->zerop[0]);
+}
diff --git a/compiler/luci/pass/src/PropagateQParamForwardPass.cpp b/compiler/luci/pass/src/PropagateQParamForwardPass.cpp
new file mode 100644
index 000000000..003e4c293
--- /dev/null
+++ b/compiler/luci/pass/src/PropagateQParamForwardPass.cpp
@@ -0,0 +1,194 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/PropagateQParamForwardPass.h"
+
+#include "QuantizationUtils.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Log.h>
+
+#include <iostream>
+
+namespace
+{
+
+bool copy_qparam(luci::CircleQuantParam *src, luci::CircleQuantParam *dst)
+{
+ assert(src->scale.size() == dst->scale.size());
+ assert(src->zerop.size() == dst->zerop.size());
+
+ // src and dst have the same qparam
+ if (std::equal(src->scale.begin(), src->scale.end(), dst->scale.begin()) &&
+ std::equal(src->zerop.begin(), src->zerop.end(), dst->zerop.begin()) &&
+ src->quantized_dimension == dst->quantized_dimension)
+ return false;
+
+ dst->scale.assign(src->scale.begin(), src->scale.end());
+ dst->zerop.assign(src->zerop.begin(), src->zerop.end());
+ dst->quantized_dimension = src->quantized_dimension;
+ return true;
+}
+
+bool copy_qparam(luci::CircleNode *src, luci::CircleNode *dst)
+{
+ // Skip nodes that do not have quantparams
+ auto src_qparam = src->quantparam();
+ if (not src_qparam)
+ return false;
+
+ auto dst_qparam = dst->quantparam();
+ if (not dst_qparam)
+ return false;
+
+ return copy_qparam(src_qparam, dst_qparam);
+}
+
+// Visitor to propagate quantization parameters
+struct PropagateQParamForward final : public luci::CircleNodeMutableVisitor<bool>
+{
+ PropagateQParamForward() = default;
+
+ bool visit(luci::CircleNode *) { return false; }
+
+ bool visit(luci::CircleGather *node)
+ {
+ auto input_node = loco::must_cast<luci::CircleNode *>(node->params());
+ return copy_qparam(input_node, node);
+ }
+
+ bool visit(luci::CircleReshape *node)
+ {
+ auto input_node = loco::must_cast<luci::CircleNode *>(node->tensor());
+ return copy_qparam(input_node, node);
+ }
+
+ bool visit(luci::CircleTranspose *node)
+ {
+ auto input_node = loco::must_cast<luci::CircleNode *>(node->a());
+ return copy_qparam(input_node, node);
+ }
+
+ bool visit(luci::CircleStridedSlice *node)
+ {
+ auto input_node = loco::must_cast<luci::CircleNode *>(node->input());
+ return copy_qparam(input_node, node);
+ }
+
+ bool visit(luci::CircleSplitOut *node)
+ {
+ auto split = loco::must_cast<luci::CircleSplit *>(node->input());
+ auto input_node = loco::must_cast<luci::CircleNode *>(split->input());
+ return copy_qparam(input_node, node);
+ }
+
+ bool visit(luci::CircleSplitVOut *node)
+ {
+ auto splitv = loco::must_cast<luci::CircleSplitV *>(node->input());
+ auto input_node = loco::must_cast<luci::CircleNode *>(splitv->input());
+ return copy_qparam(input_node, node);
+ }
+
+ bool visit(luci::CircleUnpackOut *node)
+ {
+ auto unpack = loco::must_cast<luci::CircleUnpack *>(node->input());
+ auto input_node = loco::must_cast<luci::CircleNode *>(unpack->value());
+ return copy_qparam(input_node, node);
+ }
+
+ // Propagate qparam across Quantize op to ensure
+ // special qparams (pre-defined values, integer scale)
+ bool visit(luci::CircleQuantize *node)
+ {
+ auto input_node = loco::must_cast<luci::CircleNode *>(node->input());
+
+ // Skip if input_node is not quantized activation
+ if (input_node->dtype() != loco::DataType::U8 and input_node->dtype() != loco::DataType::S16)
+ return false;
+
+ // If input_node and node have the same dtype, Quantize op
+ // will do rescale, not requantize for mixed-precision
+ if (input_node->dtype() == node->dtype())
+ return false;
+
+ assert(node->dtype() == loco::DataType::U8 or node->dtype() == loco::DataType::S16);
+
+ auto prev_qparam = node->quantparam();
+ assert(prev_qparam);
+ assert(prev_qparam->scale.size() == 1);
+ assert(prev_qparam->zerop.size() == 1);
+
+ const auto prev_scale = prev_qparam->scale[0];
+ const auto prev_zerop = prev_qparam->zerop[0];
+
+ auto qtype = luci::activation_qtype(input_node);
+ switch (qtype)
+ {
+ case luci::ActivationQType::PreDefinedValue:
+ node->quantparam(luci::make_predefined_qparam(input_node->opcode(), node->dtype()));
+ break;
+ case luci::ActivationQType::IntScale:
+ luci::set_int_scale(node);
+ break;
+ default:
+ break;
+ }
+
+ assert(node->quantparam());
+ assert(node->quantparam()->scale.size() == 1);
+ assert(node->quantparam()->zerop.size() == 1);
+
+ const auto scale = node->quantparam()->scale[0];
+ const auto zerop = node->quantparam()->zerop[0];
+
+ // Compare qparam with saved values to detect update
+ return scale != prev_scale or zerop != prev_zerop;
+ }
+};
+
+} // namespace
+
+namespace luci
+{
+
+bool PropagateQParamForwardPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ LOGGER(l);
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ INFO(l) << "PropagateQParamForwardPass visit node: " << circle_node->name() << std::endl;
+
+ PropagateQParamForward pqp;
+ if (circle_node->accept(&pqp))
+ changed = true;
+
+ if (_TF_style_maxpool)
+ {
+ if (auto maxpool = dynamic_cast<luci::CircleMaxPool2D *>(node))
+ {
+ auto input = loco::must_cast<luci::CircleNode *>(maxpool->value());
+ copy_qparam(input, maxpool);
+ }
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/PropagateQParamForwardPass.test.cpp b/compiler/luci/pass/src/PropagateQParamForwardPass.test.cpp
new file mode 100644
index 000000000..a734c0873
--- /dev/null
+++ b/compiler/luci/pass/src/PropagateQParamForwardPass.test.cpp
@@ -0,0 +1,260 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/PropagateQParamForwardPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+void addQuantParam(luci::CircleNode *node, const std::vector<float> &scale,
+ const std::vector<int64_t> &zp)
+{
+ assert(node->quantparam() == nullptr);
+
+ auto quantparam = std::make_unique<luci::CircleQuantParam>();
+ quantparam->scale = scale;
+ quantparam->zerop = zp;
+ node->quantparam(std::move(quantparam));
+}
+
+/**
+ * Simple graph for test
+ *
+ * BEFORE
+ *
+ * [Conv] (qparam 1)
+ * |
+ * [Reshape] (qparam 2)
+ *
+ * AFTER
+ *
+ * [Conv] (qparam 2)
+ * |
+ * [Reshape] (qparam 2)
+ *
+ */
+class SimpleGraph
+{
+public:
+ SimpleGraph()
+ {
+ input = g.nodes()->create<luci::CircleInput>();
+ conv = g.nodes()->create<luci::CircleConv2D>();
+ reshape = g.nodes()->create<luci::CircleReshape>();
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_input = g.inputs()->create();
+ input->index(graph_input->index());
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ addQuantParam(conv, {0.1, 0.2, 0.3}, {0, 10, 20});
+ addQuantParam(reshape, {0.2, 0.4, 0.6}, {-10, 0, 10});
+
+ conv->input(input);
+ reshape->tensor(conv);
+ output->from(reshape);
+ }
+
+public:
+ loco::Graph g;
+ luci::CircleInput *input = nullptr;
+ luci::CircleConv2D *conv = nullptr;
+ luci::CircleReshape *reshape = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+/**
+ * Test graph for forward propagation in Quantize Op
+ *
+ * BEFORE
+ *
+ * [Tanh U8] (qparam 1 - pre-defined for U8)
+ * |
+ * [Quantize S16] (qparam 2 - not pre-defined value)
+ *
+ * AFTER
+ *
+ * [Tanh U8] (qparam 1 - pre-defined for U8)
+ * |
+ * [Quantize S16] (qparam 3 - pre-defined for S16)
+ *
+ */
+class TanhQuantizeGraph
+{
+public:
+ TanhQuantizeGraph()
+ {
+ input = g.nodes()->create<luci::CircleInput>();
+ tanh = g.nodes()->create<luci::CircleTanh>();
+ quantize = g.nodes()->create<luci::CircleQuantize>();
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_input = g.inputs()->create();
+ input->index(graph_input->index());
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ tanh->dtype(loco::DataType::U8);
+ quantize->dtype(loco::DataType::S16);
+
+ addQuantParam(tanh, {2.0f / 256.0f}, {128}); // pre-defined qparam for U8
+ addQuantParam(quantize, {1.0}, {0}); // not pre-defined values
+
+ tanh->x(input);
+ quantize->input(tanh);
+ output->from(quantize);
+ }
+
+public:
+ loco::Graph g;
+ luci::CircleInput *input = nullptr;
+ luci::CircleTanh *tanh = nullptr;
+ luci::CircleQuantize *quantize = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+/**
+ * Test graph for forward propagation in Quantize Op
+ *
+ * BEFORE
+ *
+ * [Floor U8] (qparam 1 - int scale)
+ * |
+ * [Quantize S16] (qparam 2 - not int scale)
+ *
+ * AFTER
+ *
+ * [Floor U8] (qparam 1 - int scale)
+ * |
+ * [Quantize S16] (qparam 3 - int scale)
+ *
+ */
+class FloorQuantizeGraph
+{
+public:
+ FloorQuantizeGraph()
+ {
+ input = g.nodes()->create<luci::CircleInput>();
+ floor = g.nodes()->create<luci::CircleFloor>();
+ quantize = g.nodes()->create<luci::CircleQuantize>();
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_input = g.inputs()->create();
+ input->index(graph_input->index());
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ floor->dtype(loco::DataType::U8);
+ quantize->dtype(loco::DataType::S16);
+
+ addQuantParam(floor, {4.0f}, {128}); // int scale
+ addQuantParam(quantize, {0.3}, {0}); // not int scale
+
+ floor->x(input);
+ quantize->input(floor);
+ output->from(quantize);
+ }
+
+public:
+ loco::Graph g;
+ luci::CircleInput *input = nullptr;
+ luci::CircleFloor *floor = nullptr;
+ luci::CircleQuantize *quantize = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+} // namespace
+
+TEST(PropagateQParamForwardPassTest, name)
+{
+ luci::PropagateQParamForwardPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST(PropagateQParamForward, simple)
+{
+ SimpleGraph g;
+
+ luci::PropagateQParamForwardPass pass;
+ while (pass.run(&g.g))
+ ;
+
+ EXPECT_FLOAT_EQ(0.1, g.reshape->quantparam()->scale[0]);
+ EXPECT_FLOAT_EQ(0.2, g.reshape->quantparam()->scale[1]);
+ EXPECT_FLOAT_EQ(0.3, g.reshape->quantparam()->scale[2]);
+ EXPECT_EQ(0, g.reshape->quantparam()->zerop[0]);
+ EXPECT_EQ(10, g.reshape->quantparam()->zerop[1]);
+ EXPECT_EQ(20, g.reshape->quantparam()->zerop[2]);
+}
+
+TEST(PropagateQParamForward, wrong_op_NEG)
+{
+ SimpleGraph g;
+ g.output->from(g.conv);
+ g.reshape->drop();
+
+ luci::PropagateQParamForwardPass pass;
+ while (pass.run(&g.g))
+ ;
+
+ EXPECT_FLOAT_EQ(0.1, g.conv->quantparam()->scale[0]);
+ EXPECT_FLOAT_EQ(0.2, g.conv->quantparam()->scale[1]);
+ EXPECT_FLOAT_EQ(0.3, g.conv->quantparam()->scale[2]);
+ EXPECT_EQ(0, g.conv->quantparam()->zerop[0]);
+ EXPECT_EQ(10, g.conv->quantparam()->zerop[1]);
+ EXPECT_EQ(20, g.conv->quantparam()->zerop[2]);
+}
+
+TEST(PropagateQParamForward, tanh_predefined_value)
+{
+ TanhQuantizeGraph g;
+
+ luci::PropagateQParamForwardPass pass;
+ while (pass.run(&g.g))
+ ;
+
+ EXPECT_FLOAT_EQ(1.0f / 32768.0f, g.quantize->quantparam()->scale[0]);
+}
+
+TEST(PropagateQParamForward, floor_int_scale)
+{
+ FloorQuantizeGraph g;
+
+ luci::PropagateQParamForwardPass pass;
+ while (pass.run(&g.g))
+ ;
+
+ EXPECT_FLOAT_EQ(1.0f, g.quantize->quantparam()->scale[0]);
+}
+
+TEST(PropagateQParamForward, same_dtype_NEG)
+{
+ FloorQuantizeGraph g;
+ g.quantize->dtype(loco::DataType::U8);
+
+ luci::PropagateQParamForwardPass pass;
+ while (pass.run(&g.g))
+ ;
+
+ // Qparam is not propagated as ifm/ofm of Quantize Op have the same dtype
+ EXPECT_FLOAT_EQ(0.3f, g.quantize->quantparam()->scale[0]);
+}
diff --git a/compiler/luci/pass/src/PropagateQuantParamPass.cpp b/compiler/luci/pass/src/PropagateQuantParamPass.cpp
deleted file mode 100644
index b1cb7a418..000000000
--- a/compiler/luci/pass/src/PropagateQuantParamPass.cpp
+++ /dev/null
@@ -1,107 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "luci/Pass/PropagateQuantParamPass.h"
-
-#include <luci/IR/CircleNodes.h>
-#include <luci/IR/CircleNodeVisitor.h>
-#include <luci/Log.h>
-
-#include <iostream>
-
-namespace
-{
-
-bool copy_qparam(luci::CircleQuantParam *src, luci::CircleQuantParam *dst)
-{
- assert(src->scale.size() == dst->scale.size());
- assert(src->zerop.size() == dst->zerop.size());
-
- // src and dst have the same qparam
- if (std::equal(src->scale.begin(), src->scale.end(), dst->scale.begin()) &&
- std::equal(src->zerop.begin(), src->zerop.end(), dst->zerop.begin()) &&
- src->quantized_dimension == dst->quantized_dimension)
- return false;
-
- dst->scale.assign(src->scale.begin(), src->scale.end());
- dst->zerop.assign(src->zerop.begin(), src->zerop.end());
- dst->quantized_dimension = src->quantized_dimension;
- return true;
-}
-
-bool copy_qparam(luci::CircleNode *src, luci::CircleNode *dst)
-{
- // Skip nodes that do not have quantparams
- auto src_qparam = src->quantparam();
- if (not src_qparam)
- return false;
-
- auto dst_qparam = dst->quantparam();
- if (not dst_qparam)
- return false;
-
- return copy_qparam(src_qparam, dst_qparam);
-}
-
-// Visitor to propagate quantization parameters
-struct PropagateQuantParam final : public luci::CircleNodeMutableVisitor<bool>
-{
- PropagateQuantParam() = default;
-
- bool visit(luci::CircleNode *) { return false; }
-
- bool visit(luci::CircleReshape *node)
- {
- auto input = node->tensor();
- if (loco::succs(input).size() != 1)
- return false;
-
- auto input_node = loco::must_cast<luci::CircleNode *>(input);
- return copy_qparam(input_node, node);
- }
-
- bool visit(luci::CircleTranspose *node)
- {
- auto input_node = loco::must_cast<luci::CircleNode *>(node->a());
- return copy_qparam(input_node, node);
- }
-
- // TODO : Add more Ops (e.g., layout-changing Ops)
-};
-
-} // namespace
-
-namespace luci
-{
-
-bool PropagateQuantParamPass::run(loco::Graph *g)
-{
- bool changed = false;
- LOGGER(l);
- for (auto node : loco::active_nodes(loco::output_nodes(g)))
- {
- auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- INFO(l) << "PropagateQuantParamPass visit node: " << circle_node->name() << std::endl;
-
- PropagateQuantParam pqp;
- if (circle_node->accept(&pqp))
- changed = true;
- }
-
- return changed;
-}
-
-} // namespace luci
diff --git a/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp b/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp
deleted file mode 100644
index 0f1564223..000000000
--- a/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp
+++ /dev/null
@@ -1,125 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "luci/Pass/PropagateQuantParamPass.h"
-
-#include <luci/IR/CircleNodes.h>
-
-#include <gtest/gtest.h>
-
-namespace
-{
-
-void addQuantParam(luci::CircleNode *node, const std::vector<float> &scale,
- const std::vector<int64_t> &zp)
-{
- assert(node->quantparam() == nullptr);
-
- auto quantparam = std::make_unique<luci::CircleQuantParam>();
- quantparam->scale = scale;
- quantparam->zerop = zp;
- node->quantparam(std::move(quantparam));
-}
-
-/**
- * Simple graph for test
- *
- * BEFORE
- *
- * [Conv] (qparam 1)
- * |
- * [Reshape] (qparam 2)
- *
- * AFTER
- *
- * [Conv] (qparam 2)
- * |
- * [Reshape] (qparam 2)
- *
- */
-class SimpleGraph
-{
-public:
- SimpleGraph()
- {
- input = g.nodes()->create<luci::CircleInput>();
- conv = g.nodes()->create<luci::CircleConv2D>();
- reshape = g.nodes()->create<luci::CircleReshape>();
- output = g.nodes()->create<luci::CircleOutput>();
-
- auto graph_input = g.inputs()->create();
- input->index(graph_input->index());
- auto graph_output = g.outputs()->create();
- output->index(graph_output->index());
-
- addQuantParam(conv, {0.1, 0.2, 0.3}, {0, 10, 20});
- addQuantParam(reshape, {0.2, 0.4, 0.6}, {-10, 0, 10});
-
- conv->input(input);
- reshape->tensor(conv);
- output->from(reshape);
- }
-
-public:
- loco::Graph g;
- luci::CircleInput *input;
- luci::CircleConv2D *conv;
- luci::CircleReshape *reshape;
- luci::CircleOutput *output;
-};
-
-} // namespace
-
-TEST(PropagateQuantParamPassTest, name)
-{
- luci::PropagateQuantParamPass pass;
- auto const name = pass.name();
- ASSERT_NE(nullptr, name);
-}
-
-TEST(PropagateQuantParam, simple)
-{
- SimpleGraph g;
-
- luci::PropagateQuantParamPass pass;
- while (pass.run(&g.g))
- ;
-
- EXPECT_FLOAT_EQ(0.1, g.reshape->quantparam()->scale[0]);
- EXPECT_FLOAT_EQ(0.2, g.reshape->quantparam()->scale[1]);
- EXPECT_FLOAT_EQ(0.3, g.reshape->quantparam()->scale[2]);
- EXPECT_EQ(0, g.reshape->quantparam()->zerop[0]);
- EXPECT_EQ(10, g.reshape->quantparam()->zerop[1]);
- EXPECT_EQ(20, g.reshape->quantparam()->zerop[2]);
-}
-
-TEST(PropagateQuantParam, wrong_op_NEG)
-{
- SimpleGraph g;
- g.output->from(g.conv);
- g.reshape->drop();
-
- luci::PropagateQuantParamPass pass;
- while (pass.run(&g.g))
- ;
-
- EXPECT_FLOAT_EQ(0.1, g.conv->quantparam()->scale[0]);
- EXPECT_FLOAT_EQ(0.2, g.conv->quantparam()->scale[1]);
- EXPECT_FLOAT_EQ(0.3, g.conv->quantparam()->scale[2]);
- EXPECT_EQ(0, g.conv->quantparam()->zerop[0]);
- EXPECT_EQ(10, g.conv->quantparam()->zerop[1]);
- EXPECT_EQ(20, g.conv->quantparam()->zerop[2]);
-}
diff --git a/compiler/luci/pass/src/QuantizationUtils.cpp b/compiler/luci/pass/src/QuantizationUtils.cpp
index 2f6fed46e..ad86cedf4 100644
--- a/compiler/luci/pass/src/QuantizationUtils.cpp
+++ b/compiler/luci/pass/src/QuantizationUtils.cpp
@@ -33,43 +33,6 @@ bool is_quantized(const CircleNode *node)
node->dtype() == loco::DataType::S64); // bias (int16 quant)
}
-// Check if node is weights of conv2d, depthwise_conv2d, or fully_connected layer
-bool is_weights(CircleNode *node)
-{
- auto circle_const = dynamic_cast<CircleConst *>(node);
- if (circle_const == nullptr)
- return false;
-
- auto succs = loco::succs(node);
-
- // Node is weights if it is the weights of all of its successors
- for (auto out : succs)
- {
- bool is_weights = false;
-
- auto conv = dynamic_cast<CircleConv2D *>(out);
- if (conv != nullptr && conv->filter() == circle_const)
- is_weights = true;
-
- auto dw_conv = dynamic_cast<CircleDepthwiseConv2D *>(out);
- if (dw_conv != nullptr && dw_conv->filter() == circle_const)
- is_weights = true;
-
- auto t_conv = dynamic_cast<CircleTransposeConv *>(out);
- if (t_conv != nullptr && t_conv->filter() == circle_const && circle_const->rank() == 4)
- is_weights = true;
-
- auto fc = dynamic_cast<CircleFullyConnected *>(out);
- if (fc != nullptr && fc->weights() == circle_const)
- is_weights = true;
-
- if (!is_weights)
- return false;
- }
-
- return true;
-}
-
uint8_t fp32_to_uint8_cast(float f)
{
assert(std::numeric_limits<uint8_t>::min() <= f);
@@ -77,7 +40,6 @@ uint8_t fp32_to_uint8_cast(float f)
return static_cast<uint8_t>(f);
}
-// Per-layer quantization of weights (const tensor) using given min/max values
void asymmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float max,
float &scaling_factor, int64_t &zp, float &nudged_min,
float &nudged_max)
@@ -107,7 +69,6 @@ void asymmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float
}
}
-// Per-layer quantization of weights (const tensor) using given min/max values
void symmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float max,
float &scaling_factor, int64_t &zp, float &nudged_min,
float &nudged_max)
@@ -315,4 +276,123 @@ uint32_t cal_offset(loco::TensorShape &dimension, uint32_t *indices)
indices[2] * dimension.dim(3).value() + indices[3];
}
+ActivationQType activation_qtype(const CircleNode *node)
+{
+ auto fused_act_node = dynamic_cast<const CircleNodeMixin<CircleNodeTrait::FusedActFunc> *>(node);
+ if (fused_act_node && fused_act_node->fusedActivationFunction() == FusedActFunc::TANH)
+ return ActivationQType::PreDefinedValue;
+
+ switch (node->opcode())
+ {
+ case CircleOpcode::LOGISTIC:
+ case CircleOpcode::TANH:
+ case CircleOpcode::SOFTMAX:
+ return ActivationQType::PreDefinedValue;
+ case CircleOpcode::FLOOR:
+ case CircleOpcode::FLOOR_DIV:
+ case CircleOpcode::FLOOR_MOD:
+ case CircleOpcode::CEIL:
+ return ActivationQType::IntScale;
+ default:
+ break;
+ }
+
+ return ActivationQType::MinMax;
+}
+
+std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, loco::DataType dtype)
+{
+ auto qparam = std::make_unique<CircleQuantParam>();
+
+ auto set_qparam = [&qparam](float scale, int64_t zp) {
+ qparam->scale.emplace_back(scale);
+ qparam->zerop.emplace_back(zp);
+ };
+
+ switch (opcode)
+ {
+ case CircleOpcode::LOGISTIC:
+ if (dtype == loco::DataType::U8)
+ set_qparam(1.0f / 256.0f, 0);
+ else
+ {
+ assert(dtype == loco::DataType::S16);
+ set_qparam(1.0f / 32768.0f, 0);
+ }
+ break;
+ case CircleOpcode::TANH:
+ if (dtype == loco::DataType::U8)
+ set_qparam(2.0f / 256.0f, 128);
+ else
+ {
+ assert(dtype == loco::DataType::S16);
+ set_qparam(1.0f / 32768.0f, 0);
+ }
+ break;
+ case CircleOpcode::SOFTMAX:
+ if (dtype == loco::DataType::U8)
+ set_qparam(1.0f / 255.0f, 0);
+ else
+ {
+ assert(dtype == loco::DataType::S16);
+ set_qparam(1.0f / 32767.0f, 0);
+ }
+ break;
+ default:
+ throw std::runtime_error("Unsupported opcode with pre-defined qparam");
+ }
+ return std::move(qparam);
+}
+
+// For nodes with integer output, we use integer scale
+void set_int_scale(luci::CircleNode *node)
+{
+ assert(node); // FIX_CALLER_UNLESS
+
+ auto qparam = node->quantparam();
+ assert(qparam); // FIX_CALLER_UNLESS
+ assert(qparam->scale.size() == 1); // FIX_CALLER_UNLESS
+
+ auto fp_scale = qparam->scale[0];
+ qparam->scale[0] = fp_scale < 1 ? 1.0f : std::round(fp_scale);
+}
+
+void quant_const(luci::CircleConst *node, loco::DataType quant_type)
+{
+ assert(node->dtype() == loco::DataType::FLOAT32);
+
+ float min = std::numeric_limits<float>::max();
+ float max = std::numeric_limits<float>::lowest();
+ for (uint32_t i = 0; i < node->size<loco::DataType::FLOAT32>(); i++)
+ {
+ auto data = node->at<loco::DataType::FLOAT32>(i);
+ min = data < min ? data : min;
+ max = data > max ? data : max;
+ }
+
+ float scaling_factor{0.0};
+ int64_t zp{0};
+ float nudged_min{0.0};
+ float nudged_max{0.0};
+
+ switch (quant_type)
+ {
+ case loco::DataType::U8:
+ asymmetric_wquant_with_minmax_per_layer(node, min, max, scaling_factor, zp, nudged_min,
+ nudged_max);
+ break;
+ case loco::DataType::S16:
+ symmetric_wquant_with_minmax_per_layer(node, min, max, scaling_factor, zp, nudged_min,
+ nudged_max);
+ break;
+ default:
+ throw std::runtime_error("Unsupported data type");
+ }
+
+ auto quantparam = std::make_unique<luci::CircleQuantParam>();
+ quantparam->scale.push_back(scaling_factor);
+ quantparam->zerop.push_back(zp);
+ node->quantparam(std::move(quantparam));
+}
+
} // namespace luci
diff --git a/compiler/luci/pass/src/QuantizationUtils.h b/compiler/luci/pass/src/QuantizationUtils.h
index 605f6a77e..cd8cec95a 100644
--- a/compiler/luci/pass/src/QuantizationUtils.h
+++ b/compiler/luci/pass/src/QuantizationUtils.h
@@ -23,33 +23,61 @@
namespace luci
{
+// Compute scale/zp using given min/max for symmetric quantization (int16)
void compute_sym_scale_zp(float min, float max, float &scaling_factor, int64_t &zp,
float &nudged_min, float &nudged_max);
+// Compute scale/zp using given min/max for asymmetric quantization (uint8)
void compute_asym_scale_zp(float min, float max, float &scaling_factor, int64_t &zp,
float &nudged_min, float &nudged_max);
+// Asymmetric per-layer quantization of weights (const tensor) using given min/max values
+// NOTE: in-place update of node data
void asymmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float max,
float &scaling_factor, int64_t &zp, float &nudged_min,
float &nudged_max);
+// Symmetric per-layer quantization of weights (const tensor) using given min/max values
+// NOTE: in-place update of node data
void symmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float max,
float &scaling_factor, int64_t &zp, float &nudged_min,
float &nudged_max);
+// Helper function to get channel dimension
+// TODO Embed this function into iterate_per_channel
bool get_channel_dim_index(CircleConst *node, loco::TensorShape &dimension,
int32_t &channel_dim_index);
+// Calculate offset of the given indices in dimension
uint32_t cal_offset(loco::TensorShape &dimension, uint32_t *indices);
-void propagate_concat_quantparam(luci::CircleConcatenation *concat, loco::DataType quant_type);
+// Backward propagation of concatenation qparam
+void propagate_concat_quantparam(luci::CircleConcatenation *concat);
-void propagate_pad_v2_quantparam(luci::CirclePadV2 *pad_v2, loco::DataType quant_type);
-
-bool is_weights(CircleNode *node);
+// Backward propagation of pad_v2 qparam
+void propagate_pad_v2_quantparam(luci::CirclePadV2 *pad_v2);
+// Return true if the node is quantized
bool is_quantized(const CircleNode *node);
+enum ActivationQType
+{
+ MinMax, // Quantize using recorded min/max
+ PreDefinedValue, // Quantize using pre-defined values
+ IntScale, // Round scale to a positive integer
+};
+
+ActivationQType activation_qtype(const CircleNode *node);
+
+// Create qparam with pre-defined values for speical operators
+std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, loco::DataType dtype);
+
+// Update node's scale to a positive integer (for special Ops e.g., Floor, Ceil)
+void set_int_scale(luci::CircleNode *node);
+
+// Quantize const tensor using its min/max values
+void quant_const(luci::CircleConst *node, loco::DataType quant_type);
+
} // namespace luci
#endif // __LUCI_QUANTIZATION_UTILS_H__
diff --git a/compiler/luci/pass/src/QuantizeActivation.cpp b/compiler/luci/pass/src/QuantizeActivation.cpp
new file mode 100644
index 000000000..149331824
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizeActivation.cpp
@@ -0,0 +1,296 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "QuantizeActivation.h"
+#include "QuantizationUtils.h"
+
+#include <luci/Service/Nodes/CircleConst.h>
+#include <luci/Log.h>
+
+#include <algorithm>
+#include <cmath>
+
+using namespace luci;
+
+namespace
+{
+
+bool has_min_max(const CircleNode *node)
+{
+ return node->quantparam() && !node->quantparam()->min.empty() && !node->quantparam()->max.empty();
+}
+
+} // namespace
+
+// QuantizeActivation
+namespace luci
+{
+
+void QuantizeActivation::visit(luci::CircleNode *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeActivation visit node: " << node->name() << std::endl;
+
+ // Check if this is already quantized
+ if (is_quantized(node))
+ return;
+
+ // Check if this is bool type (bool type is not quantized)
+ if (node->dtype() == loco::DataType::BOOL)
+ return;
+
+ // Check if this is const (const activation is handled by QuantizeConstInputActivation)
+ // NOTE QuantizePreChecker guarantees weights/bias are const.
+ // Update this code when we accept non-const weights/bias.
+ if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
+ return;
+
+ // Check if this is activation
+ // We assume min/max are recorded only for activations
+ if (has_min_max(node))
+ {
+ // Quantize using recorded min/max
+ auto quantparam = node->quantparam();
+ assert(quantparam);
+ assert(quantparam->min.size() == 1); // only support layer-wise quant
+ assert(quantparam->max.size() == 1); // only support layer-wise quant
+ auto min = quantparam->min[0];
+ auto max = quantparam->max[0];
+
+ float scaling_factor{0};
+ int64_t zp{0};
+ float nudged_min{0};
+ float nudged_max{0};
+
+ if (output_type == loco::DataType::U8)
+ {
+ compute_asym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
+ node->dtype(loco::DataType::U8);
+ }
+ else
+ {
+ compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
+ node->dtype(loco::DataType::S16);
+ }
+
+ node->quantparam()->scale.push_back(scaling_factor);
+ node->quantparam()->zerop.push_back(zp);
+ }
+ // Fix special attributes
+ if (node->opcode() == luci::CircleOpcode::CAST)
+ {
+ auto *cast = loco::must_cast<luci::CircleCast *>(node);
+ auto *cast_input = loco::must_cast<luci::CircleNode *>(cast->x());
+
+ // make sure that cast_input is already quantized
+ assert(cast_input->dtype() != loco::DataType::FLOAT32);
+ cast->in_data_type(cast_input->dtype());
+ cast->out_data_type(cast->dtype());
+ }
+}
+
+} // namespace luci
+
+// QuantizeSpecialActivation
+namespace luci
+{
+
+void QuantizeSpecialActivation::visit(luci::CircleNode *node)
+{
+ // Nodes fused with activation functions which need special quantization
+ auto fused_act_node = dynamic_cast<CircleNodeMixin<CircleNodeTrait::FusedActFunc> *>(node);
+ if (fused_act_node != nullptr && fused_act_node->fusedActivationFunction() == FusedActFunc::TANH)
+ {
+ auto qparam = make_predefined_qparam(luci::CircleOpcode::TANH, output_type);
+ node->quantparam(std::move(qparam));
+ }
+}
+
+void QuantizeSpecialActivation::visit(luci::CircleLogistic *node)
+{
+ assert(activation_qtype(node) == luci::ActivationQType::PreDefinedValue);
+ auto qparam = make_predefined_qparam(luci::CircleOpcode::LOGISTIC, output_type);
+ node->quantparam(std::move(qparam));
+}
+
+void QuantizeSpecialActivation::visit(luci::CircleTanh *node)
+{
+ assert(activation_qtype(node) == luci::ActivationQType::PreDefinedValue);
+ auto qparam = make_predefined_qparam(luci::CircleOpcode::TANH, output_type);
+ node->quantparam(std::move(qparam));
+}
+
+void QuantizeSpecialActivation::visit(luci::CircleSoftmax *node)
+{
+ assert(activation_qtype(node) == luci::ActivationQType::PreDefinedValue);
+ auto qparam = make_predefined_qparam(luci::CircleOpcode::SOFTMAX, output_type);
+ node->quantparam(std::move(qparam));
+}
+
+void QuantizeSpecialActivation::visit(luci::CircleFloor *node)
+{
+ assert(activation_qtype(node) == luci::ActivationQType::IntScale);
+ set_int_scale(node);
+}
+
+void QuantizeSpecialActivation::visit(luci::CircleFloorDiv *node)
+{
+ assert(activation_qtype(node) == luci::ActivationQType::IntScale);
+ set_int_scale(node);
+}
+
+void QuantizeSpecialActivation::visit(luci::CircleFloorMod *node)
+{
+ assert(activation_qtype(node) == luci::ActivationQType::IntScale);
+ set_int_scale(node);
+}
+
+void QuantizeSpecialActivation::visit(luci::CircleCeil *node)
+{
+ assert(activation_qtype(node) == luci::ActivationQType::IntScale);
+ set_int_scale(node);
+}
+
+} // namespace luci
+
+// QuantizeConstInputActivation
+namespace luci
+{
+
+// Default behavior (NYI)
+void QuantizeConstInputActivation::visit(luci::CircleNode *node)
+{
+ for (uint32_t i = 0; i < node->arity(); i++)
+ {
+ auto input_node = node->arg(i);
+ auto const_node = dynamic_cast<luci::CircleConst *>(input_node);
+ if (const_node != nullptr)
+ throw std::runtime_error("Unsupported Op for const inputs");
+ }
+}
+
+// INPUT_NAME is the only activation of NODE
+#define QUANTIZE_SINGLE_CONST_INPUT(NODE, INPUT_NAME) \
+ void QuantizeConstInputActivation::visit(NODE *node) \
+ { \
+ auto input = node->INPUT_NAME(); \
+ auto const_node = dynamic_cast<luci::CircleConst *>(input); \
+ if (const_node && !is_quantized(const_node)) \
+ { \
+ auto new_const = luci::clone(const_node); \
+ quant_const(new_const, _output_type); \
+ node->INPUT_NAME(new_const); \
+ } \
+ }
+
+// INPUT_NAME1 and INPUT_NAME2 are the only activations of NODE
+#define QUANTIZE_TWO_CONST_INPUTS(NODE, INPUT_NAME1, INPUT_NAME2) \
+ void QuantizeConstInputActivation::visit(NODE *node) \
+ { \
+ auto input1 = node->INPUT_NAME1(); \
+ auto const_node1 = dynamic_cast<luci::CircleConst *>(input1); \
+ if (const_node1 && !is_quantized(const_node1)) \
+ { \
+ auto new_const1 = luci::clone(const_node1); \
+ quant_const(new_const1, _output_type); \
+ node->INPUT_NAME1(new_const1); \
+ } \
+ auto input2 = node->INPUT_NAME2(); \
+ auto const_node2 = dynamic_cast<luci::CircleConst *>(input2); \
+ if (const_node2 && !is_quantized(const_node2)) \
+ { \
+ auto new_const2 = luci::clone(const_node2); \
+ quant_const(new_const2, _output_type); \
+ node->INPUT_NAME2(new_const2); \
+ } \
+ }
+
+// Ops that receive a single activation as an input
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleArgMax, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleArgMin, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleBatchToSpaceND, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleDepthToSpace, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleElu, features)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleExp, x)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleFloor, x)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleGather, params)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleLocalResponseNormalization, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleLogistic, x)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleMean, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleMirrorPad, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CirclePad, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleReduceAny, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleReduceProd, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleReduceMax, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleReduceMin, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleReshape, tensor)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleResizeBilinear, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleResizeNearestNeighbor, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleReverseSequence, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleRsqrt, x)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSlice, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSoftmax, logits)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSpaceToBatchND, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSpaceToDepth, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSplit, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSplitV, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSqrt, x)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleStridedSlice, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSum, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleTanh, x)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleTile, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleTopKV2, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleTranspose, a)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleUnpack, value)
+
+// Ops that receive two activations as inputs
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleAdd, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleBatchMatMul, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleDiv, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleEqual, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleFloorDiv, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleGreater, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleGreaterEqual, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleLess, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleLessEqual, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleMaximum, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleMinimum, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleMul, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleNotEqual, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CirclePow, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleSub, x, y)
+
+// AddN has arbitrary number of inputs
+void QuantizeConstInputActivation::visit(luci::CircleAddN *node)
+{
+ auto arity = node->arity();
+ for (uint32_t i = 0; i < arity; i++)
+ {
+ auto input_node = node->inputs(i);
+ auto const_node = dynamic_cast<luci::CircleConst *>(input_node);
+ if (const_node && !is_quantized(const_node))
+ {
+ auto new_const = luci::clone(const_node);
+ quant_const(new_const, _output_type);
+ node->inputs(i, new_const);
+ }
+ }
+}
+
+#undef QUANTIZE_SINGLE_CONST_INPUT
+#undef QUANTIZE_TWO_CONST_INPUTS
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/QuantizeActivation.h b/compiler/luci/pass/src/QuantizeActivation.h
new file mode 100644
index 000000000..fc32d1cde
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizeActivation.h
@@ -0,0 +1,165 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_QUANTIZATION_ACTIVATION_H__
+#define __LUCI_QUANTIZATION_ACTIVATION_H__
+
+#include <luci/IR/CircleNodeVisitor.h>
+
+namespace luci
+{
+
+/**
+ * @brief Quantize non-const activation using recorded min/max values
+ */
+struct QuantizeActivation final : public luci::CircleNodeMutableVisitor<void>
+{
+ QuantizeActivation(loco::DataType input, loco::DataType output)
+ : input_type(input), output_type(output)
+ {
+ }
+
+ loco::DataType input_type;
+ loco::DataType output_type;
+
+ // Quantize each node using recorded min/max
+ void visit(luci::CircleNode *node);
+};
+
+/**
+ * @brief Quantize non-const activaion using pre-defined scale/zp for special Ops
+ */
+struct QuantizeSpecialActivation final : public luci::CircleNodeMutableVisitor<void>
+{
+ QuantizeSpecialActivation(loco::DataType input, loco::DataType output)
+ : input_type(input), output_type(output)
+ {
+ }
+
+ loco::DataType input_type;
+ loco::DataType output_type;
+
+ void visit(luci::CircleNode *node);
+ void visit(luci::CircleLogistic *node);
+ void visit(luci::CircleTanh *node);
+ void visit(luci::CircleSoftmax *node);
+ void visit(luci::CircleFloor *node);
+ void visit(luci::CircleFloorDiv *node);
+ void visit(luci::CircleFloorMod *node);
+ void visit(luci::CircleCeil *node);
+};
+
+// Quantize constant input activation of a node
+// The input of a node is quantized if it is
+// 1. Constant (instance of CircleConst*)
+// 2. Activation (other inputs e.g., weights, bias, axis, etc should not be quantized here)
+struct QuantizeConstInputActivation final : public luci::CircleNodeMutableVisitor<void>
+{
+ QuantizeConstInputActivation(loco::DataType output_type) : _output_type(output_type) {}
+
+private:
+ loco::DataType _output_type;
+
+// Skip NODE
+#define SKIP(NODE) \
+ void visit(NODE *) {}
+
+ // Handled in QuantizeWeights and QuantizeBias
+ SKIP(luci::CircleConv2D)
+ SKIP(luci::CircleDepthwiseConv2D)
+ SKIP(luci::CircleFullyConnected)
+ SKIP(luci::CircleInstanceNorm)
+ SKIP(luci::CirclePRelu)
+ SKIP(luci::CircleTransposeConv)
+
+ // Handled in PropagateQParamBackwardPass
+ SKIP(luci::CircleConcatenation)
+ SKIP(luci::CirclePadV2)
+ SKIP(luci::CirclePack)
+ SKIP(luci::CircleOneHot)
+
+ // Inputs of logical Ops are bool, thus not quantized
+ SKIP(luci::CircleLogicalOr)
+ SKIP(luci::CircleLogicalAnd)
+ SKIP(luci::CircleLogicalNot)
+
+#undef SKIP
+
+ // Default behavior (NYI)
+ void visit(luci::CircleNode *node);
+
+ // Ops that receive a single activation as an input
+ void visit(luci::CircleArgMax *node);
+ void visit(luci::CircleArgMin *node);
+ void visit(luci::CircleBatchToSpaceND *node);
+ void visit(luci::CircleDepthToSpace *node);
+ void visit(luci::CircleElu *node);
+ void visit(luci::CircleExp *node);
+ void visit(luci::CircleFloor *node);
+ void visit(luci::CircleGather *node);
+ void visit(luci::CircleLocalResponseNormalization *node);
+ void visit(luci::CircleLogistic *node);
+ void visit(luci::CircleMean *node);
+ void visit(luci::CircleMirrorPad *node);
+ void visit(luci::CirclePad *node);
+ void visit(luci::CircleReduceAny *node);
+ void visit(luci::CircleReduceProd *node);
+ void visit(luci::CircleReduceMax *node);
+ void visit(luci::CircleReduceMin *node);
+ void visit(luci::CircleReshape *node);
+ void visit(luci::CircleResizeBilinear *node);
+ void visit(luci::CircleResizeNearestNeighbor *node);
+ void visit(luci::CircleReverseSequence *node);
+ void visit(luci::CircleRsqrt *node);
+ void visit(luci::CircleSlice *node);
+ void visit(luci::CircleSoftmax *node);
+ void visit(luci::CircleSpaceToBatchND *node);
+ void visit(luci::CircleSpaceToDepth *node);
+ void visit(luci::CircleSplit *node);
+ void visit(luci::CircleSplitV *node);
+ void visit(luci::CircleSqrt *node);
+ void visit(luci::CircleStridedSlice *node);
+ void visit(luci::CircleSum *node);
+ void visit(luci::CircleTanh *node);
+ void visit(luci::CircleTile *node);
+ void visit(luci::CircleTopKV2 *node);
+ void visit(luci::CircleTranspose *node);
+ void visit(luci::CircleUnpack *node);
+
+ // Ops that receive two activations as inputs
+ void visit(luci::CircleAdd *node);
+ void visit(luci::CircleBatchMatMul *node);
+ void visit(luci::CircleDiv *node);
+ void visit(luci::CircleEqual *node);
+ void visit(luci::CircleFloorDiv *node);
+ void visit(luci::CircleGreater *node);
+ void visit(luci::CircleGreaterEqual *node);
+ void visit(luci::CircleLess *node);
+ void visit(luci::CircleLessEqual *node);
+ void visit(luci::CircleMaximum *node);
+ void visit(luci::CircleMinimum *node);
+ void visit(luci::CircleMul *node);
+ void visit(luci::CircleNotEqual *node);
+ void visit(luci::CirclePow *node);
+ void visit(luci::CircleSub *node);
+
+ // AddN has arbitrary number of inputs
+ void visit(luci::CircleAddN *node);
+};
+
+} // namespace luci
+
+#endif // __LUCI_QUANTIZATION_ACTIVATION_H__
diff --git a/compiler/luci/pass/src/QuantizeBias.cpp b/compiler/luci/pass/src/QuantizeBias.cpp
new file mode 100644
index 000000000..aa496232a
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizeBias.cpp
@@ -0,0 +1,300 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "QuantizeBias.h"
+#include "QuantizationUtils.h"
+
+#include <luci/Service/Nodes/CircleConst.h>
+#include <luci/Log.h>
+
+#include <algorithm>
+#include <cmath>
+
+using namespace luci;
+
+namespace
+{
+
+// struct to carry Input/Weights/Bias
+struct IWB
+{
+ CircleNode *input = nullptr;
+ CircleNode *weights = nullptr;
+ CircleConst *bias = nullptr;
+
+ IWB(loco::Node *i, loco::Node *w, loco::Node *b)
+ {
+ input = dynamic_cast<luci::CircleNode *>(i);
+ weights = dynamic_cast<luci::CircleNode *>(w);
+ bias = dynamic_cast<luci::CircleConst *>(b);
+ }
+
+ // Return true if bias can be quantized with valid input an weights
+ operator bool()
+ {
+ if (bias == nullptr || is_quantized(bias))
+ return false;
+ if (input == nullptr || weights == nullptr)
+ return false;
+ return true;
+ }
+};
+
+// Create a new const node from an existing node.
+// The new node has the following characteristics
+// type: T
+// shape: same with 'node' (given as an argument)
+// buffer size: 'size' (given as an argument)
+// Note that contents are not filled in this function.
+template <loco::DataType T>
+luci::CircleConst *create_empty_const_from(luci::CircleConst *node, uint32_t size)
+{
+ auto new_node = node->graph()->nodes()->create<CircleConst>();
+ // TODO: We don't have any naming convention for quantized nodes yet.
+ // Fix this when we have one.
+ new_node->name(node->name());
+ new_node->dtype(T);
+ new_node->rank(node->rank());
+ for (uint32_t i = 0; i < node->rank(); i++)
+ new_node->dim(i).set(node->dim(i).value());
+
+ new_node->size<T>(size);
+ new_node->shape_status(luci::ShapeStatus::VALID);
+
+ return new_node;
+}
+
+CircleConst *asym_quant_bias_per_layer(CircleConst *node, float input_scale, float weight_scale,
+ float *scaling_factor, int64_t *zp)
+{
+ float scale = input_scale * weight_scale;
+ const float scaling_factor_inv = (scale == 0) ? 0 : 1.0 / scale;
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+ std::vector<int32_t> quantized_values(size);
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ quantized_values[i] =
+ static_cast<int32_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
+ }
+
+ auto new_bias = create_empty_const_from<loco::DataType::S32>(node, size);
+
+ const int32_t kMinScale = std::numeric_limits<int32_t>::lowest();
+ const int32_t kMaxScale = std::numeric_limits<int32_t>::max();
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ new_bias->at<loco::DataType::S32>(i) =
+ std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
+ }
+ *scaling_factor = scale;
+ *zp = 0;
+
+ return new_bias;
+}
+
+CircleConst *quant_bias_per_channel(CircleConst *node, float input_scale,
+ std::vector<float> &weight_scale,
+ std::vector<float> &scaling_factor, std::vector<int64_t> &zp)
+{
+ float scaling_factor_inv{0};
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+ std::vector<int32_t> quantized_values(size);
+
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ scaling_factor[i] = input_scale * weight_scale[i];
+ scaling_factor_inv = (scaling_factor[i] == 0) ? 0 : 1.0 / scaling_factor[i];
+ quantized_values[i] =
+ static_cast<int32_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
+ zp[i] = 0;
+ }
+
+ auto new_bias = create_empty_const_from<loco::DataType::S32>(node, size);
+
+ const int32_t kMinScale = std::numeric_limits<int32_t>::lowest();
+ const int32_t kMaxScale = std::numeric_limits<int32_t>::max();
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ new_bias->at<loco::DataType::S32>(i) =
+ std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
+ }
+
+ return new_bias;
+}
+
+CircleConst *int16_quant_bias_per_channel(CircleConst *node, float input_scale,
+ std::vector<float> &weight_scale,
+ std::vector<float> &scaling_factor,
+ std::vector<int64_t> &zp)
+{
+ float scaling_factor_inv{0};
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+ std::vector<int64_t> quantized_values(size);
+
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ scaling_factor[i] = input_scale * weight_scale[i];
+ scaling_factor_inv = (scaling_factor[i] == 0) ? 0 : 1.0 / scaling_factor[i];
+ quantized_values[i] =
+ static_cast<int64_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
+ zp[i] = 0;
+ }
+
+ auto new_bias = create_empty_const_from<loco::DataType::S64>(node, size);
+
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ new_bias->at<loco::DataType::S64>(i) = quantized_values[i];
+ }
+
+ return new_bias;
+}
+
+} // namespace
+
+namespace luci
+{
+
+// Return a quantized bias node
+CircleConst *QuantizeBias::quantized_bias(CircleNode *input, const CircleNode *weight,
+ CircleNode *bias)
+{
+ auto const_bias = loco::must_cast<luci::CircleConst *>(bias);
+ assert(const_bias->dtype() == loco::DataType::FLOAT32);
+
+ // If input is const, it is quantized here, not in QuantizeActivation
+ if (auto const_input = dynamic_cast<luci::CircleConst *>(input))
+ {
+ quant_const(const_input, output_type);
+ }
+
+ CircleConst *new_bias = nullptr;
+
+ if (granularity == QuantizationGranularity::ChannelWise)
+ {
+ auto input_q = input->quantparam();
+ assert(input_q);
+ assert(input_q->scale.size() == 1); // input scale's layer-wise
+ auto input_scale = input_q->scale[0];
+
+ assert(weight->quantparam() != nullptr); // weight scale's channel-wise
+ auto weight_scale = weight->quantparam()->scale;
+
+ uint32_t size = const_bias->size<loco::DataType::FLOAT32>();
+ assert(size == weight_scale.size());
+ std::vector<float> scaling_factor(size);
+ std::vector<int64_t> zp(size);
+
+ if (output_type == loco::DataType::U8)
+ {
+ new_bias = quant_bias_per_channel(const_bias, input_scale, weight_scale, scaling_factor, zp);
+ }
+ else if (output_type == loco::DataType::S16)
+ {
+ new_bias =
+ int16_quant_bias_per_channel(const_bias, input_scale, weight_scale, scaling_factor, zp);
+ }
+ else
+ {
+ throw std::runtime_error("Unsupported quantization type.");
+ }
+
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->scale = scaling_factor;
+ quantparam->zerop = zp;
+ assert(new_bias->quantparam() == nullptr); // bias should not be quantized before
+ new_bias->quantparam(std::move(quantparam));
+
+ return new_bias;
+ }
+ else
+ {
+ auto input_q = input->quantparam();
+ assert(input_q);
+ assert(input_q->scale.size() == 1); // Only support per-layer quant
+ auto input_scale = input_q->scale[0];
+
+ auto weight_q = weight->quantparam();
+ assert(weight_q);
+ assert(weight_q->scale.size() == 1); // Only support per-layer quant
+ auto weight_scale = weight_q->scale[0];
+
+ float scaling_factor{0};
+ int64_t zp{0};
+ new_bias =
+ asym_quant_bias_per_layer(const_bias, input_scale, weight_scale, &scaling_factor, &zp);
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->scale.push_back(scaling_factor);
+ quantparam->zerop.push_back(zp);
+ assert(new_bias->quantparam() == nullptr); // bias should not be quantized before
+ new_bias->quantparam(std::move(quantparam));
+
+ return new_bias;
+ }
+}
+
+void QuantizeBias::visit(luci::CircleConv2D *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeBias QuantizeBias::visit node: " << node->name() << std::endl;
+
+ if (auto iwb = IWB(node->input(), node->filter(), node->bias()))
+ {
+ auto new_bias = quantized_bias(iwb.input, iwb.weights, iwb.bias);
+ node->bias(new_bias);
+ }
+}
+
+void QuantizeBias::visit(luci::CircleDepthwiseConv2D *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeBias QuantizeBias::visit node: " << node->name() << std::endl;
+
+ if (auto iwb = IWB(node->input(), node->filter(), node->bias()))
+ {
+ auto new_bias = quantized_bias(iwb.input, iwb.weights, iwb.bias);
+ node->bias(new_bias);
+ }
+}
+
+void QuantizeBias::visit(luci::CircleTransposeConv *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeBias QuantizeBias::visit node: " << node->name() << std::endl;
+
+ if (auto iwb = IWB(node->outBackprop(), node->filter(), node->bias()))
+ {
+ auto new_bias = quantized_bias(iwb.input, iwb.weights, iwb.bias);
+ node->bias(new_bias);
+ }
+}
+
+void QuantizeBias::visit(luci::CircleFullyConnected *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeBias visit node: " << node->name() << std::endl;
+
+ if (auto iwb = IWB(node->input(), node->weights(), node->bias()))
+ {
+ auto new_bias = quantized_bias(iwb.input, iwb.weights, iwb.bias);
+ node->bias(new_bias);
+ }
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/QuantizeBias.h b/compiler/luci/pass/src/QuantizeBias.h
new file mode 100644
index 000000000..8de09df72
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizeBias.h
@@ -0,0 +1,56 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_QUANTIZE_BIAS_H__
+#define __LUCI_QUANTIZE_BIAS_H__
+
+#include <luci/Pass/QuantizationParameters.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+namespace luci
+{
+
+/**
+ * @brief QuantizeBias quantizes tensors for bias
+ * @details Use input/weights scale to quantize values
+ */
+struct QuantizeBias final : public luci::CircleNodeMutableVisitor<void>
+{
+ QuantizeBias(loco::DataType input, loco::DataType output, QuantizationGranularity gr)
+ : input_type(input), output_type(output), granularity(gr)
+ {
+ }
+
+ loco::DataType input_type;
+ loco::DataType output_type;
+ QuantizationGranularity granularity;
+
+private:
+ // Return a quantized bias node
+ CircleConst *quantized_bias(CircleNode *input, const CircleNode *weight, CircleNode *bias);
+
+ void visit(luci::CircleConv2D *node);
+ void visit(luci::CircleDepthwiseConv2D *node);
+ void visit(luci::CircleTransposeConv *node);
+ void visit(luci::CircleFullyConnected *node);
+
+ // Default behavior
+ void visit(luci::CircleNode *) {}
+};
+
+} // namespace luci
+
+#endif // __LUCI_QUANTIZE_BIAS_H__
diff --git a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp
index c8ad87e3d..c9b35e0be 100644
--- a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp
+++ b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp
@@ -16,9 +16,11 @@
#include "luci/Pass/QuantizeDequantizeWeightsPass.h"
#include "QuantizationUtils.h"
+#include "helpers/LayerInfoMap.h"
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Service/Nodes/CircleConst.h>
#include <luci/Log.h>
#include <loco/IR/TensorShape.h>
@@ -251,7 +253,7 @@ void asymmetric_wdequant_with_minmax_per_layer(CircleConst *node, float scaling_
* @brief QuantizeDequantizeWeights quantizes and dequantizes tensors for weights
* @details Find min/max values on the fly, quantize the model, and dequantize the model
*/
-struct QuantizeDequantizeWeights final : public luci::CircleNodeMutableVisitor<bool>
+struct QuantizeDequantizeWeights final : public luci::CircleNodeMutableVisitor<void>
{
QuantizeDequantizeWeights(loco::DataType input, loco::DataType output,
QuantizationGranularity granularity)
@@ -263,88 +265,164 @@ struct QuantizeDequantizeWeights final : public luci::CircleNodeMutableVisitor<b
loco::DataType output_type;
QuantizationGranularity granularity;
- // Quantize and dequantize input tensors of each node
- bool visit(luci::CircleNode *node)
+private:
+ // Fake quantize weights (Only u8 quantization is supported for LWQ)
+ void fake_quantize_lwq(luci::CircleConst *weights) const
{
- assert(output_type == loco::DataType::U8 || output_type == loco::DataType::S16);
- LOGGER(l);
- INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl;
- auto arity = node->arity();
- for (uint32_t i = 0; i < arity; i++)
+ assert(output_type == loco::DataType::U8); // FIX_CALLER_UNLESS
+
+ // Find min/max per layer
+ float min = std::numeric_limits<float>::max();
+ float max = std::numeric_limits<float>::lowest();
+ for (uint32_t i = 0; i < weights->size<loco::DataType::FLOAT32>(); i++)
{
- auto input_node = node->arg(i);
- auto circle_node = loco::must_cast<luci::CircleNode *>(input_node);
+ auto data = weights->at<loco::DataType::FLOAT32>(i);
+ min = data < min ? data : min;
+ max = data > max ? data : max;
+ }
+ float scaling_factor{0};
+ int64_t zp{0};
+ float nudged_min{0};
+ float nudged_max{0};
+
+ asymmetric_wquant_with_minmax_per_layer(weights, min, max, scaling_factor, zp, nudged_min,
+ nudged_max);
+ asymmetric_wdequant_with_minmax_per_layer(weights, scaling_factor, nudged_min);
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->min.push_back(nudged_min);
+ quantparam->max.push_back(nudged_max);
+ quantparam->scale.push_back(scaling_factor);
+ quantparam->zerop.push_back(zp);
+ weights->quantparam(std::move(quantparam));
+ }
- // Check if this is already quantized
- if (is_quantized(circle_node))
- continue;
+private:
+ // Fake quantize weights (u8/s16 quantization are supported for CWQ)
+ void fake_quantize_cwq(luci::CircleConst *weights) const
+ {
+ assert(output_type == loco::DataType::U8 ||
+ output_type == loco::DataType::S16); // FIX_CALLER_UNLESS
- if (is_weights(circle_node))
- {
- auto circle_const = loco::must_cast<luci::CircleConst *>(circle_node);
+ // Find min/max per channel
+ std::vector<float> min;
+ std::vector<float> max;
- // Find min/max per channel-wise
- if (granularity == QuantizationGranularity::ChannelWise)
- {
- std::vector<float> min;
- std::vector<float> max;
-
- cal_minmax_per_channel(circle_const, min, max);
-
- std::vector<float> nudged_min(min.size());
- std::vector<float> nudged_max(min.size());
- std::vector<float> scaling_factor(min.size());
- std::vector<int64_t> zp(min.size());
-
- if (output_type == loco::DataType::U8)
- {
- asymmetric_wquant_per_channel(circle_const, min, max, scaling_factor, zp, nudged_min,
- nudged_max);
- asymmetric_wdequant_per_channel(circle_const, scaling_factor, nudged_min);
- }
- else
- {
- sym_wquant_per_channel(circle_const, min, max, scaling_factor, zp, nudged_min,
- nudged_max);
- sym_wdequant_per_channel(circle_const, scaling_factor);
- }
-
- auto quantparam = std::make_unique<CircleQuantParam>();
- quantparam->min = nudged_min;
- quantparam->max = nudged_max;
- quantparam->scale = scaling_factor;
- quantparam->zerop = zp;
- circle_node->quantparam(std::move(quantparam));
- }
- // Find min/max per layer-wise
- else
- {
- float min = std::numeric_limits<float>::max();
- float max = std::numeric_limits<float>::lowest();
- for (uint32_t i = 0; i < circle_const->size<loco::DataType::FLOAT32>(); i++)
- {
- auto data = circle_const->at<loco::DataType::FLOAT32>(i);
- min = data < min ? data : min;
- max = data > max ? data : max;
- }
- float scaling_factor{0};
- int64_t zp{0};
- float nudged_min{0};
- float nudged_max{0};
-
- asymmetric_wquant_with_minmax_per_layer(circle_const, min, max, scaling_factor, zp,
- nudged_min, nudged_max);
- asymmetric_wdequant_with_minmax_per_layer(circle_const, scaling_factor, nudged_min);
- auto quantparam = std::make_unique<CircleQuantParam>();
- quantparam->min.push_back(nudged_min);
- quantparam->max.push_back(nudged_max);
- quantparam->scale.push_back(scaling_factor);
- quantparam->zerop.push_back(zp);
- circle_node->quantparam(std::move(quantparam));
- }
- }
+ cal_minmax_per_channel(weights, min, max);
+
+ std::vector<float> nudged_min(min.size());
+ std::vector<float> nudged_max(min.size());
+ std::vector<float> scaling_factor(min.size());
+ std::vector<int64_t> zp(min.size());
+
+ if (output_type == loco::DataType::U8)
+ {
+ asymmetric_wquant_per_channel(weights, min, max, scaling_factor, zp, nudged_min, nudged_max);
+ asymmetric_wdequant_per_channel(weights, scaling_factor, nudged_min);
+ }
+ else
+ {
+ sym_wquant_per_channel(weights, min, max, scaling_factor, zp, nudged_min, nudged_max);
+ sym_wdequant_per_channel(weights, scaling_factor);
}
- return false;
+
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->min = nudged_min;
+ quantparam->max = nudged_max;
+ quantparam->scale = scaling_factor;
+ quantparam->zerop = zp;
+ weights->quantparam(std::move(quantparam));
+ }
+
+private:
+ void fake_quantize(luci::CircleConst *weights) const
+ {
+ switch (granularity)
+ {
+ case luci::QuantizationGranularity::ChannelWise:
+ fake_quantize_cwq(weights);
+ break;
+ case luci::QuantizationGranularity::LayerWise:
+ fake_quantize_lwq(weights);
+ break;
+ default:
+ throw std::invalid_argument("Unsupported granularity");
+ }
+ }
+
+private:
+ // Check if
+ // 1. node is const
+ // 2. node was not quantized
+ bool is_quantizable(loco::Node *node)
+ {
+ auto const_node = dynamic_cast<luci::CircleConst *>(node);
+ if (not const_node)
+ return false;
+
+ // Skip if this is already quantized
+ if (is_quantized(const_node))
+ return false;
+
+ return true;
+ }
+
+ // Default behavior (Do nothing)
+ void visit(luci::CircleNode *) {}
+
+ void visit(luci::CircleConv2D *node)
+ {
+ LOGGER(l);
+ INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl;
+
+ if (not is_quantizable(node->filter()))
+ return;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
+ auto new_weights = luci::clone(weights);
+ node->filter(new_weights);
+ fake_quantize(new_weights);
+ }
+
+ void visit(luci::CircleDepthwiseConv2D *node)
+ {
+ LOGGER(l);
+ INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl;
+
+ if (not is_quantizable(node->filter()))
+ return;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
+ auto new_weights = luci::clone(weights);
+ node->filter(new_weights);
+ fake_quantize(new_weights);
+ }
+
+ void visit(luci::CircleTransposeConv *node)
+ {
+ LOGGER(l);
+ INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl;
+
+ if (not is_quantizable(node->filter()))
+ return;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
+ auto new_weights = luci::clone(weights);
+ node->filter(new_weights);
+ fake_quantize(new_weights);
+ }
+
+ void visit(luci::CircleFullyConnected *node)
+ {
+ LOGGER(l);
+ INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl;
+
+ if (not is_quantizable(node->weights()))
+ return;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->weights());
+ auto new_weights = luci::clone(weights);
+ node->weights(new_weights);
+ fake_quantize(new_weights);
}
};
@@ -355,11 +433,36 @@ bool QuantizeDequantizeWeightsPass::run(loco::Graph *g)
LOGGER(l);
INFO(l) << "QuantizeDequantizeWeightsPass Start" << std::endl;
+ auto info_by_name = layer_info_map(g, _ctx->layers_info);
+
+ auto quantize_dtype = [&](const luci::CircleNode *node) {
+ auto iter = info_by_name.find(node->name());
+
+ // Return designated quantization dtype
+ if (iter != info_by_name.end())
+ return iter->second.dtype;
+
+ // Return default quantization dtype
+ return _ctx->output_model_dtype;
+ };
+
+ auto quantize_granularity = [&](const luci::CircleNode *node) {
+ auto iter = info_by_name.find(node->name());
+
+ // Return designated quantization granularity
+ if (iter != info_by_name.end())
+ return iter->second.granularity;
+
+ // Return default quantization granularity
+ return _ctx->granularity;
+ };
+
// Quantize weights
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
- QuantizeDequantizeWeights qw(_input_model_dtype, _output_model_dtype, _granularity);
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ QuantizeDequantizeWeights qw(_ctx->input_model_dtype, quantize_dtype(circle_node),
+ quantize_granularity(circle_node));
circle_node->accept(&qw);
}
diff --git a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.test.cpp b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.test.cpp
index f226253c2..15f5ca7ac 100644
--- a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.test.cpp
+++ b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.test.cpp
@@ -25,3 +25,17 @@ TEST(QuantizeDequantizeWeightsPassTest, name)
auto const name = pass.name();
ASSERT_NE(nullptr, name);
}
+
+TEST(QuantizeDequantizeWeightsPassTest, name_ctx)
+{
+ auto ctx = std::make_unique<luci::QuantizeDequantizeWeightsPass::Context>();
+ {
+ ctx->input_model_dtype = loco::DataType::FLOAT32;
+ ctx->output_model_dtype = loco::DataType::U8;
+ ctx->granularity = luci::QuantizationGranularity::LayerWise;
+ }
+
+ luci::QuantizeDequantizeWeightsPass pass(std::move(ctx));
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
diff --git a/compiler/luci/pass/src/QuantizePreCheckerPass.cpp b/compiler/luci/pass/src/QuantizePreCheckerPass.cpp
new file mode 100644
index 000000000..4b3b7e330
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizePreCheckerPass.cpp
@@ -0,0 +1,119 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/QuantizePreCheckerPass.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+#include <luci/Log.h>
+
+namespace luci
+{
+
+namespace
+{
+
+void check_const_opcode(luci::CircleNode *node)
+{
+ if (node == nullptr)
+ return;
+
+ if (node->opcode() != luci::CircleOpcode::CIRCLECONST and
+ node->opcode() != luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE)
+ {
+ throw std::runtime_error("Unsupported non const input " + node->name());
+ }
+}
+
+struct ConstInputChecker final : public luci::CircleNodeMutableVisitor<void>
+{
+// INPUT_NAME is name for input const for current NODE
+#define CHECK_NODE_WITH_ONE_INPUT_CONST(NODE, INPUT_NAME) \
+ void visit(NODE *node) \
+ { \
+ const auto input = dynamic_cast<luci::CircleNode *>(node->INPUT_NAME()); \
+ check_const_opcode(input); \
+ }
+
+// INPUT_NAME_1 and INPUT_NAME_2 are names for input const for current NODE
+#define CHECK_NODE_WITH_TWO_INPUT_CONST(NODE, INPUT_NAME_1, INPUT_NAME_2) \
+ void visit(NODE *node) \
+ { \
+ const auto input_1 = dynamic_cast<luci::CircleNode *>(node->INPUT_NAME_1()); \
+ const auto input_2 = dynamic_cast<luci::CircleNode *>(node->INPUT_NAME_2()); \
+ \
+ check_const_opcode(input_1); \
+ check_const_opcode(input_2); \
+ }
+
+// INPUT_NAME_1, INPUT_NAME_2 and INPUT_NAME_3 are names for input const for current NODE
+#define CHECK_NODE_WITH_THREE_INPUT_CONST(NODE, INPUT_NAME_1, INPUT_NAME_2, INPUT_NAME_3) \
+ void visit(NODE *node) \
+ { \
+ const auto input_1 = dynamic_cast<luci::CircleNode *>(node->INPUT_NAME_1()); \
+ const auto input_2 = dynamic_cast<luci::CircleNode *>(node->INPUT_NAME_2()); \
+ const auto input_3 = dynamic_cast<luci::CircleNode *>(node->INPUT_NAME_3()); \
+ \
+ check_const_opcode(input_1); \
+ check_const_opcode(input_2); \
+ check_const_opcode(input_3); \
+ }
+
+ // Skip other circle node
+ void visit(luci::CircleNode *) {}
+
+ // Ops that receive one const nodes as inputs
+ CHECK_NODE_WITH_ONE_INPUT_CONST(luci::CirclePRelu, alpha)
+
+ // Ops that receive two const node as an inputs
+ CHECK_NODE_WITH_TWO_INPUT_CONST(luci::CircleConv2D, filter, bias)
+ CHECK_NODE_WITH_TWO_INPUT_CONST(luci::CircleDepthwiseConv2D, filter, bias)
+ CHECK_NODE_WITH_TWO_INPUT_CONST(luci::CircleFullyConnected, weights, bias)
+ CHECK_NODE_WITH_TWO_INPUT_CONST(luci::CircleInstanceNorm, gamma, beta)
+
+ // Ops that receive three const nodes as an inputs
+ CHECK_NODE_WITH_THREE_INPUT_CONST(luci::CircleTransposeConv, inputSizes, filter, bias)
+
+#undef CHECK_NODE_WITH_ONE_INPUT_CONST
+#undef CHECK_NODE_WITH_TWO_INPUT_CONST
+#undef CHECK_NODE_WITH_THREE_INPUT_CONST
+};
+
+} // namespace
+
+/**
+ * Verify the input model has the form acceptable by quantizer
+ */
+bool QuantizePreCheckerPass::run(loco::Graph *g)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizePreCheckerPass Start" << std::endl;
+
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ // Check const inputs
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ ConstInputChecker checker{};
+ circle_node->accept(&checker);
+ }
+
+ INFO(l) << "QuantizePreCheckerPass End" << std::endl;
+
+ return false; // one time run
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp b/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp
new file mode 100644
index 000000000..788353cd8
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp
@@ -0,0 +1,401 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/QuantizePreCheckerPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+class SimpleConv2DGraph
+{
+public:
+ SimpleConv2DGraph(bool make_valid)
+ {
+ conv2d_node = g.nodes()->create<luci::CircleConv2D>();
+ input_1 = g.nodes()->create<luci::CircleInput>();
+ filter = g.nodes()->create<luci::CircleConst>();
+
+ conv2d_node->input(input_1);
+ conv2d_node->filter(filter);
+
+ if (make_valid)
+ {
+ bias = g.nodes()->create<luci::CircleConst>();
+ conv2d_node->bias(bias);
+ }
+ else
+ {
+ input_2 = g.nodes()->create<luci::CircleInput>();
+ conv2d_node->bias(input_2);
+ }
+
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ output->from(conv2d_node);
+ }
+
+public:
+ loco::Graph g;
+
+private:
+ luci::CircleConv2D *conv2d_node = nullptr;
+ luci::CircleInput *input_1 = nullptr;
+ luci::CircleInput *input_2 = nullptr;
+ luci::CircleConst *filter = nullptr;
+ luci::CircleConst *bias = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+class SimpleDepthConv2DGraph
+{
+public:
+ SimpleDepthConv2DGraph(bool make_valid)
+ {
+ depth_conv2d_node = g.nodes()->create<luci::CircleDepthwiseConv2D>();
+ input_1 = g.nodes()->create<luci::CircleInput>();
+ filter = g.nodes()->create<luci::CircleConst>();
+
+ depth_conv2d_node->input(input_1);
+ depth_conv2d_node->filter(filter);
+
+ if (make_valid)
+ {
+ bias = g.nodes()->create<luci::CircleConst>();
+ depth_conv2d_node->bias(bias);
+ }
+ else
+ {
+ input_2 = g.nodes()->create<luci::CircleInput>();
+ depth_conv2d_node->bias(input_2);
+ }
+
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ output->from(depth_conv2d_node);
+ }
+
+public:
+ loco::Graph g;
+
+private:
+ luci::CircleDepthwiseConv2D *depth_conv2d_node = nullptr;
+ luci::CircleInput *input_1 = nullptr;
+ luci::CircleInput *input_2 = nullptr;
+ luci::CircleConst *filter = nullptr;
+ luci::CircleConst *bias = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+class SimpleFCGraph
+{
+public:
+ SimpleFCGraph(bool make_valid)
+ {
+ fc_node = g.nodes()->create<luci::CircleFullyConnected>();
+ input_1 = g.nodes()->create<luci::CircleInput>();
+ weights = g.nodes()->create<luci::CircleConst>();
+
+ fc_node->input(input_1);
+ fc_node->weights(weights);
+
+ if (make_valid)
+ {
+ bias = g.nodes()->create<luci::CircleConst>();
+ fc_node->bias(bias);
+ }
+ else
+ {
+ input_2 = g.nodes()->create<luci::CircleInput>();
+ fc_node->bias(input_2);
+ }
+
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ output->from(fc_node);
+ }
+
+public:
+ loco::Graph g;
+
+private:
+ luci::CircleFullyConnected *fc_node = nullptr;
+ luci::CircleInput *input_1 = nullptr;
+ luci::CircleInput *input_2 = nullptr;
+ luci::CircleConst *weights = nullptr;
+ luci::CircleConst *bias = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+class SimpleInstanceNormGraph
+{
+public:
+ SimpleInstanceNormGraph(bool make_valid)
+ {
+ instance_norm_node = g.nodes()->create<luci::CircleInstanceNorm>();
+ input_1 = g.nodes()->create<luci::CircleInput>();
+ gamma = g.nodes()->create<luci::CircleConst>();
+
+ instance_norm_node->input(input_1);
+ instance_norm_node->gamma(gamma);
+
+ if (make_valid)
+ {
+ beta = g.nodes()->create<luci::CircleConst>();
+ instance_norm_node->beta(beta);
+ }
+ else
+ {
+ input_2 = g.nodes()->create<luci::CircleInput>();
+ instance_norm_node->beta(input_2);
+ }
+
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ output->from(instance_norm_node);
+ }
+
+public:
+ loco::Graph g;
+
+private:
+ luci::CircleInstanceNorm *instance_norm_node = nullptr;
+ luci::CircleInput *input_1 = nullptr;
+ luci::CircleInput *input_2 = nullptr;
+ luci::CircleConst *gamma = nullptr;
+ luci::CircleConst *beta = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+class SimpleTransposeConvGraph
+{
+public:
+ SimpleTransposeConvGraph(bool make_valid)
+ {
+ transpose_conv = g.nodes()->create<luci::CircleTransposeConv>();
+ input_1 = g.nodes()->create<luci::CircleInput>();
+
+ input_sizes = g.nodes()->create<luci::CircleConst>();
+ filter = g.nodes()->create<luci::CircleConst>();
+
+ transpose_conv->outBackprop(input_1);
+ transpose_conv->filter(filter);
+ transpose_conv->inputSizes(input_sizes);
+
+ if (make_valid)
+ {
+ bias = g.nodes()->create<luci::CircleConst>();
+ transpose_conv->bias(bias);
+ }
+ else
+ {
+ input_2 = g.nodes()->create<luci::CircleInput>();
+ transpose_conv->bias(input_2);
+ }
+
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ output->from(transpose_conv);
+ }
+
+public:
+ loco::Graph g;
+
+private:
+ luci::CircleTransposeConv *transpose_conv = nullptr;
+ luci::CircleInput *input_1 = nullptr;
+ luci::CircleInput *input_2 = nullptr;
+ luci::CircleConst *input_sizes = nullptr;
+ luci::CircleConst *filter = nullptr;
+ luci::CircleConst *bias = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+class SimplePReluGraph
+{
+public:
+ SimplePReluGraph(bool make_valid)
+ {
+ prelu = g.nodes()->create<luci::CirclePRelu>();
+ input_1 = g.nodes()->create<luci::CircleInput>();
+
+ prelu->input(input_1);
+
+ if (make_valid)
+ {
+ alpha = g.nodes()->create<luci::CircleConst>();
+ prelu->alpha(alpha);
+ }
+ else
+ {
+ input_2 = g.nodes()->create<luci::CircleInput>();
+ prelu->alpha(input_2);
+ }
+
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ output->from(prelu);
+ }
+
+public:
+ loco::Graph g;
+
+private:
+ luci::CirclePRelu *prelu = nullptr;
+ luci::CircleInput *input_1 = nullptr;
+ luci::CircleInput *input_2 = nullptr;
+ luci::CircleConst *alpha = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+TEST(QuantizePreCheckerPassTest, name)
+{
+ luci::QuantizePreCheckerPass pass{};
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+// Test Conv2d
+TEST(QuantizePreCheckerPassTest, conv2d)
+{
+ SimpleConv2DGraph valid_graph(true);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_NO_THROW(checker.run(&valid_graph.g));
+}
+
+TEST(QuantizePreCheckerPassTest, conv2d_NEG)
+{
+ SimpleConv2DGraph invalid_graph(false);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
+}
+
+// Test DepthwiseConv2d
+TEST(QuantizePreCheckerPassTest, depthwise_conv2d)
+{
+ SimpleDepthConv2DGraph valid_graph(true);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_NO_THROW(checker.run(&valid_graph.g));
+}
+
+TEST(QuantizePreCheckerPassTest, depthwise_conv2d_NEG)
+{
+ SimpleDepthConv2DGraph invalid_graph(false);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
+}
+
+// Test FullyConnected
+TEST(QuantizePreCheckerPassTest, fully_connected)
+{
+ SimpleFCGraph valid_graph(true);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_NO_THROW(checker.run(&valid_graph.g));
+}
+
+TEST(QuantizePreCheckerPassTest, fully_connected_NEG)
+{
+ SimpleFCGraph invalid_graph(false);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
+}
+
+// Test InstanceNorm
+TEST(QuantizePreCheckerPassTest, instance_norm)
+{
+ SimpleInstanceNormGraph valid_graph(true);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_NO_THROW(checker.run(&valid_graph.g));
+}
+
+TEST(QuantizePreCheckerPassTest, instance_norm_NEG)
+{
+ SimpleInstanceNormGraph invalid_graph(false);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
+}
+
+// Test TransposeConv
+TEST(QuantizePreCheckerPassTest, transpose_conv)
+{
+ SimpleTransposeConvGraph valid_graph(true);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_NO_THROW(checker.run(&valid_graph.g));
+}
+
+TEST(QuantizePreCheckerPassTest, transpose_conv_NEG)
+{
+ SimpleTransposeConvGraph invalid_graph(false);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
+}
+
+// Test PRelu
+TEST(QuantizePreCheckerPassTest, prelu)
+{
+ SimplePReluGraph valid_graph(true);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_NO_THROW(checker.run(&valid_graph.g));
+}
+
+TEST(QuantizePreCheckerPassTest, prelu_NEG)
+{
+ SimplePReluGraph invalid_graph(false);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
+}
diff --git a/compiler/luci/pass/src/QuantizeWeights.cpp b/compiler/luci/pass/src/QuantizeWeights.cpp
new file mode 100644
index 000000000..11322ab44
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizeWeights.cpp
@@ -0,0 +1,394 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "QuantizeWeights.h"
+#include "QuantizationUtils.h"
+
+#include <luci/Service/Nodes/CircleConst.h>
+#include <luci/Log.h>
+
+#include <cmath>
+#include <vector>
+#include <functional>
+
+using namespace luci;
+
+namespace
+{
+
+using IterFunc = std::function<void(uint32_t *, loco::TensorShape &, int32_t)>;
+
+void iterate_per_channel(CircleConst *node, int32_t &channel_dim_index, IterFunc func)
+{
+ loco::TensorShape dimension;
+ dimension.rank(4);
+ uint32_t indices[4] = {
+ 0,
+ };
+
+ if (!get_channel_dim_index(node, dimension, channel_dim_index))
+ {
+ assert(false);
+ return;
+ }
+
+ for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++)
+ {
+ for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++)
+ {
+ for (indices[2] = 0; indices[2] < dimension.dim(2).value(); indices[2]++)
+ {
+ for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++)
+ {
+ func(indices, dimension, channel_dim_index);
+ }
+ }
+ }
+ }
+}
+
+void asym_wquant_per_channel(CircleConst *node, std::vector<float> &min,
+ std::vector<float> &scaling_factor, int32_t &channel_dim_index)
+{
+ assert(node->dtype() == loco::DataType::FLOAT32);
+
+ const int32_t kMinScale = 0;
+ const int32_t kMaxScale = 255;
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+ std::vector<int32_t> quantized_values(size);
+
+ auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int32_t channel_dim_index) {
+ int channel_idx = indices[channel_dim_index];
+ const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
+ auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
+ quantized_values[cal_offset(dimension, indices)] =
+ static_cast<int32_t>(std::round((data - min[channel_idx]) * scaling_factor_inv));
+ };
+
+ iterate_per_channel(node, channel_dim_index, quantize);
+
+ node->dtype(loco::DataType::U8); // change the type of tensor
+ node->size<loco::DataType::U8>(size); // resize tensor
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ node->at<loco::DataType::U8>(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
+ }
+}
+
+void sym_wquant_per_channel(CircleConst *node, std::vector<float> &scaling_factor,
+ int32_t &channel_dim_index)
+{
+ assert(node->dtype() == loco::DataType::FLOAT32);
+
+ const int32_t kMaxScale = std::numeric_limits<int16_t>::max();
+ const int32_t kMinScale = -kMaxScale;
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+ std::vector<int32_t> quantized_values(size);
+
+ auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int32_t channel_dim_index) {
+ int channel_idx = indices[channel_dim_index];
+ const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
+ auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
+ quantized_values[cal_offset(dimension, indices)] =
+ static_cast<int32_t>(std::round(data * scaling_factor_inv));
+ };
+
+ iterate_per_channel(node, channel_dim_index, quantize);
+
+ node->dtype(loco::DataType::S16); // change the type of tensor
+ node->size<loco::DataType::S16>(size); // resize tensor
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ node->at<loco::DataType::S16>(i) =
+ std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
+ }
+}
+
+void asym_wquant_per_layer(CircleConst *node, float min, float scaling_factor)
+{
+ const int32_t kMinScale = 0;
+ const int32_t kMaxScale = 255;
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+
+ const float scaling_factor_inv = 1.0 / scaling_factor;
+ std::vector<int32_t> quantized_values(size);
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ auto data = node->at<loco::DataType::FLOAT32>(i);
+ quantized_values[i] = static_cast<int32_t>(std::round((data - min) * scaling_factor_inv));
+ }
+
+ node->dtype(loco::DataType::U8); // change the type of tensor
+ node->size<loco::DataType::U8>(size); // resize tensor
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ node->at<loco::DataType::U8>(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
+ }
+}
+
+// Quantize const per channel
+//
+// The last dimension of const is the same as the dimension of channel
+// And the rest of the const dimensions should be 1
+// So, a 'single value' is quantized per channel
+//
+// Quantization spec (f: fp value, q: quantized value)
+//
+// uint8
+// Positive f: f = f * (q - 0) [q = 1, scale = f, zp = 0]
+// Negative f: f = (-f) * (q - 1) [q = 0, scale = -f, zp = 1]
+//
+// int16
+// Positive f: f = f * (q - 0) [q = 1, scale = f, zp = 0]
+// Negative f: f = (-f) * (q - 0) [q = -1, scale = -f, zp = 0]
+void quant_const_per_channel(CircleConst *node, loco::DataType quant_type)
+{
+ assert(node->dtype() == loco::DataType::FLOAT32);
+ assert(node->rank() > 0);
+
+ for (uint32_t i = 0; i < node->rank() - 1; i++)
+ {
+ // Caller should call this function when the below condition is satisfied
+ if (node->dim(i).value() != 1)
+ throw std::runtime_error("Non-channel dimension of const node must be 1");
+ }
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+ assert(size == node->dim(node->rank() - 1).value());
+
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->quantized_dimension = node->rank() - 1;
+ std::vector<int32_t> quantized_data(size);
+
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ auto data = node->at<loco::DataType::FLOAT32>(i);
+ if (quant_type == loco::DataType::U8)
+ {
+ if (data >= 0)
+ {
+ quantparam->scale.push_back(data);
+ quantparam->zerop.push_back(0);
+ quantized_data[i] = 1;
+ }
+ else
+ {
+ quantparam->scale.push_back(-data);
+ quantparam->zerop.push_back(1);
+ quantized_data[i] = 0;
+ }
+ }
+ else if (quant_type == loco::DataType::S16)
+ {
+ if (data >= 0)
+ {
+ quantparam->scale.push_back(data);
+ quantized_data[i] = 1;
+ }
+ else
+ {
+ quantparam->scale.push_back(-data);
+ quantized_data[i] = -1;
+ }
+ quantparam->zerop.push_back(0);
+ }
+ }
+ node->quantparam(std::move(quantparam));
+
+ switch (quant_type)
+ {
+ case loco::DataType::U8:
+ node->dtype(loco::DataType::U8);
+ node->size<loco::DataType::U8>(size);
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ assert(quantized_data[i] == 0 || quantized_data[i] == 1);
+ node->at<loco::DataType::U8>(i) = quantized_data[i];
+ }
+ break;
+ case loco::DataType::S16:
+ node->dtype(loco::DataType::S16);
+ node->size<loco::DataType::S16>(size);
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ assert(quantized_data[i] == -1 || quantized_data[i] == 1);
+ node->at<loco::DataType::S16>(i) = quantized_data[i];
+ }
+ break;
+ default:
+ throw std::runtime_error("Unsupported data type");
+ }
+}
+
+} // namespace
+
+namespace luci
+{
+
+void QuantizeWeights::quantize_weights(luci::CircleConst *weights)
+{
+ // Find min/max per channel-wise
+ if (granularity == QuantizationGranularity::ChannelWise)
+ {
+ auto quantparam = weights->quantparam();
+ if (quantparam == nullptr)
+ {
+ assert(false && "quantparam is nullptr");
+ return;
+ }
+
+ auto min = quantparam->min;
+ auto scaling_factor = quantparam->scale;
+ int32_t channel_dim_index = 0;
+
+ if (output_type == loco::DataType::U8)
+ {
+ asym_wquant_per_channel(weights, min, scaling_factor, channel_dim_index);
+ }
+ else
+ {
+ sym_wquant_per_channel(weights, scaling_factor, channel_dim_index);
+ }
+ quantparam->min.clear();
+ quantparam->max.clear();
+ quantparam->quantized_dimension = channel_dim_index;
+ }
+ // Find min/max per layer-wise
+ else
+ {
+ // Quantize using recorded quantparam
+ auto quantparam = weights->quantparam();
+ assert(quantparam != nullptr);
+ assert(quantparam->min.size() == 1); // only support layer-wise quant
+ assert(quantparam->scale.size() == 1); // only support layer-wise quant
+ auto min = quantparam->min[0];
+ auto scaling_factor = quantparam->scale[0];
+ asym_wquant_per_layer(weights, min, scaling_factor);
+ quantparam->min.clear();
+ quantparam->max.clear();
+ }
+}
+void QuantizeWeights::visit(luci::CircleConv2D *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeWeights QuantizeWeights::visit node: " << node->name() << std::endl;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
+ if (!is_quantized(weights))
+ {
+ auto new_weights = luci::clone(weights);
+ node->filter(new_weights);
+ quantize_weights(new_weights);
+ }
+}
+
+void QuantizeWeights::visit(luci::CircleDepthwiseConv2D *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeWeights QuantizeWeights::visit node: " << node->name() << std::endl;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
+ if (!is_quantized(weights))
+ {
+ auto new_weights = luci::clone(weights);
+ node->filter(new_weights);
+ quantize_weights(new_weights);
+ }
+}
+
+void QuantizeWeights::visit(luci::CircleInstanceNorm *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeWeights QuantizeWeights::visit node: " << node->name() << std::endl;
+
+ auto gamma = loco::must_cast<luci::CircleConst *>(node->gamma());
+ auto beta = loco::must_cast<luci::CircleConst *>(node->beta());
+
+ if (!is_quantized(gamma))
+ {
+ assert(gamma->dtype() == loco::DataType::FLOAT32);
+ auto new_gamma = luci::clone(gamma);
+ if (granularity == QuantizationGranularity::LayerWise)
+ quant_const(new_gamma, output_type);
+ else if (granularity == QuantizationGranularity::ChannelWise)
+ quant_const_per_channel(new_gamma, output_type);
+ node->gamma(new_gamma);
+ }
+ if (!is_quantized(beta))
+ {
+ assert(beta->dtype() == loco::DataType::FLOAT32);
+ auto new_beta = luci::clone(beta);
+ if (granularity == QuantizationGranularity::LayerWise)
+ quant_const(new_beta, output_type);
+ else if (granularity == QuantizationGranularity::ChannelWise)
+ quant_const_per_channel(new_beta, output_type);
+ node->beta(new_beta);
+ }
+}
+
+void QuantizeWeights::visit(luci::CirclePRelu *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeWeights QuantizeWeights::visit node: " << node->name() << std::endl;
+
+ auto alpha = loco::must_cast<luci::CircleConst *>(node->alpha());
+
+ if (!is_quantized(alpha))
+ {
+ assert(alpha->dtype() == loco::DataType::FLOAT32);
+ auto new_alpha = luci::clone(alpha);
+ if (granularity == QuantizationGranularity::LayerWise)
+ quant_const(new_alpha, output_type);
+ else if (granularity == QuantizationGranularity::ChannelWise)
+ quant_const_per_channel(new_alpha, output_type);
+ node->alpha(new_alpha);
+ }
+}
+
+void QuantizeWeights::visit(luci::CircleTransposeConv *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeWeights QuantizeWeights::visit node: " << node->name() << std::endl;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
+ if (!is_quantized(weights))
+ {
+ auto new_weights = luci::clone(weights);
+ node->filter(new_weights);
+ quantize_weights(new_weights);
+ }
+}
+
+void QuantizeWeights::visit(luci::CircleFullyConnected *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeWeights QuantizeWeights::visit node: " << node->name() << std::endl;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->weights());
+ if (!is_quantized(weights))
+ {
+ auto new_weights = luci::clone(weights);
+ node->weights(new_weights);
+ quantize_weights(new_weights);
+ }
+}
+
+void QuantizeWeights::visit(luci::CircleNode *) {}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/QuantizeWeights.h b/compiler/luci/pass/src/QuantizeWeights.h
new file mode 100644
index 000000000..f62cd40f3
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizeWeights.h
@@ -0,0 +1,55 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_QUANTIZE_WEIGHTS_H__
+#define __LUCI_QUANTIZE_WEIGHTS_H__
+
+#include <luci/Pass/QuantizationParameters.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+namespace luci
+{
+
+/**
+ * @brief QuantizeWeights quantizes tensors for weights
+ * @details Find min/max values on the fly and then quantize
+ */
+struct QuantizeWeights final : public luci::CircleNodeMutableVisitor<void>
+{
+ QuantizeWeights(loco::DataType input, loco::DataType output, QuantizationGranularity gr)
+ : input_type(input), output_type(output), granularity(gr)
+ {
+ }
+
+ loco::DataType input_type;
+ loco::DataType output_type;
+ QuantizationGranularity granularity;
+
+private:
+ void quantize_weights(luci::CircleConst *weights);
+
+ void visit(luci::CircleConv2D *node);
+ void visit(luci::CircleDepthwiseConv2D *node);
+ void visit(luci::CircleInstanceNorm *node);
+ void visit(luci::CirclePRelu *node);
+ void visit(luci::CircleTransposeConv *node);
+ void visit(luci::CircleFullyConnected *node);
+ void visit(luci::CircleNode *);
+};
+
+} // namespace luci
+
+#endif // __LUCI_QUANTIZE_WEIGHTS_H__
diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
index c3552ec52..d9a9d4db7 100644
--- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
+++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
@@ -15,55 +15,32 @@
*/
#include "luci/Pass/QuantizeWithMinMaxPass.h"
+#include "luci/Pass/PropagateQParamForwardPass.h"
+#include "luci/Pass/PropagateQParamBackwardPass.h"
+#include "luci/Pass/RemoveRedundantQuantizePass.h"
+#include "QuantizeActivation.h"
+#include "QuantizeWeights.h"
+#include "QuantizeBias.h"
#include "QuantizationUtils.h"
+#include "ProgressReporter.h"
+#include "helpers/LayerInfoMap.h"
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
#include <luci/Service/Nodes/CircleConst.h>
#include <luci/Profile/CircleNodeOrigin.h>
#include <luci/Log.h>
+#include <logo/Phase.h>
#include <oops/UserExn.h>
#include <iostream>
#include <cmath>
-#include <functional>
namespace
{
using namespace luci;
-using IterFunc = std::function<void(uint32_t *, loco::TensorShape &, int32_t)>;
-
-void iterate_per_channel(CircleConst *node, int32_t &channel_dim_index, IterFunc func)
-{
- loco::TensorShape dimension;
- dimension.rank(4);
- uint32_t indices[4] = {
- 0,
- };
-
- if (!get_channel_dim_index(node, dimension, channel_dim_index))
- {
- assert(false);
- return;
- }
-
- for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++)
- {
- for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++)
- {
- for (indices[2] = 0; indices[2] < dimension.dim(2).value(); indices[2]++)
- {
- for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++)
- {
- func(indices, dimension, channel_dim_index);
- }
- }
- }
- }
-}
-
// Create a Quantize Op whose
// dtype is out_type
// shape is the same with node
@@ -80,7 +57,17 @@ luci::CircleQuantize *create_quantize_op(luci::CircleNode *node, loco::DataType
quantize->shape_status(luci::ShapeStatus::VALID);
auto qparam = node->quantparam();
- assert(qparam); // FIX_CALLER_UNLESS
+ assert(qparam); // FIX_CALLER_UNLESS
+
+ auto qtype = luci::activation_qtype(node);
+ if (qtype == ActivationQType::PreDefinedValue)
+ {
+ quantize->quantparam(luci::make_predefined_qparam(node->opcode(), out_type));
+ return quantize;
+ }
+
+ assert(qtype == ActivationQType::MinMax or qtype == ActivationQType::IntScale);
+
assert(qparam->min.size() == 1); // FIX_CALLER_UNLESS
assert(qparam->max.size() == 1); // FIX_CALLER_UNLESS
auto min = qparam->min[0];
@@ -104,9 +91,17 @@ luci::CircleQuantize *create_quantize_op(luci::CircleNode *node, loco::DataType
auto quantparam = std::make_unique<CircleQuantParam>();
quantparam->scale.push_back(scaling_factor);
quantparam->zerop.push_back(zp);
+ // Save original min/max (not nudged_min/max). Nudged min/max
+ // is different from the real min/max values, causing wrong
+ // qparam when quantization dtype is changed.
+ quantparam->min.push_back(min);
+ quantparam->max.push_back(max);
quantize->quantparam(std::move(quantparam));
+ if (qtype == ActivationQType::IntScale)
+ set_int_scale(quantize);
+
return quantize;
}
@@ -118,1412 +113,232 @@ namespace luci
namespace
{
-// Create a new const node from an existing node.
-// The new node has the following characteristics
-// type: T
-// shape: same with 'node' (given as an argument)
-// buffer size: 'size' (given as an argument)
-// Note that contents are not filled in this function.
-template <loco::DataType T>
-luci::CircleConst *create_empty_const_from(luci::CircleConst *node, uint32_t size)
-{
- auto new_node = node->graph()->nodes()->create<CircleConst>();
- // TODO: We don't have any naming convention for quantized nodes yet.
- // Fix this when we have one.
- new_node->name(node->name());
- new_node->dtype(T);
- new_node->rank(node->rank());
- for (uint32_t i = 0; i < node->rank(); i++)
- new_node->dim(i).set(node->dim(i).value());
-
- new_node->size<T>(size);
- new_node->shape_status(luci::ShapeStatus::VALID);
-
- return new_node;
-}
-
-void overwrite_quantparam(luci::CircleNode *source, luci::CircleNode *target)
-{
- auto source_qparam = source->quantparam();
- if (source_qparam == nullptr)
- throw std::runtime_error("source quantparam is not found during overwrite");
-
- auto target_qparam = target->quantparam();
- if (target_qparam == nullptr)
- {
- auto quantparam = std::make_unique<CircleQuantParam>();
- target->quantparam(std::move(quantparam));
- target_qparam = target->quantparam();
-
- if (target_qparam == nullptr)
- throw std::runtime_error("Creating new quant param failed");
- }
- target_qparam->min = source_qparam->min;
- target_qparam->max = source_qparam->max;
- target_qparam->scale = source_qparam->scale;
- target_qparam->zerop = source_qparam->zerop;
- target_qparam->quantized_dimension = source_qparam->quantized_dimension;
-}
-
-void quant_const_values(luci::CircleConst *const_node, float scaling_factor, float zerop,
- loco::DataType quant_type)
-{
- uint32_t size = const_node->size<loco::DataType::FLOAT32>();
-
- const float scaling_factor_inv = 1.0 / scaling_factor;
- std::vector<int32_t> quantized_values(size);
- for (uint32_t i = 0; i < size; ++i)
- {
- auto data = static_cast<double>(const_node->at<loco::DataType::FLOAT32>(i));
- double quantized_float = std::round(data * scaling_factor_inv) + zerop;
- constexpr auto int_max = static_cast<double>(std::numeric_limits<int32_t>::max());
- constexpr auto int_min = static_cast<double>(std::numeric_limits<int32_t>::min());
- quantized_float = std::min(int_max, std::max(int_min, quantized_float));
-
- quantized_values[i] = static_cast<int32_t>(quantized_float);
- }
-
- switch (quant_type)
- {
- case loco::DataType::U8:
- const_node->dtype(loco::DataType::U8); // change the type of tensor
- const_node->size<loco::DataType::U8>(size); // resize tensor
- for (uint32_t i = 0; i < size; ++i)
- const_node->at<loco::DataType::U8>(i) = std::min(255, std::max(0, quantized_values[i]));
- break;
- case loco::DataType::S16:
- assert(zerop == 0);
- const_node->dtype(loco::DataType::S16); // change the type of tensor
- const_node->size<loco::DataType::S16>(size); // resize tensor
- for (uint32_t i = 0; i < size; ++i)
- const_node->at<loco::DataType::S16>(i) =
- std::min(32767, std::max(-32767, quantized_values[i]));
- break;
- default:
- throw std::runtime_error("Unsupported data type");
- }
-}
-
-// Quantize const per channel
-//
-// The last dimension of const is the same as the dimension of channel
-// And the rest of the const dimensions should be 1
-// So, a 'single value' is quantized per channel
-//
-// Quantization spec (f: fp value, q: quantized value)
-//
-// uint8
-// Positive f: f = f * (q - 0) [q = 1, scale = f, zp = 0]
-// Negative f: f = (-f) * (q - 1) [q = 0, scale = -f, zp = 1]
-//
-// int16
-// Positive f: f = f * (q - 0) [q = 1, scale = f, zp = 0]
-// Negative f: f = (-f) * (q - 0) [q = -1, scale = -f, zp = 0]
-void quant_const_per_channel(CircleConst *node, loco::DataType quant_type)
-{
- assert(node->dtype() == loco::DataType::FLOAT32);
- assert(node->rank() > 0);
-
- for (uint32_t i = 0; i < node->rank() - 1; i++)
- {
- // Caller should call this function when the below condition is satisfied
- if (node->dim(i).value() != 1)
- throw std::runtime_error("Non-channel dimension of const node must be 1");
- }
-
- uint32_t size = node->size<loco::DataType::FLOAT32>();
- assert(size == node->dim(node->rank() - 1).value());
-
- auto quantparam = std::make_unique<CircleQuantParam>();
- quantparam->quantized_dimension = node->rank() - 1;
- std::vector<int32_t> quantized_data(size);
-
- for (uint32_t i = 0; i < size; ++i)
- {
- auto data = node->at<loco::DataType::FLOAT32>(i);
- if (quant_type == loco::DataType::U8)
- {
- if (data >= 0)
- {
- quantparam->scale.push_back(data);
- quantparam->zerop.push_back(0);
- quantized_data[i] = 1;
- }
- else
- {
- quantparam->scale.push_back(-data);
- quantparam->zerop.push_back(1);
- quantized_data[i] = 0;
- }
- }
- else if (quant_type == loco::DataType::S16)
- {
- if (data >= 0)
- {
- quantparam->scale.push_back(data);
- quantized_data[i] = 1;
- }
- else
- {
- quantparam->scale.push_back(-data);
- quantized_data[i] = -1;
- }
- quantparam->zerop.push_back(0);
- }
- }
- node->quantparam(std::move(quantparam));
-
- switch (quant_type)
- {
- case loco::DataType::U8:
- node->dtype(loco::DataType::U8);
- node->size<loco::DataType::U8>(size);
- for (uint32_t i = 0; i < size; ++i)
- {
- assert(quantized_data[i] == 0 || quantized_data[i] == 1);
- node->at<loco::DataType::U8>(i) = quantized_data[i];
- }
- break;
- case loco::DataType::S16:
- node->dtype(loco::DataType::S16);
- node->size<loco::DataType::S16>(size);
- for (uint32_t i = 0; i < size; ++i)
- {
- assert(quantized_data[i] == -1 || quantized_data[i] == 1);
- node->at<loco::DataType::S16>(i) = quantized_data[i];
- }
- break;
- default:
- throw std::runtime_error("Unsupported data type");
- }
-}
-
-void quant_const(CircleConst *node, loco::DataType quant_type)
-{
- assert(node->dtype() == loco::DataType::FLOAT32);
-
- float min = std::numeric_limits<float>::max();
- float max = std::numeric_limits<float>::lowest();
- for (uint32_t i = 0; i < node->size<loco::DataType::FLOAT32>(); i++)
- {
- auto data = node->at<loco::DataType::FLOAT32>(i);
- min = data < min ? data : min;
- max = data > max ? data : max;
- }
-
- float scaling_factor{0.0};
- int64_t zp{0};
- float nudged_min{0.0};
- float nudged_max{0.0};
-
- switch (quant_type)
- {
- case loco::DataType::U8:
- asymmetric_wquant_with_minmax_per_layer(node, min, max, scaling_factor, zp, nudged_min,
- nudged_max);
- break;
- case loco::DataType::S16:
- symmetric_wquant_with_minmax_per_layer(node, min, max, scaling_factor, zp, nudged_min,
- nudged_max);
- break;
- default:
- throw std::runtime_error("Unsupported data type");
- }
-
- auto quantparam = std::make_unique<CircleQuantParam>();
- quantparam->scale.push_back(scaling_factor);
- quantparam->zerop.push_back(zp);
- node->quantparam(std::move(quantparam));
-}
-
-// Check if the node is the bias of Conv2D, DepthwiseConv2D, FullyConnected, or TransposeConv layer
-// Returns a list of <input, weights, output> vectors for the above operators.
-// Note that it returns a 'list' because bias can be used by multiple operators.
-std::vector<std::vector<loco::Node *>> get_input_weight_output_of_bias(CircleNode *node)
-{
- std::vector<std::vector<loco::Node *>> result;
- auto circle_const = dynamic_cast<CircleConst *>(node);
- if (circle_const == nullptr)
- return result;
-
- auto succs = loco::succs(node);
-
- for (auto out : succs)
- {
- auto conv = dynamic_cast<CircleConv2D *>(out);
- if (conv != nullptr && conv->bias() == circle_const)
- {
- assert(conv->input() != nullptr);
- assert(conv->filter() != nullptr);
- result.push_back({conv->input(), conv->filter(), conv});
- continue;
- }
- auto dw_conv = dynamic_cast<CircleDepthwiseConv2D *>(out);
- if (dw_conv != nullptr && dw_conv->bias() == circle_const)
- {
- assert(dw_conv->input() != nullptr);
- assert(dw_conv->filter() != nullptr);
- result.push_back({dw_conv->input(), dw_conv->filter(), dw_conv});
- continue;
- }
- auto fc = dynamic_cast<CircleFullyConnected *>(out);
- if (fc != nullptr && fc->bias() == circle_const)
- {
- assert(fc->input() != nullptr);
- assert(fc->weights() != nullptr);
- result.push_back({fc->input(), fc->weights(), fc});
- continue;
- }
- auto tconv = dynamic_cast<CircleTransposeConv *>(out);
- if (tconv != nullptr && tconv->bias() == circle_const)
- {
- assert(tconv->outBackprop() != nullptr);
- assert(tconv->filter() != nullptr);
- result.push_back({tconv->outBackprop(), tconv->filter(), tconv});
- continue;
- }
- }
- return result;
-}
-
-CircleConst *asym_quant_bias_per_layer(CircleConst *node, float input_scale, float weight_scale,
- float *scaling_factor, int64_t *zp)
-{
- float scale = input_scale * weight_scale;
- const float scaling_factor_inv = (scale == 0) ? 0 : 1.0 / scale;
-
- uint32_t size = node->size<loco::DataType::FLOAT32>();
- std::vector<int32_t> quantized_values(size);
- for (uint32_t i = 0; i < size; ++i)
- {
- quantized_values[i] =
- static_cast<int32_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
- }
-
- auto new_bias = create_empty_const_from<loco::DataType::S32>(node, size);
-
- const int32_t kMinScale = std::numeric_limits<int32_t>::lowest();
- const int32_t kMaxScale = std::numeric_limits<int32_t>::max();
- for (uint32_t i = 0; i < size; ++i)
- {
- new_bias->at<loco::DataType::S32>(i) =
- std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
- }
- *scaling_factor = scale;
- *zp = 0;
-
- return new_bias;
-}
-
-CircleConst *quant_bias_per_channel(CircleConst *node, float input_scale,
- std::vector<float> &weight_scale,
- std::vector<float> &scaling_factor, std::vector<int64_t> &zp)
-{
- float scaling_factor_inv{0};
-
- uint32_t size = node->size<loco::DataType::FLOAT32>();
- std::vector<int32_t> quantized_values(size);
-
- for (uint32_t i = 0; i < size; ++i)
- {
- scaling_factor[i] = input_scale * weight_scale[i];
- scaling_factor_inv = (scaling_factor[i] == 0) ? 0 : 1.0 / scaling_factor[i];
- quantized_values[i] =
- static_cast<int32_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
- zp[i] = 0;
- }
-
- auto new_bias = create_empty_const_from<loco::DataType::S32>(node, size);
-
- const int32_t kMinScale = std::numeric_limits<int32_t>::lowest();
- const int32_t kMaxScale = std::numeric_limits<int32_t>::max();
- for (uint32_t i = 0; i < size; ++i)
- {
- new_bias->at<loco::DataType::S32>(i) =
- std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
- }
-
- return new_bias;
-}
-
-CircleConst *int16_quant_bias_per_channel(CircleConst *node, float input_scale,
- std::vector<float> &weight_scale,
- std::vector<float> &scaling_factor,
- std::vector<int64_t> &zp)
-{
- float scaling_factor_inv{0};
-
- uint32_t size = node->size<loco::DataType::FLOAT32>();
- std::vector<int64_t> quantized_values(size);
-
- for (uint32_t i = 0; i < size; ++i)
- {
- scaling_factor[i] = input_scale * weight_scale[i];
- scaling_factor_inv = (scaling_factor[i] == 0) ? 0 : 1.0 / scaling_factor[i];
- quantized_values[i] =
- static_cast<int64_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
- zp[i] = 0;
- }
-
- auto new_bias = create_empty_const_from<loco::DataType::S64>(node, size);
-
- for (uint32_t i = 0; i < size; ++i)
- {
- new_bias->at<loco::DataType::S64>(i) = quantized_values[i];
- }
-
- return new_bias;
-}
-
-bool has_min_max(const CircleNode *node)
-{
- return node->quantparam() && !node->quantparam()->min.empty() && !node->quantparam()->max.empty();
-}
-
-void sym_wquant_per_channel(CircleConst *node, std::vector<float> &scaling_factor,
- int32_t &channel_dim_index)
-{
- assert(node->dtype() == loco::DataType::FLOAT32);
-
- const int32_t kMaxScale = std::numeric_limits<int16_t>::max();
- const int32_t kMinScale = -kMaxScale;
-
- uint32_t size = node->size<loco::DataType::FLOAT32>();
- std::vector<int32_t> quantized_values(size);
-
- auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int32_t channel_dim_index) {
- int channel_idx = indices[channel_dim_index];
- const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
- auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
- quantized_values[cal_offset(dimension, indices)] =
- static_cast<int32_t>(std::round(data * scaling_factor_inv));
- };
-
- iterate_per_channel(node, channel_dim_index, quantize);
-
- node->dtype(loco::DataType::S16); // change the type of tensor
- node->size<loco::DataType::S16>(size); // resize tensor
- for (uint32_t i = 0; i < size; ++i)
- {
- node->at<loco::DataType::S16>(i) =
- std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
- }
-}
-
-void asym_wquant_per_channel(CircleConst *node, std::vector<float> &min,
- std::vector<float> &scaling_factor, int32_t &channel_dim_index)
-{
- assert(node->dtype() == loco::DataType::FLOAT32);
-
- const int32_t kMinScale = 0;
- const int32_t kMaxScale = 255;
-
- uint32_t size = node->size<loco::DataType::FLOAT32>();
- std::vector<int32_t> quantized_values(size);
-
- auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int32_t channel_dim_index) {
- int channel_idx = indices[channel_dim_index];
- const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
- auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
- quantized_values[cal_offset(dimension, indices)] =
- static_cast<int32_t>(std::round((data - min[channel_idx]) * scaling_factor_inv));
- };
-
- iterate_per_channel(node, channel_dim_index, quantize);
-
- node->dtype(loco::DataType::U8); // change the type of tensor
- node->size<loco::DataType::U8>(size); // resize tensor
- for (uint32_t i = 0; i < size; ++i)
- {
- node->at<loco::DataType::U8>(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
- }
-}
-
-void asym_wquant_per_layer(CircleConst *node, float min, float scaling_factor)
-{
- const int32_t kMinScale = 0;
- const int32_t kMaxScale = 255;
-
- uint32_t size = node->size<loco::DataType::FLOAT32>();
-
- const float scaling_factor_inv = 1.0 / scaling_factor;
- std::vector<int32_t> quantized_values(size);
- for (uint32_t i = 0; i < size; ++i)
- {
- auto data = node->at<loco::DataType::FLOAT32>(i);
- quantized_values[i] = static_cast<int32_t>(std::round((data - min) * scaling_factor_inv));
- }
-
- node->dtype(loco::DataType::U8); // change the type of tensor
- node->size<loco::DataType::U8>(size); // resize tensor
- for (uint32_t i = 0; i < size; ++i)
- {
- node->at<loco::DataType::U8>(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
- }
-}
-
-void set_bias(luci::CircleNode *node, luci::CircleConst *bias)
-{
- if (auto conv = dynamic_cast<CircleConv2D *>(node))
- conv->bias(bias);
- else if (auto dconv = dynamic_cast<CircleDepthwiseConv2D *>(node))
- dconv->bias(bias);
- else if (auto tconv = dynamic_cast<CircleTransposeConv *>(node))
- tconv->bias(bias);
- else if (auto fc = dynamic_cast<CircleFullyConnected *>(node))
- fc->bias(bias);
- else
- throw std::runtime_error("Only convolution, depthwise convolution, transposed convolution, and "
- "fully-connected layer have bias");
-}
-
-void set_act_qparam(luci::CircleNode *node, float scale, int64_t zp)
-{
- assert(node); // FIX_CALLER_UNLESS
- assert(node->quantparam()); // FIX_CALLER_UNLESS
-
- auto qparam = node->quantparam();
- assert(qparam->scale.size() == 1); // FIX_CALLER_UNLESS
- assert(qparam->zerop.size() == 1); // FIX_CALLER_UNLESS
- qparam->scale[0] = scale;
- qparam->zerop[0] = zp;
-}
-
-/**
- * @brief Manually set scale/zp of output tensor of special Ops
- */
-struct QuantizeSpecialActivation final : public luci::CircleNodeMutableVisitor<void>
-{
- QuantizeSpecialActivation(loco::DataType input, loco::DataType output)
- : input_type(input), output_type(output)
- {
- }
-
- loco::DataType input_type;
- loco::DataType output_type;
-
- void visit(luci::CircleNode *)
- {
- // Do nothing by default
- }
-
- void visit(luci::CircleLogistic *node)
- {
- if (output_type == loco::DataType::U8)
- set_act_qparam(node, 1.0f / 256.0f, 0);
- else
- {
- assert(output_type == loco::DataType::S16);
- set_act_qparam(node, 1.0f / 32768.0f, 0);
- }
- }
-
- void visit(luci::CircleTanh *node)
- {
- if (output_type == loco::DataType::U8)
- set_act_qparam(node, 2.0f / 256.0f, 128);
- else
- {
- assert(output_type == loco::DataType::S16);
- set_act_qparam(node, 1.0f / 32768.0f, 0);
- }
- }
-
- void visit(luci::CircleStridedSlice *node)
- {
- auto input = loco::must_cast<luci::CircleNode *>(node->input());
- auto i_qparam = input->quantparam();
- assert(i_qparam);
- assert(i_qparam->scale.size() == 1); // FIX_CALLER_UNLESS
- assert(i_qparam->zerop.size() == 1); // FIX_CALLER_UNLESS
- auto i_scale = i_qparam->scale[0];
- auto i_zp = i_qparam->zerop[0];
-
- set_act_qparam(node, i_scale, i_zp);
- }
-
- void visit(luci::CircleSplitOut *node)
- {
- auto split = loco::must_cast<luci::CircleSplit *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(split->input());
- auto i_qparam = input->quantparam();
- assert(i_qparam);
- assert(i_qparam->scale.size() == 1); // FIX_CALLER_UNLESS
- assert(i_qparam->zerop.size() == 1); // FIX_CALLER_UNLESS
- auto i_scale = i_qparam->scale[0];
- auto i_zp = i_qparam->zerop[0];
-
- set_act_qparam(node, i_scale, i_zp);
- }
-
- void visit(luci::CircleSplitVOut *node)
- {
- auto splitv = loco::must_cast<luci::CircleSplitV *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(splitv->input());
- auto i_qparam = input->quantparam();
- assert(i_qparam);
- assert(i_qparam->scale.size() == 1); // FIX_CALLER_UNLESS
- assert(i_qparam->zerop.size() == 1); // FIX_CALLER_UNLESS
- auto i_scale = i_qparam->scale[0];
- auto i_zp = i_qparam->zerop[0];
-
- set_act_qparam(node, i_scale, i_zp);
- }
-
- void visit(luci::CircleUnpackOut *node)
- {
- auto unpack = loco::must_cast<luci::CircleUnpack *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(unpack->value());
- auto i_qparam = input->quantparam();
- assert(i_qparam);
- assert(i_qparam->scale.size() == 1); // FIX_CALLER_UNLESS
- assert(i_qparam->zerop.size() == 1); // FIX_CALLER_UNLESS
- auto i_scale = i_qparam->scale[0];
- auto i_zp = i_qparam->zerop[0];
-
- set_act_qparam(node, i_scale, i_zp);
- }
-
- // TODO Move Softmax, Floor, Ceil from QuantizeActivation to here
-};
-
/**
- * @brief QuantizeActivation quantizes tensors for activations
- * @details Quantize using recorded min/max values
+ * Insert Quantize operator for mixed-precision quantization
+ * 1. Before input feature map (only for non-const)
+ * 2. After output feature map
+ *
+ * For example, if default_dtype = U8 and op_dtype = S16,
+ * 1. Quantize Op for U8->S16 is inserted before ifm
+ * 2. Quantize Op for S16->U8 is inserted after ofm
+ *
+ * Why not insert Quantize Op for const ifm?
+ * We quantize const tensor at once to preserve precision.
+ * For example, if default dtype = U8, op_dtype = S16, and op is CONV2D,
+ * We directly quantize weights to 16 bits, not 8->16 bits.
*/
-struct QuantizeActivation final : public luci::CircleNodeMutableVisitor<bool>
+struct InsertQuantizeOp final : public luci::CircleNodeMutableVisitor<void>
{
- QuantizeActivation(loco::DataType input, loco::DataType output)
- : input_type(input), output_type(output)
+ InsertQuantizeOp(loco::DataType default_dtype, loco::DataType op_dtype)
+ : _default_dtype(default_dtype), _op_dtype(op_dtype)
{
+ assert(default_dtype != op_dtype); // FIX_CALLER_UNLESS
}
- loco::DataType input_type;
- loco::DataType output_type;
+private:
+ loco::DataType _default_dtype;
+ loco::DataType _op_dtype;
- // Quantize input tensors of each node
- bool visit(luci::CircleNode *node)
+private:
+ luci::CircleQuantize *create_in_quantize(loco::Node *in, loco::Node *origin)
+ {
+ auto input = loco::must_cast<luci::CircleNode *>(in);
+ if (input->opcode() == luci::CircleOpcode::CIRCLECONST)
+ return nullptr;
+
+ auto input_quant = create_quantize_op(input, _op_dtype);
+ input_quant->input(input);
+ auto origin_node = loco::must_cast<luci::CircleNode *>(origin);
+ luci::add_origin(input_quant, luci::get_origin(origin_node));
+ return input_quant;
+ }
+
+ void insert_out_quantize(loco::Node *node)
+ {
+ auto output = loco::must_cast<luci::CircleNode *>(node);
+ assert(output->opcode() != luci::CircleOpcode::CIRCLECONST); // FIX_CALLER_UNLESS
+ auto output_quant = create_quantize_op(output, _default_dtype);
+
+ luci::add_origin(output_quant, luci::get_origin(output));
+ loco::replace(node).with(output_quant);
+ output_quant->input(node);
+ }
+
+// INPUT_NAME is the only activation of NODE
+#define INSERT_QUANTIZE_TO_UNARY_OP(NODE, INPUT_NAME) \
+ void visit(NODE *node) \
+ { \
+ if (auto input_quant = create_in_quantize(node->INPUT_NAME(), node)) \
+ node->INPUT_NAME(input_quant); \
+ \
+ insert_out_quantize(node); \
+ }
+
+// INPUT_NAME is the only activation of NODE
+#define INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(NODE, INPUT_NAME, OUT_NAME) \
+ void visit(NODE *node) \
+ { \
+ if (auto input_quant = create_in_quantize(node->INPUT_NAME(), node)) \
+ node->INPUT_NAME(input_quant); \
+ \
+ auto out_nodes = loco::succs(node); \
+ for (auto out_node : out_nodes) \
+ { \
+ auto out_circle = loco::must_cast<OUT_NAME *>(out_node); \
+ insert_out_quantize(out_circle); \
+ } \
+ }
+
+// INPUT_NAME1 and INPUT_NAME2 are the only activations of NODE
+#define INSERT_QUANTIZE_TO_BINARY_OP(NODE, INPUT_NAME1, INPUT_NAME2) \
+ void visit(NODE *node) \
+ { \
+ if (auto input1_quant = create_in_quantize(node->INPUT_NAME1(), node)) \
+ node->INPUT_NAME1(input1_quant); \
+ \
+ if (auto input2_quant = create_in_quantize(node->INPUT_NAME2(), node)) \
+ node->INPUT_NAME2(input2_quant); \
+ \
+ insert_out_quantize(node); \
+ }
+
+ // Default behavior (NYI)
+ void visit(luci::CircleNode *node)
+ {
+ throw std::runtime_error("Unsupported Op for mixed-precision quantization. Layer name: " +
+ node->name());
+ }
+
+ // Skip output layer
+ void visit(luci::CircleOutput *) {}
+ void visit(luci::CircleSplitVOut *) {}
+ void visit(luci::CircleSplitOut *) {}
+ void visit(luci::CircleTopKV2Out *) {}
+ void visit(luci::CircleUniqueOut *) {}
+ void visit(luci::CircleUnpackOut *) {}
+
+ // Ops that receive a single activation as an input
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleAveragePool2D, value)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleBatchToSpaceND, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleConv2D, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleDepthToSpace, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleDepthwiseConv2D, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleElu, features)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleExp, x)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleFloor, x)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleFullyConnected, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleGather, params)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleInstanceNorm, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLocalResponseNormalization, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLogistic, x)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMaxPool2D, value)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMean, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMirrorPad, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePad, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePadV2, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePRelu, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReduceProd, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReduceMax, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReduceMin, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleRelu, features)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReshape, tensor)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleResizeBilinear, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleResizeNearestNeighbor, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReverseSequence, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleRsqrt, x)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSlice, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSoftmax, logits)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSpaceToBatchND, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSpaceToDepth, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSqrt, x)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleStridedSlice, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSum, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleTanh, x)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleTile, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleTranspose, a)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleTransposeConv, outBackprop)
+
+ // Ops that receive two activations as inputs
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleAdd, x, y)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleBatchMatMul, x, y)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleDiv, x, y)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleFloorDiv, x, y)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleMaximum, x, y)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleMinimum, x, y)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleMul, x, y)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleOneHot, on_value, off_value)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CirclePow, x, y)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleSub, x, y)
+
+ // Multiple-output ops that receive one activation as inputs
+ INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(luci::CircleSplit, input, luci::CircleSplitOut)
+ INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(luci::CircleSplitV, input, luci::CircleSplitVOut)
+ INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(luci::CircleTopKV2, input, luci::CircleTopKV2Out)
+ INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(luci::CircleUnique, input, luci::CircleUniqueOut)
+ INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(luci::CircleUnpack, value, luci::CircleUnpackOut)
+
+ // AddN has arbitrary number of inputs
+ void visit(luci::CircleAddN *node)
{
- LOGGER(l);
- INFO(l) << "QuantizeActivation visit node: " << node->name() << std::endl;
auto arity = node->arity();
for (uint32_t i = 0; i < arity; i++)
{
- auto input_node = node->arg(i);
- auto circle_node = loco::must_cast<luci::CircleNode *>(input_node);
-
- // Check if this is already quantized
- if (is_quantized(circle_node))
- continue;
-
- // Check if this is bias (bias is quantized later)
- auto iwo = get_input_weight_output_of_bias(circle_node);
- if (iwo.size() > 0)
- continue;
-
- // Check if this is bool type (bool type is not quantized)
- if (circle_node->dtype() == loco::DataType::BOOL)
- continue;
-
- // Check if this is activation
- // We assume min/max are recorded only for activations
- if (has_min_max(circle_node) && !is_weights(circle_node))
- {
- // Quantize using recorded min/max
- auto quantparam = circle_node->quantparam();
- assert(quantparam);
- assert(quantparam->min.size() == 1); // only support layer-wise quant
- assert(quantparam->max.size() == 1); // only support layer-wise quant
- auto min = quantparam->min[0];
- auto max = quantparam->max[0];
-
- // Special values
- if (circle_node->opcode() == luci::CircleOpcode::SOFTMAX)
- {
- min = 0.0f;
- max = 1.0f;
- }
-
- float scaling_factor{0};
- int64_t zp{0};
- float nudged_min{0};
- float nudged_max{0};
-
- if (output_type == loco::DataType::U8)
- {
- compute_asym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
- circle_node->dtype(loco::DataType::U8);
- }
- else
- {
- compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
- circle_node->dtype(loco::DataType::S16);
- }
-
- // Nodes fused with activation functions which need special quantization
- auto fused_act_node =
- dynamic_cast<CircleNodeMixin<CircleNodeTrait::FusedActFunc> *>(circle_node);
- if (fused_act_node != nullptr &&
- fused_act_node->fusedActivationFunction() == FusedActFunc::TANH)
- {
- if (output_type == loco::DataType::U8)
- {
- scaling_factor = 2.0f / 256.0f;
- zp = 128;
- }
- else
- {
- assert(output_type == loco::DataType::S16);
- scaling_factor = 1.0f / 32768.0f;
- zp = 0;
- }
- }
-
- // The output of these Ops should be integer, so scale should be integer
- // TODO Handle cases where the integer scale needs to be propagated
- if (circle_node->opcode() == CircleOpcode::FLOOR ||
- circle_node->opcode() == CircleOpcode::FLOOR_DIV ||
- circle_node->opcode() == CircleOpcode::FLOOR_MOD ||
- circle_node->opcode() == CircleOpcode::CEIL)
- {
- assert(scaling_factor >= 0); // FIX_ME_UNLESS
- scaling_factor = scaling_factor < 1 ? 1.0f : std::round(scaling_factor);
- }
-
- circle_node->quantparam()->scale.push_back(scaling_factor);
- circle_node->quantparam()->zerop.push_back(zp);
- }
- // Fix special attributes
- if (circle_node->opcode() == luci::CircleOpcode::CAST)
- {
- auto *cast = loco::must_cast<luci::CircleCast *>(circle_node);
- auto *cast_input = loco::must_cast<luci::CircleNode *>(cast->x());
-
- // make sure that cast_input is already quantized
- assert(cast_input->dtype() != loco::DataType::FLOAT32);
- cast->in_data_type(cast_input->dtype());
- cast->out_data_type(cast->dtype());
- }
- }
- return false;
- }
-};
-
-struct QuantizeBias final : public luci::CircleNodeMutableVisitor<bool>
-{
- QuantizeBias(loco::DataType input, loco::DataType output, QuantizationGranularity gr)
- : input_type(input), output_type(output), granularity(gr)
- {
- }
-
- loco::DataType input_type;
- loco::DataType output_type;
- QuantizationGranularity granularity;
-
- // Quantize bias node
- bool visit(luci::CircleNode *node)
- {
- // Check if this is already quantized
- if (is_quantized(node))
- return false;
-
- auto iwo_list = get_input_weight_output_of_bias(node);
-
- for (auto iwo : iwo_list)
- {
- assert(iwo.size() == 3);
-
- auto input = loco::must_cast<luci::CircleNode *>(iwo[0]);
- auto weight = loco::must_cast<luci::CircleNode *>(iwo[1]);
- auto output = loco::must_cast<luci::CircleNode *>(iwo[2]);
-
- auto const_bias = loco::must_cast<luci::CircleConst *>(node);
- assert(const_bias->dtype() == loco::DataType::FLOAT32);
-
- // If input is const, it is quantized here, not in QuantizeActivation
- if (auto const_input = dynamic_cast<luci::CircleConst *>(input))
- {
- quant_const(const_input, output_type);
- }
-
- CircleConst *new_bias = nullptr;
-
- if (granularity == QuantizationGranularity::ChannelWise)
- {
- auto input_q = input->quantparam();
- assert(input_q);
- assert(input_q->scale.size() == 1); // input scale's layer-wise
- auto input_scale = input_q->scale[0];
-
- assert(weight->quantparam() != nullptr); // weight scale's channel-wise
- auto weight_scale = weight->quantparam()->scale;
-
- uint32_t size = const_bias->size<loco::DataType::FLOAT32>();
- assert(size == weight_scale.size());
- std::vector<float> scaling_factor(size);
- std::vector<int64_t> zp(size);
-
- if (output_type == loco::DataType::U8)
- {
- new_bias =
- quant_bias_per_channel(const_bias, input_scale, weight_scale, scaling_factor, zp);
- }
- else if (output_type == loco::DataType::S16)
- {
- new_bias =
- int16_quant_bias_per_channel(const_bias, input_scale, weight_scale, scaling_factor, zp);
- }
- else
- {
- throw std::runtime_error("Unsupported quantization type.");
- }
-
- auto quantparam = std::make_unique<CircleQuantParam>();
- quantparam->scale = scaling_factor;
- quantparam->zerop = zp;
- assert(new_bias->quantparam() == nullptr); // bias should not be quantized before
- new_bias->quantparam(std::move(quantparam));
-
- set_bias(output, new_bias);
- }
- else
- {
- auto input_q = input->quantparam();
- assert(input_q);
- assert(input_q->scale.size() == 1); // Only support per-layer quant
- auto input_scale = input_q->scale[0];
-
- auto weight_q = weight->quantparam();
- assert(weight_q);
- assert(weight_q->scale.size() == 1); // Only support per-layer quant
- auto weight_scale = weight_q->scale[0];
-
- float scaling_factor{0};
- int64_t zp{0};
- new_bias =
- asym_quant_bias_per_layer(const_bias, input_scale, weight_scale, &scaling_factor, &zp);
- auto quantparam = std::make_unique<CircleQuantParam>();
- quantparam->scale.push_back(scaling_factor);
- quantparam->zerop.push_back(zp);
- assert(new_bias->quantparam() == nullptr); // bias should not be quantized before
- new_bias->quantparam(std::move(quantparam));
-
- set_bias(output, new_bias);
- }
- }
- return false;
- }
-};
-
-/**
- * @brief QuantizeWeights quantizes tensors for weights
- * @details Find min/max values on the fly and then quantize
- */
-struct QuantizeWeights final : public luci::CircleNodeMutableVisitor<bool>
-{
- QuantizeWeights(loco::DataType input, loco::DataType output, QuantizationGranularity gr)
- : input_type(input), output_type(output), granularity(gr)
- {
- }
-
- loco::DataType input_type;
- loco::DataType output_type;
- QuantizationGranularity granularity;
-
-private:
- void quantize_weights(luci::CircleConst *weights)
- {
- // Find min/max per channel-wise
- if (granularity == QuantizationGranularity::ChannelWise)
- {
- auto quantparam = weights->quantparam();
- if (quantparam == nullptr)
- {
- assert(false && "quantparam is nullptr");
- return;
- }
-
- auto min = quantparam->min;
- auto scaling_factor = quantparam->scale;
- int32_t channel_dim_index = 0;
-
- if (output_type == loco::DataType::U8)
- {
- asym_wquant_per_channel(weights, min, scaling_factor, channel_dim_index);
- }
- else
- {
- sym_wquant_per_channel(weights, scaling_factor, channel_dim_index);
- }
- quantparam->min.clear();
- quantparam->max.clear();
- quantparam->quantized_dimension = channel_dim_index;
- }
- // Find min/max per layer-wise
- else
- {
- // Quantize using recorded quantparam
- auto quantparam = weights->quantparam();
- assert(quantparam != nullptr);
- assert(quantparam->min.size() == 1); // only support layer-wise quant
- assert(quantparam->scale.size() == 1); // only support layer-wise quant
- auto min = quantparam->min[0];
- auto scaling_factor = quantparam->scale[0];
- asym_wquant_per_layer(weights, min, scaling_factor);
- quantparam->min.clear();
- quantparam->max.clear();
- }
- }
-
- bool visit(luci::CircleConv2D *node)
- {
- LOGGER(l);
- INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
-
- auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
- if (!is_quantized(weights))
- {
- auto new_weights = luci::clone(weights);
- node->filter(new_weights);
- quantize_weights(new_weights);
- return true;
+ if (auto input_quant = create_in_quantize(node->inputs(i), node))
+ node->inputs(i, input_quant);
}
- return false;
- }
-
- bool visit(luci::CircleDepthwiseConv2D *node)
- {
- LOGGER(l);
- INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
- auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
- if (!is_quantized(weights))
- {
- auto new_weights = luci::clone(weights);
- node->filter(new_weights);
- quantize_weights(new_weights);
- return true;
- }
- return false;
+ insert_out_quantize(node);
}
- bool visit(luci::CircleInstanceNorm *node)
+ // Concat has arbitrary number of inputs
+ void visit(luci::CircleConcatenation *node)
{
- LOGGER(l);
- INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
-
- auto gamma = loco::must_cast<luci::CircleConst *>(node->gamma());
- auto beta = loco::must_cast<luci::CircleConst *>(node->beta());
-
- bool changed = false;
- if (!is_quantized(gamma))
- {
- assert(gamma->dtype() == loco::DataType::FLOAT32);
- auto new_gamma = luci::clone(gamma);
- if (granularity == QuantizationGranularity::LayerWise)
- quant_const(new_gamma, output_type);
- else if (granularity == QuantizationGranularity::ChannelWise)
- quant_const_per_channel(new_gamma, output_type);
- node->gamma(new_gamma);
- changed = true;
- }
- if (!is_quantized(beta))
- {
- assert(beta->dtype() == loco::DataType::FLOAT32);
- auto new_beta = luci::clone(beta);
- if (granularity == QuantizationGranularity::LayerWise)
- quant_const(new_beta, output_type);
- else if (granularity == QuantizationGranularity::ChannelWise)
- quant_const_per_channel(new_beta, output_type);
- node->beta(new_beta);
- changed = true;
- }
-
- return changed;
- }
-
- bool visit(luci::CirclePRelu *node)
- {
- LOGGER(l);
- INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
-
- auto alpha = loco::must_cast<luci::CircleConst *>(node->alpha());
-
- if (!is_quantized(alpha))
+ auto arity = node->arity();
+ for (uint32_t i = 0; i < arity; i++)
{
- assert(alpha->dtype() == loco::DataType::FLOAT32);
- auto new_alpha = luci::clone(alpha);
- if (granularity == QuantizationGranularity::LayerWise)
- quant_const(new_alpha, output_type);
- else if (granularity == QuantizationGranularity::ChannelWise)
- quant_const_per_channel(new_alpha, output_type);
- node->alpha(new_alpha);
- return true;
+ if (auto input_quant = create_in_quantize(node->values(i), node))
+ node->values(i, input_quant);
}
- return false;
+ insert_out_quantize(node);
}
- bool visit(luci::CircleTransposeConv *node)
+ // Pack has arbitrary number of inputs
+ void visit(luci::CirclePack *node)
{
- LOGGER(l);
- INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
-
- auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
- if (!is_quantized(weights))
+ auto arity = node->arity();
+ for (uint32_t i = 0; i < arity; i++)
{
- auto new_weights = luci::clone(weights);
- node->filter(new_weights);
- quantize_weights(new_weights);
- return true;
+ if (auto input_quant = create_in_quantize(node->values(i), node))
+ node->values(i, input_quant);
}
- return false;
- }
-
- bool visit(luci::CircleFullyConnected *node)
- {
- LOGGER(l);
- INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
- auto weights = loco::must_cast<luci::CircleConst *>(node->weights());
- if (!is_quantized(weights))
- {
- auto new_weights = luci::clone(weights);
- node->weights(new_weights);
- quantize_weights(new_weights);
- return true;
- }
- return false;
+ insert_out_quantize(node);
}
- bool visit(luci::CircleNode *) { return false; }
+#undef INSERT_QUANTIZE_TO_UNARY_OP
+#undef INSERT_QUANTIZE_TO_BINARY_OP
+#undef INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP
};
-/** EXAMPLE
- *
- * BEFORE
- *
- * [CircleNode] [CircleConst]
- * (qparam1) (FP32)
- * \ /
- * \ /
- * [CirclePack]
- * (qparam2)
- *
- * AFTER
- *
- * [CircleNode] [CircleConst] [CircleConst] <- Dead node
- * (qparam2) (qparam2) (FP32)
- * \ /
- * \ /
- * [CirclePack]
- * (qparam2)
- *
- * NOTE Quantization parameter of CirclePack (qparam2) is propagated to the inputs.
- */
-void propagate_pack_quantparam(luci::CirclePack *pack, loco::DataType quant_type)
-{
- assert(pack->quantparam() != nullptr);
-
- const auto num_inputs = pack->values_count();
-
- for (uint32_t i = 0; i < num_inputs; i++)
- {
- auto node = loco::must_cast<luci::CircleNode *>(pack->arg(i));
-
- // Skip if this input is PACK Op
- if (node->opcode() == luci::CircleOpcode::PACK)
- continue;
-
- // Quantize constant values
- if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
- {
- luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
- if (const_node->dtype() != loco::DataType::FLOAT32)
- throw std::runtime_error("Unsupported data type for constant input of pack Op");
-
- const auto pack_qparam = pack->quantparam();
- if (pack_qparam == nullptr)
- throw std::runtime_error("quantparam of pack is not found during propagation");
-
- assert(pack_qparam->scale.size() == 1);
- assert(pack_qparam->zerop.size() == 1);
- const auto scaling_factor = pack_qparam->scale[0];
- const auto zerop = pack_qparam->zerop[0];
-
- auto new_const = luci::clone(const_node);
- quant_const_values(new_const, scaling_factor, zerop, quant_type);
- pack->values(i, new_const);
- overwrite_quantparam(pack, new_const);
- }
- else
- {
- const auto succs = loco::succs(node);
- if (succs.size() > 1)
- continue;
-
- // Non-const input must have been quantized
- assert(node->quantparam() != nullptr);
- overwrite_quantparam(pack, node);
- }
- }
-}
-
-/**
- * @brief Quantize const input tensors using min/max of const values
- */
-void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type)
-{
- auto opcode = node->opcode();
- auto arity = node->arity();
-
- loco::Node *input_node{nullptr};
- luci::CircleConst *const_node{nullptr};
-
- switch (opcode)
- {
- case luci::CircleOpcode::CONV_2D:
- case luci::CircleOpcode::DEPTHWISE_CONV_2D:
- case luci::CircleOpcode::FULLY_CONNECTED:
- case luci::CircleOpcode::INSTANCE_NORM:
- case luci::CircleOpcode::PRELU:
- case luci::CircleOpcode::TRANSPOSE_CONV:
- // Handled in QuantizeWeights and QuantizeBias
- break;
-
- case luci::CircleOpcode::CONCATENATION:
- // Handled in propagate_concat_quantparam
- break;
-
- case luci::CircleOpcode::LOGICAL_OR:
- // Inputs of logical Ops are bool, thus not quantized
- break;
-
- case luci::CircleOpcode::ARG_MAX:
- case luci::CircleOpcode::ARG_MIN:
- case luci::CircleOpcode::BATCH_TO_SPACE_ND:
- case luci::CircleOpcode::LOCAL_RESPONSE_NORMALIZATION:
- case luci::CircleOpcode::MEAN:
- case luci::CircleOpcode::MIRROR_PAD:
- case luci::CircleOpcode::PAD:
- case luci::CircleOpcode::REDUCE_ANY:
- case luci::CircleOpcode::REDUCE_PROD:
- case luci::CircleOpcode::REDUCE_MAX:
- case luci::CircleOpcode::REDUCE_MIN:
- case luci::CircleOpcode::RESHAPE:
- case luci::CircleOpcode::RESIZE_BILINEAR:
- case luci::CircleOpcode::RESIZE_NEAREST_NEIGHBOR:
- case luci::CircleOpcode::REVERSE_SEQUENCE:
- case luci::CircleOpcode::SLICE:
- case luci::CircleOpcode::SPACE_TO_BATCH_ND:
- case luci::CircleOpcode::SPLIT_V:
- case luci::CircleOpcode::STRIDED_SLICE:
- case luci::CircleOpcode::SUM:
- case luci::CircleOpcode::TILE:
- case luci::CircleOpcode::TOPK_V2:
- case luci::CircleOpcode::TRANSPOSE:
- // The second input of these Ops should not be quantized
- // Ex: axis, paddings
- input_node = node->arg(0);
- const_node = dynamic_cast<luci::CircleConst *>(input_node);
- if (const_node != nullptr && !is_quantized(const_node))
- quant_const(const_node, output_type);
- break;
-
- case luci::CircleOpcode::ADD:
- case luci::CircleOpcode::ADD_N:
- case luci::CircleOpcode::DEPTH_TO_SPACE:
- case luci::CircleOpcode::DIV:
- case luci::CircleOpcode::ELU:
- case luci::CircleOpcode::EQUAL:
- case luci::CircleOpcode::EXP:
- case luci::CircleOpcode::FLOOR:
- case luci::CircleOpcode::FLOOR_DIV:
- case luci::CircleOpcode::GREATER:
- case luci::CircleOpcode::GREATER_EQUAL:
- case luci::CircleOpcode::LESS:
- case luci::CircleOpcode::LESS_EQUAL:
- case luci::CircleOpcode::LOGISTIC:
- case luci::CircleOpcode::MAXIMUM:
- case luci::CircleOpcode::MINIMUM:
- case luci::CircleOpcode::MUL:
- case luci::CircleOpcode::NOT_EQUAL:
- case luci::CircleOpcode::POW:
- case luci::CircleOpcode::RSQRT:
- case luci::CircleOpcode::SOFTMAX:
- case luci::CircleOpcode::SPACE_TO_DEPTH:
- case luci::CircleOpcode::SQRT:
- case luci::CircleOpcode::SUB:
- case luci::CircleOpcode::TANH:
- case luci::CircleOpcode::UNPACK:
- // Quantize all const inputs using their values
- for (uint32_t i = 0; i < arity; i++)
- {
- input_node = node->arg(i);
- const_node = dynamic_cast<luci::CircleConst *>(input_node);
- if (const_node != nullptr && !is_quantized(const_node))
- quant_const(const_node, output_type);
- }
- break;
-
- case luci::CircleOpcode::SPLIT:
- // Only the second input is quantized
- // First input should not be quantized (e.g., split_dim)
- input_node = node->arg(1);
- const_node = dynamic_cast<luci::CircleConst *>(input_node);
- if (const_node != nullptr && !is_quantized(const_node))
- quant_const(const_node, output_type);
- break;
-
- case luci::CircleOpcode::PADV2:
- // First and third constant inputs are quantized
- // Second input should not be quantized (e.g., paddings)
- // Quant params are propagated either from output range to the non-constant input
- // or from input to output and constant values
- propagate_pad_v2_quantparam(loco::must_cast<CirclePadV2 *>(node), output_type);
- break;
-
- case luci::CircleOpcode::PACK:
- // Quant param is propagated from output to inputs
- propagate_pack_quantparam(loco::must_cast<CirclePack *>(node), output_type);
- break;
-
- default:
- for (uint32_t i = 0; i < arity; i++)
- {
- input_node = node->arg(i);
- const_node = dynamic_cast<luci::CircleConst *>(input_node);
- if (const_node != nullptr)
- throw std::runtime_error("Unsupported Op for const inputs");
- }
- break;
- }
-}
-
} // namespace
-/** BEFORE
- *
- * [CircleNode] [CircleConst]
- * (U8 qparam1) (FP32)
- * \ /
- * \ /
- * [CircleConcatenation]
- * (U8 qparam2)
- *
- * AFTER
- * [CircleNode] [CircleConst] [CircleConst] <- Dead node
- * (U8 qparam2) (U8 qparam2) (FP32)
- * \ /
- * \ /
- * [CircleConcatenation]
- * (U8 qparam2)
- */
-void propagate_concat_quantparam(luci::CircleConcatenation *concat, loco::DataType quant_type)
-{
- assert(concat->quantparam() != nullptr);
-
- const auto num_inputs = concat->numValues();
-
- // Quantize const inputs using their values if concat has fused act function
- if (concat->fusedActivationFunction() != luci::FusedActFunc::NONE)
- {
- for (uint32_t i = 0; i < num_inputs; i++)
- {
- auto node = concat->arg(i);
- auto const_node = dynamic_cast<luci::CircleConst *>(node);
- if (const_node != nullptr)
- {
- auto new_const = luci::clone(const_node);
- quant_const(new_const, quant_type);
- concat->values(i, new_const);
- }
- }
- return;
- }
-
- for (uint32_t i = 0; i < num_inputs; i++)
- {
- auto node = loco::must_cast<luci::CircleNode *>(concat->arg(i));
-
- // Skip if this input is CONCAT Op
- if (node->opcode() == luci::CircleOpcode::CONCATENATION)
- continue;
-
- // Quantize constant values
- if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
- {
- luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
- if (const_node->dtype() != loco::DataType::FLOAT32)
- throw std::runtime_error("Unsupported data type for constant input of concatenation Op");
-
- const auto concat_qparam = concat->quantparam();
- if (concat_qparam == nullptr)
- throw std::runtime_error("quantparam of concat is not found during propagation");
-
- assert(concat_qparam->scale.size() == 1);
- const auto scaling_factor = concat_qparam->scale[0];
- const auto zerop = concat_qparam->zerop[0];
-
- auto new_const = luci::clone(const_node);
- quant_const_values(new_const, scaling_factor, zerop, quant_type);
- concat->values(i, new_const);
- overwrite_quantparam(concat, new_const);
- }
- else
- {
- const auto succs = loco::succs(node);
- if (succs.size() > 1)
- continue;
-
- // Non-const input must have been quantized
- assert(node->quantparam() != nullptr);
- overwrite_quantparam(concat, node);
- }
- }
-}
-
-/**
- * tells if pad_v2 quantization should ignore padding value
- * In that case padding const will be quantized with input parameters, and probably clipped
- */
-bool ignore_pad_v2_const_quantization(luci::CirclePadV2 *pad)
-{
- // This is a workaround to quantize pad generated from MaxPoolWithArgmax operation properly
- // TODO use metadata hints to detect this case
- auto const_value_node = dynamic_cast<luci::CircleConst *>(pad->arg(2));
- if (!const_value_node)
- return false;
- if (const_value_node->dtype() == loco::DataType::FLOAT32)
- {
- float const_value = const_value_node->at<loco::DataType::FLOAT32>(0);
- if (const_value == std::numeric_limits<float>::lowest())
- return true;
- }
- return false;
-}
-
-/** BEFORE
- *
- * [CircleNode] [CircleConst] [CircleConst]
- * (U8 qparam1) (S32) (FP32)
- * \ | /
- * \ | /
- * [CirclePadV2]
- * (U8 qparam2)
- *
- * AFTER (case 1)
- *
- * By default qparam is propagated from output to inputs to meet backend requirements.
- *
- * [CircleNode] [CircleConst] [CircleConst] [CircleConst] <- Dead node
- * (U8 qparam2) (S32) (U8 qparam2) (FP32)
- * \ | /
- * \ | /
- * [CirclePadV2]
- * (U8 qparam2)
- *
- * AFTER (case 2)
- *
- * In case padded value is the lowest float value
- * Qparam is propagated from input to output and constant.
- *
- * This is a special case for optimization constructed pad, needed to guarantee that
- * extremely large negative constant do not stretch output quantization range.
- *
- * [CircleNode] [CircleConst] [CircleConst] [CircleConst] <- Dead node
- * (U8 qparam1) (S32) (U8 qparam1) (FP32)
- * \ | /
- * \ | /
- * [CirclePadV2]
- * (U8 qparam1)
- */
-void propagate_pad_v2_quantparam(luci::CirclePadV2 *pad_v2, loco::DataType quant_type)
-{
- if (ignore_pad_v2_const_quantization(pad_v2))
- {
- // propagate input quantization paramters from input to output and padding const value
- auto pad_v2_input = loco::must_cast<luci::CircleNode *>(pad_v2->arg(0));
- overwrite_quantparam(pad_v2_input, pad_v2);
-
- auto const_value_node = loco::must_cast<luci::CircleConst *>(
- pad_v2->arg(2)); // FIX ignore_pad_v2_const_quantization UNLESS
- auto new_const = luci::clone(const_value_node);
-
- const auto pad_v2_input_qparam = pad_v2_input->quantparam();
- assert(pad_v2_input_qparam != nullptr);
- assert(pad_v2_input_qparam->scale.size() == 1);
- const auto scaling_factor = pad_v2_input_qparam->scale.at(0);
- const auto zerop = pad_v2_input_qparam->zerop.at(0);
-
- quant_const_values(new_const, scaling_factor, zerop, quant_type);
- overwrite_quantparam(pad_v2_input, new_const);
- pad_v2->constant_values(new_const);
- return;
- }
-
- // Propagate quantization paramters from output to inputs,
- // to fit both input and counstant_value in one quant range.
- auto quant_input = [pad_v2, quant_type](void (CirclePadV2::*arg_setter)(loco::Node *),
- uint32_t arg) {
- auto node = loco::must_cast<luci::CircleNode *>(pad_v2->arg(arg));
-
- // Quantize constant values
- if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
- {
- luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
- if (is_quantized(const_node))
- return;
-
- if (const_node->dtype() != loco::DataType::FLOAT32)
- throw std::runtime_error("Unsupported data type for constant input of PadV2 Op");
-
- const auto pad_v2_qparam = pad_v2->quantparam();
- if (pad_v2_qparam == nullptr)
- throw std::runtime_error("quantparam of PadV2 is not found during propagation");
-
- assert(pad_v2_qparam->scale.size() == 1);
- const auto scaling_factor = pad_v2_qparam->scale.at(0);
- const auto zerop = pad_v2_qparam->zerop.at(0);
-
- auto new_const = luci::clone(const_node);
- quant_const_values(new_const, scaling_factor, zerop, quant_type);
- overwrite_quantparam(pad_v2, new_const);
- (pad_v2->*arg_setter)(new_const);
- }
- // Subsequent PadV2 Ops quant params are not propagated
- else if (node->opcode() == luci::CircleOpcode::PADV2)
- {
- return;
- }
- else
- {
- const auto succs = loco::succs(node);
- if (succs.size() > 1)
- return;
-
- // Non-const input must have been quantized
- assert(node->quantparam() != nullptr);
- overwrite_quantparam(pad_v2, node);
- }
- };
-
- quant_input(&CirclePadV2::input, 0);
- quant_input(&CirclePadV2::constant_values, 2);
-}
-
void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const
{
auto inputs = g->inputs();
for (auto node : loco::input_nodes(g))
{
auto input = loco::must_cast<luci::CircleInput *>(node);
- if (input->dtype() == _input_type)
+ if (input->dtype() == _ctx->input_type)
continue;
// Bool type is not quantizable
if (input->dtype() == loco::DataType::BOOL)
continue;
+ if (input->dtype() == loco::DataType::S32)
+ continue;
+ if (input->dtype() == loco::DataType::S64)
+ continue;
// Insert Quantize Op
auto quant_op = create_quantize_op(input, input->dtype());
@@ -1552,22 +367,22 @@ void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const
float nudged_min{0};
float nudged_max{0};
- if (_input_type == loco::DataType::U8)
+ if (_ctx->input_type == loco::DataType::U8)
{
compute_asym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
}
else
{
- assert(_input_type == loco::DataType::S16);
+ assert(_ctx->input_type == loco::DataType::S16);
compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
}
- input->dtype(_input_type);
+ input->dtype(_ctx->input_type);
input->quantparam()->scale[0] = scaling_factor;
input->quantparam()->zerop[0] = zp;
}
auto graph_input = inputs->at(input->index());
- graph_input->dtype(_input_type);
+ graph_input->dtype(_ctx->input_type);
}
}
@@ -1577,7 +392,7 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const
for (auto node : loco::output_nodes(g))
{
auto output = loco::must_cast<luci::CircleOutput *>(node);
- if (output->dtype() == _output_type)
+ if (output->dtype() == _ctx->output_type)
continue;
// Bool type is not quantizable
@@ -1591,7 +406,7 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const
continue;
// Insert Quantize Op
- auto quant_op = create_quantize_op(from, _output_type);
+ auto quant_op = create_quantize_op(from, _ctx->output_type);
loco::replace(from).with(quant_op);
quant_op->input(from);
@@ -1599,67 +414,165 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const
luci::add_origin(quant_op, luci::get_origin(from));
auto graph_output = outputs->at(output->index());
- graph_output->dtype(_output_type);
+ graph_output->dtype(_ctx->output_type);
}
}
+/**
+ * How QuantizeWithMinMax works?
+ *
+ * We categorized tensors into four groups
+ * - Activation: Feature maps (both Const/Non-const)
+ * - Weights: Const tensors of specific Ops (Conv, FC, ...)
+ * - Bias: Const tensors of specific Ops (Conv, FC, ...)
+ * - Others: padding value, one_hot value, axis, ..
+ *
+ * Activation is quantized in different ways
+ * 1. For non-constant activation, quantize using recorded min/max
+ * 2. For constant activation, quantize using min/max of its value
+ * 3. For some Ops (ex: pad_v2), output qparam is used as input qparam (backward propagation)
+ * 4. For some Ops (ex: reshape), input qparam is used as output qparam (forward propagation)
+ * 5. For some Ops (ex: tanh), output qparam has pre-defined values
+ *
+ * Weights is quantized using min/max of its value
+ *
+ * Bias is quantized using input scale (s_i) and weights scale (s_w)
+ * - Activation and weights should be quantized earlier than bias
+ *
+ * Quantization Steps
+ * 1. Quantize Activation
+ * - Quantize using recorded min/max (QuantizeActivation)
+ * - Insert Quantize Ops for mixed-precision quantization (InsertQuantizeOp)
+ * - Remove redundant Quantize Ops (RemoveRedundantQuantizePass)
+ * - Propagate qparam backward (PropagateQParamBackwardPass)
+ * - Quantize const inputs (QuantizeConstInputActivation)
+ * - Quantize using pre-defined values (QuantizeSpecialActivation)
+ * - Propagate qparam forward (PropagateQParamForwardPass)
+ * 2. Quantize Weights
+ * 3. Quantize Bias
+ * 4. Set input dtype
+ * 5. Set output dtype
+ *
+ * Why quantization sequence was determined as above?
+ * - Activation and weights should be quantized before bias (1->2->3). Input/Output
+ * dtype can be updated at the end (4->5).
+ * - During activation quantization,
+ * - Backward propagation is performed earlier than forward propagation. This allows
+ * backward-propagated qpram to be overwritten during forward propagation.
+ * We made this decision as Ops for forward propagation (reshape, transpose, ..)
+ * are more common than backward propagation. TODO Check this decision is safe.
+ * - QuantizeSpecialActivation is called before forward propagation to make sure that
+ * the pre-defined qparam values are propagated.
+ */
bool QuantizeWithMinMaxPass::run(loco::Graph *g)
{
LOGGER(l);
INFO(l) << "QuantizeWithMinMaxPass Start" << std::endl;
+ auto info_by_name = layer_info_map(g, _ctx->layers_info);
+
+ auto quantize_dtype = [&](const luci::CircleNode *node) {
+ auto iter = info_by_name.find(node->name());
+
+ // Return designated quantization dtype
+ if (iter != info_by_name.end())
+ return iter->second.dtype;
+
+ // Return default quantization dtype
+ return _ctx->output_model_dtype;
+ };
+
+ auto quantize_granularity = [&](const luci::CircleNode *node) {
+ auto iter = info_by_name.find(node->name());
+
+ // Return designated quantization granularity
+ if (iter != info_by_name.end())
+ return iter->second.granularity;
+
+ // Return default quantization granularity
+ return _ctx->granularity;
+ };
+
// Quantize activation
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
- QuantizeActivation qa(_input_model_dtype, _output_model_dtype);
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ QuantizeActivation qa(_ctx->input_model_dtype, quantize_dtype(circle_node));
circle_node->accept(&qa);
}
- // Quantize weights
+ // Insert Quantize Op
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
- QuantizeWeights qw(_input_model_dtype, _output_model_dtype, _granularity);
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- circle_node->accept(&qw);
+ auto op_dtype = quantize_dtype(circle_node);
+ if (op_dtype != _ctx->output_model_dtype)
+ {
+ InsertQuantizeOp iqo(_ctx->output_model_dtype, op_dtype);
+ circle_node->accept(&iqo);
+ }
}
- // Quantize bias
+ // Remove redundant Quantize Op
+ {
+ logo::Phase phase;
+
+ phase.emplace_back(std::make_unique<luci::RemoveRedundantQuantizePass>());
+
+ ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
+ logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
+ phase_runner.attach(&prog);
+ phase_runner.run(phase);
+ }
+
+ // Backward propagation of activation qparam
+ {
+ PropagateQParamBackwardPass pqbp(_ctx->output_model_dtype);
+ pqbp.run(g);
+ }
+
+ // Quantize const input activation
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
- QuantizeBias qb(_input_model_dtype, _output_model_dtype, _granularity);
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- circle_node->accept(&qb);
+ QuantizeConstInputActivation qcia(quantize_dtype(circle_node));
+ circle_node->accept(&qcia);
}
- // Propagate quantization parameters of concat Op
+ // Update qparam of output of special Ops
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
- auto concat = dynamic_cast<luci::CircleConcatenation *>(node);
- if (not concat)
- continue;
-
- // Propagate qparam of concat to its inputs if
- // (1) concat is uint8-quantized
- // (2) concat has no fused activation function
- // (3) the input is not concatenation Op
- // (4) the input is not produced to Ops other than concat
- propagate_concat_quantparam(concat, _output_model_dtype);
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ QuantizeSpecialActivation qsa(_ctx->input_model_dtype, quantize_dtype(circle_node));
+ circle_node->accept(&qsa);
}
- // Quantize const inputs other than weights and bias
+ // Forward propagation of activation qparam
+ logo::Phase phase;
+
+ phase.emplace_back(std::make_unique<luci::PropagateQParamForwardPass>(_ctx->TF_style_maxpool));
+
+ ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
+ logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
+ phase_runner.attach(&prog);
+ phase_runner.run(phase);
+
+ // Quantize weights
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- quantize_const_inputs(circle_node, _output_model_dtype);
+ QuantizeWeights qw(_ctx->input_model_dtype, quantize_dtype(circle_node),
+ quantize_granularity(circle_node));
+ circle_node->accept(&qw);
}
- // Update qparam of output of special Ops
+ // Quantize bias
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
- QuantizeSpecialActivation qsa(_input_model_dtype, _output_model_dtype);
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- circle_node->accept(&qsa);
+ QuantizeBias qb(_ctx->input_model_dtype, quantize_dtype(circle_node),
+ quantize_granularity(circle_node));
+ circle_node->accept(&qb);
}
// Update output dtype
@@ -1667,11 +580,11 @@ bool QuantizeWithMinMaxPass::run(loco::Graph *g)
for (auto node : loco::output_nodes(g))
{
auto circle_node = loco::must_cast<luci::CircleOutput *>(node);
- if (static_cast<luci::CircleNode *>(circle_node->from())->dtype() == _output_model_dtype)
+ if (static_cast<luci::CircleNode *>(circle_node->from())->dtype() == _ctx->output_model_dtype)
{
- circle_node->dtype(_output_model_dtype);
+ circle_node->dtype(_ctx->output_model_dtype);
auto graph_output = graph_outputs->at(circle_node->index());
- graph_output->dtype(_output_model_dtype);
+ graph_output->dtype(_ctx->output_model_dtype);
}
}
diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp
index 75ec0cfd8..d5fa21ffd 100644
--- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp
+++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp
@@ -16,8 +16,41 @@
#include "luci/Pass/QuantizeWithMinMaxPass.h"
+#include <luci/IR/CircleNodes.h>
+
#include <gtest/gtest.h>
+class SimpleConcatGraph
+{
+public:
+ SimpleConcatGraph(loco::DataType quant_type)
+ {
+ concat_node = g.nodes()->create<luci::CircleConcatenation>(2);
+ input_1 = g.nodes()->create<luci::CircleConst>();
+ input_2 = g.nodes()->create<luci::CircleConst>();
+
+ concat_node->dtype(quant_type);
+ concat_node->fusedActivationFunction(luci::FusedActFunc::NONE);
+ input_1->dtype(quant_type);
+ input_2->dtype(quant_type);
+
+ concat_node->values(0, input_1);
+ concat_node->values(1, input_2);
+ }
+
+ ~SimpleConcatGraph()
+ {
+ concat_node->values(0, nullptr);
+ concat_node->values(1, nullptr);
+ }
+
+public:
+ loco::Graph g;
+ luci::CircleConcatenation *concat_node = nullptr;
+ luci::CircleConst *input_1 = nullptr;
+ luci::CircleConst *input_2 = nullptr;
+};
+
TEST(QuantizeWithMinMaxPassTest, name)
{
luci::QuantizeWithMinMaxPass pass(loco::DataType::FLOAT32, loco::DataType::U8,
@@ -25,3 +58,19 @@ TEST(QuantizeWithMinMaxPassTest, name)
auto const name = pass.name();
ASSERT_NE(nullptr, name);
}
+
+// Test concat of integer tensors
+// Integer tensors are not quantized
+TEST(QuantizeWithMinMaxPassTest, int_concat)
+{
+ SimpleConcatGraph g(loco::DataType::S32);
+
+ luci::QuantizeWithMinMaxPass qwmm(loco::DataType::FLOAT32, loco::DataType::U8,
+ luci::QuantizationGranularity::LayerWise);
+
+ qwmm.run(&g.g);
+
+ EXPECT_EQ(nullptr, g.concat_node->quantparam());
+ EXPECT_EQ(nullptr, g.input_1->quantparam());
+ EXPECT_EQ(nullptr, g.input_2->quantparam());
+}
diff --git a/compiler/luci/pass/src/QuantizedModelVerifier.cpp b/compiler/luci/pass/src/QuantizedModelVerifier.cpp
index f02301ed1..684d5d48a 100644
--- a/compiler/luci/pass/src/QuantizedModelVerifier.cpp
+++ b/compiler/luci/pass/src/QuantizedModelVerifier.cpp
@@ -15,10 +15,10 @@
#include "QuantizedModelVerifier.h"
-#include "VerifyQuantizedNodeLayerWiseGranularity.h"
-#include "VerifyQuantizedNodeChannelWiseGranularity.h"
-#include "VerifyQuantizedNodeU8Type.h"
-#include "VerifyQuantizedNodeS16Type.h"
+#include "VerifyQuantizedNodeGranularity.h"
+#include "VerifyQuantizedNodeType.h"
+#include "VerifyQuantizedBiasScale.h"
+#include "helpers/LayerInfoMap.h"
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
@@ -28,12 +28,33 @@ namespace luci
void QuantizedModelVerifier::verify(loco::Graph *g)
{
- if (_quantized_dtype != Type::U8 && _quantized_dtype != Type::S16)
- throw std::runtime_error("Unsupported quantized dtype");
-
- if (_granularity != Granularity::ChannelWise && _granularity != Granularity::LayerWise)
+ if (_ctx->granularity != Granularity::ChannelWise && _ctx->granularity != Granularity::LayerWise)
throw std::runtime_error("Unsupported granularity");
+ auto info_by_name = layer_info_map(g, _ctx->layers_info);
+
+ auto quantize_dtype = [&](const luci::CircleNode *node) {
+ auto iter = info_by_name.find(node->name());
+
+ // Return designated quantization dtype
+ if (iter != info_by_name.end())
+ return iter->second.dtype;
+
+ // Return default quantization dtype
+ return _ctx->output_model_dtype;
+ };
+
+ auto quantize_granularity = [&](const luci::CircleNode *node) {
+ auto iter = info_by_name.find(node->name());
+
+ // Return designated quantization granularity
+ if (iter != info_by_name.end())
+ return iter->second.granularity;
+
+ // Return default quantization granularity
+ return _ctx->granularity;
+ };
+
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
@@ -46,32 +67,17 @@ void QuantizedModelVerifier::verify(loco::Graph *g)
};
// Verify Type
- if (_quantized_dtype == Type::U8)
- {
- VerifyQuantizedNodeU8Type vt;
- if (!circle_node->accept(&vt))
- throw std::runtime_error("Wrong data type detected in " + node_name());
- }
- else if (_quantized_dtype == Type::S16)
- {
- VerifyQuantizedNodeS16Type vt;
- if (!circle_node->accept(&vt))
- throw std::runtime_error("Wrong data type detected in " + node_name());
- }
+ if (!VerifyQuantizedNodeType::create(quantize_dtype(circle_node))->verify(circle_node))
+ throw std::runtime_error("Wrong data type detected in " + node_name());
// Verify Granularity
- if (_granularity == Granularity::LayerWise)
- {
- VerifyQuantizedNodeLayerWiseGranularity vg;
- if (!circle_node->accept(&vg))
- throw std::runtime_error("Wrong granularity detected in " + node_name());
- }
- else if (_granularity == Granularity::ChannelWise)
- {
- VerifyQuantizedNodeChannelWiseGranularity vg;
- if (!circle_node->accept(&vg))
- throw std::runtime_error("Wrong granularity detected in " + node_name());
- }
+ if (!circle_node->accept(
+ VerifyQuantizedNodeGranularity::create(quantize_granularity(circle_node)).get()))
+ throw std::runtime_error("Wrong granularity detected in " + node_name());
+
+ // Verify Bias scale
+ if (!VerifyQuantizedBiasScale::create()->verify(circle_node))
+ throw std::runtime_error("Wrong bias scale detected in " + node_name());
}
}
diff --git a/compiler/luci/pass/src/QuantizedModelVerifier.h b/compiler/luci/pass/src/QuantizedModelVerifier.h
index d5fbb8e74..7409a51d7 100644
--- a/compiler/luci/pass/src/QuantizedModelVerifier.h
+++ b/compiler/luci/pass/src/QuantizedModelVerifier.h
@@ -21,6 +21,8 @@
#include <loco.h>
+#include <memory>
+
namespace luci
{
@@ -31,18 +33,40 @@ namespace luci
*/
struct QuantizedModelVerifier
{
+public:
+ struct Context
+ {
+ loco::DataType output_model_dtype = loco::DataType::Unknown;
+ QuantizationGranularity granularity = QuantizationGranularity::ChannelWise;
+ loco::DataType input_type = loco::DataType::Unknown;
+ loco::DataType output_type = loco::DataType::Unknown;
+ bool TF_style_maxpool = false;
+ std::vector<LayerInfo> layers_info;
+ };
public:
QuantizedModelVerifier(loco::DataType quantized_dtype, QuantizationGranularity granularity)
- : _quantized_dtype(quantized_dtype), _granularity(granularity)
{
+ _ctx = std::make_unique<Context>();
+ {
+ _ctx->output_model_dtype = quantized_dtype;
+ _ctx->granularity = granularity;
+ _ctx->input_type = quantized_dtype;
+ _ctx->output_type = quantized_dtype;
+ _ctx->TF_style_maxpool = false;
+ }
+ }
+
+public:
+ QuantizedModelVerifier(std::unique_ptr<Context> &&ctx) : _ctx{std::move(ctx)}
+ {
+ // DO NOTHING
}
void verify(loco::Graph *g);
private:
- loco::DataType _quantized_dtype;
- QuantizationGranularity _granularity;
+ std::unique_ptr<Context> _ctx;
};
} // namespace luci
diff --git a/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp b/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp
index 3a6d86c33..cebafd32b 100644
--- a/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp
+++ b/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp
@@ -17,6 +17,7 @@
#include "QuantizedModelVerifier.h"
#include "luci/Pass/QuantizeWithMinMaxPass.h"
+#include "luci/Pass/QuantizationParameters.h"
#include <luci/test/TestIOGraph.h>
@@ -112,57 +113,77 @@ void quantize_and_verify(loco::Graph *g, Type quantized_dtype, Granularity granu
verifier.verify(g);
}
-// Helper function to reduce duplicate test codes
-// Assumption: g->output()->from() is the target node
-void quantize_and_verify_with_wrong_type(luci::test::TestIOGraph *g, Type quantized_dtype,
- Granularity granularity, Type wrong_dtype)
+void quantize_and_verify_with_layer_info(loco::Graph *g, Type quantized_dtype,
+ Granularity granularity)
{
- luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity);
- pass.run(g->g());
-
- auto node = loco::must_cast<luci::CircleNode *>(g->output()->from());
- node->dtype(wrong_dtype);
+ // A layer named "test" has dtype different from quantized_dtype
+ luci::LayerInfo info;
+ {
+ info.name = "test";
+ // dtype is different from quantized_dtype
+ info.dtype = quantized_dtype == Type::U8 ? Type::S16 : Type::U8;
+ info.granularity = Granularity::ChannelWise;
+ }
- luci::QuantizedModelVerifier verifier(quantized_dtype, granularity);
- verifier.verify(g->g());
-}
+ // Do quantization
+ {
+ auto ctx = std::make_unique<luci::QuantizeWithMinMaxPass::Context>();
+ {
+ ctx->input_model_dtype = Type::FLOAT32;
+ ctx->output_model_dtype = quantized_dtype;
+ ctx->granularity = granularity;
+ ctx->input_type = quantized_dtype;
+ ctx->output_type = quantized_dtype;
+ ctx->TF_style_maxpool = false;
+ ctx->layers_info.push_back(info);
+ }
-void quantize_and_verify_with_wrong_type(luci::test::TestIOGraph *g, Type quantized_dtype,
- Granularity granularity, Type wrong_dtype,
- luci::CircleNode *target)
-{
- luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity);
- pass.run(g->g());
+ luci::QuantizeWithMinMaxPass pass(std::move(ctx));
+ pass.run(g);
+ }
- target->dtype(wrong_dtype);
+ // Do verification
+ {
+ auto ctx = std::make_unique<luci::QuantizedModelVerifier::Context>();
+ {
+ ctx->output_model_dtype = quantized_dtype;
+ ctx->granularity = granularity;
+ ctx->input_type = quantized_dtype;
+ ctx->output_type = quantized_dtype;
+ ctx->TF_style_maxpool = false;
+ ctx->layers_info.push_back(info);
+ }
- luci::QuantizedModelVerifier verifier(quantized_dtype, granularity);
- verifier.verify(g->g());
+ luci::QuantizedModelVerifier verifier(std::move(ctx));
+ verifier.verify(g);
+ }
}
// Helper function to reduce duplicate test codes
// Assumption: g->output()->from() is the target node
-void quantize_and_verify_with_wrong_granularity(luci::test::TestIOGraph *g, Type quantized_dtype,
- Granularity granularity)
+void quantize_and_verify_with_wrong_type(luci::test::TestIOGraph *g, Type quantized_dtype,
+ Granularity granularity, Type wrong_dtype)
{
luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity);
pass.run(g->g());
auto node = loco::must_cast<luci::CircleNode *>(g->output()->from());
- insert_scale_zp(node, 1.0, 1);
+ node->dtype(wrong_dtype);
luci::QuantizedModelVerifier verifier(quantized_dtype, granularity);
verifier.verify(g->g());
}
// Helper function to reduce duplicate test codes
+// Assumption: g->output()->from() is the target node
void quantize_and_verify_with_wrong_granularity(luci::test::TestIOGraph *g, Type quantized_dtype,
- Granularity granularity, luci::CircleNode *target)
+ Granularity granularity)
{
luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity);
pass.run(g->g());
- insert_scale_zp(target, 1.0, 1);
+ auto node = loco::must_cast<luci::CircleNode *>(g->output()->from());
+ insert_scale_zp(node, 1.0, 1);
luci::QuantizedModelVerifier verifier(quantized_dtype, granularity);
verifier.verify(g->g());
@@ -230,6 +251,8 @@ public:
_instnorm->input(input());
_instnorm->gamma(_gamma);
_instnorm->beta(_beta);
+ _instnorm->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _instnorm->name("test");
}
output()->from(_instnorm);
@@ -256,6 +279,7 @@ public:
_logistic = g()->nodes()->create<luci::CircleLogistic>();
{
_logistic->x(input());
+ _logistic->name("test");
}
output()->from(_logistic);
@@ -275,6 +299,7 @@ public:
_lrn = g()->nodes()->create<luci::CircleLocalResponseNormalization>();
{
_lrn->input(input());
+ _lrn->name("test");
}
output()->from(_lrn);
@@ -295,6 +320,7 @@ public:
{
_softmax->logits(input());
_softmax->beta(0.1);
+ _softmax->name("test");
}
output()->from(_softmax);
@@ -324,6 +350,7 @@ public:
_stob->input(input());
_stob->block_shape(_block_shape);
_stob->paddings(_paddings);
+ _stob->name("test");
}
output()->from(_stob);
@@ -346,6 +373,7 @@ public:
{
_stod->input(input());
_stod->block_size(2);
+ _stod->name("test");
}
output()->from(_stod);
@@ -375,6 +403,7 @@ public:
_slice->input(input());
_slice->begin(_begin);
_slice->size(_size);
+ _slice->name("test");
}
output()->from(_slice);
@@ -472,6 +501,7 @@ public:
_slice->begin(_begin);
_slice->end(_end);
_slice->strides(_strides);
+ _slice->name("test");
}
output()->from(_slice);
@@ -499,6 +529,7 @@ public:
{
_reshape->tensor(input());
_reshape->shape(_shape);
+ _reshape->name("test");
}
output()->from(_reshape);
@@ -519,6 +550,7 @@ public:
_tanh = g()->nodes()->create<luci::CircleTanh>();
{
_tanh->x(input());
+ _tanh->name("test");
}
output()->from(_tanh);
@@ -538,6 +570,7 @@ public:
_floor = g()->nodes()->create<luci::CircleFloor>();
{
_floor->x(input());
+ _floor->name("test");
}
output()->from(_floor);
@@ -601,6 +634,7 @@ public:
_btos->input(input());
_btos->block_shape(_block_shape);
_btos->crops(_crops);
+ _btos->name("test");
}
output()->from(_btos);
@@ -623,6 +657,7 @@ public:
{
_dtos->input(input());
_dtos->block_size(2);
+ _dtos->name("test");
}
output()->from(_dtos);
@@ -645,6 +680,7 @@ public:
_pack->values(0, input());
_pack->values(1, _param);
_pack->axis(0);
+ _pack->name("test");
}
output()->from(_pack);
@@ -680,6 +716,7 @@ public:
{
_pad->input(input());
_pad->paddings(_paddings);
+ _pad->name("test");
}
output()->from(_pad);
@@ -707,6 +744,7 @@ public:
_pad->input(input());
_pad->paddings(_paddings);
_pad->constant_values(_constant_values);
+ _pad->name("test");
}
output()->from(_pad);
@@ -735,6 +773,7 @@ public:
_mirror_pad->input(input());
_mirror_pad->paddings(_paddings);
_mirror_pad->mode(luci::MirrorPadMode::REFLECT);
+ _mirror_pad->name("test");
}
output()->from(_mirror_pad);
@@ -761,6 +800,7 @@ public:
{
_transpose->a(input());
_transpose->perm(_perm);
+ _transpose->name("test");
}
output()->from(_transpose);
@@ -784,6 +824,8 @@ public:
_concat->values(0, input());
_concat->values(1, _param);
_concat->axis(0);
+ _concat->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _concat->name("test");
}
output()->from(_concat);
@@ -795,6 +837,54 @@ private:
luci::CircleConst *_param = nullptr;
};
+template <Type indexT> class OneHotTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32, 10});
+ {
+ // input dtype is float by default, but OneHot's input should have indexType (s32/s64)
+ input()->dtype(indexT);
+ }
+
+ _depth = g()->nodes()->template create<luci::CircleConst>();
+ {
+ _depth->dtype(loco::DataType::S32);
+ }
+
+ _on_value = g()->nodes()->template create<luci::CircleConst>();
+ {
+ _on_value->dtype(loco::DataType::FLOAT32);
+ }
+
+ _off_value = g()->nodes()->template create<luci::CircleConst>();
+ {
+ _off_value->dtype(loco::DataType::FLOAT32);
+ }
+
+ _one_hot = g()->nodes()->template create<luci::CircleOneHot>();
+ {
+ _one_hot->indices(input());
+ _one_hot->depth(_depth);
+ _one_hot->on_value(_on_value);
+ _one_hot->off_value(_off_value);
+ _one_hot->axis(-1);
+ _one_hot->dtype(loco::DataType::FLOAT32);
+ _one_hot->name("test");
+ }
+ output()->from(_one_hot);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+private:
+ luci::CircleOneHot *_one_hot = nullptr;
+ luci::CircleConst *_depth = nullptr;
+ luci::CircleConst *_on_value = nullptr;
+ luci::CircleConst *_off_value = nullptr;
+};
+
// Test graph for comparison Ops
// GREATER, GREATER_EQUAL, LESS, LESS_EQUAL, EQUAL, NOT_EQUAL
template <class Op> class ComparisonOpTestGraph final : public SimpleTestGraph
@@ -866,6 +956,7 @@ public:
{
_div->x(input());
_div->y(_const);
+ _div->name("test");
}
output()->from(_div);
@@ -893,6 +984,7 @@ public:
{
_floor_div->x(input());
_floor_div->y(_const);
+ _floor_div->name("test");
}
output()->from(_floor_div);
@@ -917,6 +1009,7 @@ public:
_rsqrt = g()->nodes()->create<luci::CircleRsqrt>();
{
_rsqrt->x(input());
+ _rsqrt->name("test");
}
output()->from(_rsqrt);
@@ -936,6 +1029,7 @@ public:
_sqrt = g()->nodes()->create<luci::CircleSqrt>();
{
_sqrt->x(input());
+ _sqrt->name("test");
}
output()->from(_sqrt);
@@ -955,6 +1049,7 @@ public:
_elu = g()->nodes()->create<luci::CircleElu>();
{
_elu->features(input());
+ _elu->name("test");
}
output()->from(_elu);
@@ -977,6 +1072,7 @@ public:
{
_pow->x(input());
_pow->y(_const);
+ _pow->name("test");
}
output()->from(_pow);
@@ -1004,6 +1100,7 @@ public:
{
_resize_bilinear->input(input());
_resize_bilinear->size(_size);
+ _resize_bilinear->name("test");
}
output()->from(_resize_bilinear);
@@ -1027,6 +1124,7 @@ public:
{
_resize_nearest_neighbor->input(input());
_resize_nearest_neighbor->size(_size);
+ _resize_nearest_neighbor->name("test");
}
output()->from(_resize_nearest_neighbor);
@@ -1067,6 +1165,62 @@ private:
luci::CircleConst *_unpack_dim = nullptr;
};
+class MulTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+
+ _const = create_dummy_const<Type::FLOAT32>(g(), {32});
+ _mul = g()->nodes()->create<luci::CircleMul>();
+ {
+ _mul->x(input());
+ _mul->y(_const);
+ _mul->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _mul->name("test");
+ }
+ output()->from(_mul);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+ loco::Node *x() { return _mul->x(); }
+ loco::Node *y() { return _mul->y(); }
+
+private:
+ luci::CircleMul *_mul = nullptr;
+ luci::CircleConst *_const = nullptr;
+};
+
+class AddTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+
+ _const = create_dummy_const<Type::FLOAT32>(g(), {32});
+ _add = g()->nodes()->create<luci::CircleAdd>();
+ {
+ _add->x(input());
+ _add->y(_const);
+ _add->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _add->name("test");
+ }
+ output()->from(_add);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+ loco::Node *x() { return _add->x(); }
+ loco::Node *y() { return _add->y(); }
+
+private:
+ luci::CircleAdd *_add = nullptr;
+ luci::CircleConst *_const = nullptr;
+};
+
} // namespace
// Quantize and verify with given configurations
@@ -1078,6 +1232,15 @@ private:
EXPECT_NO_THROW(quantize_and_verify(g.g(), type, granularity)); \
} while (0)
+// Quantize and verify with layer info
+#define TEST_WITH_LAYER_INFO(graph, type, granularity) \
+ do \
+ { \
+ graph g; \
+ g.init(); \
+ EXPECT_NO_THROW(quantize_and_verify_with_layer_info(g.g(), type, granularity)); \
+ } while (0)
+
// Quantize and verify with wrong type
#define TEST_WITH_WRONG_TYPE(graph, type, granularity, wrong_dtype) \
do \
@@ -1098,25 +1261,34 @@ private:
// Quantize and verify with wrong type
// Users can specify the test target
-#define TEST_WITH_WRONG_TYPE_TARGET(graph, type, granularity, wrong_dtype, target) \
- do \
- { \
- graph g; \
- g.init(); \
- auto node = loco::must_cast<luci::CircleNode *>(target); \
- EXPECT_ANY_THROW( \
- quantize_and_verify_with_wrong_type(&g, type, granularity, wrong_dtype, node)); \
+#define TEST_WITH_WRONG_TYPE_TARGET(graph, type, granularity, wrong_dtype, target) \
+ do \
+ { \
+ graph g; \
+ g.init(); \
+ auto node = loco::must_cast<luci::CircleNode *>(target); \
+ luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, type, granularity); \
+ pass.run(g.g()); \
+ auto after_node = loco::must_cast<luci::CircleNode *>(target); \
+ after_node->dtype(wrong_dtype); \
+ luci::QuantizedModelVerifier verifier(type, granularity); \
+ EXPECT_ANY_THROW(verifier.verify(g.g())); \
} while (0)
// Quantize and verify with wrong granularity
// Users can specify the test target
-#define TEST_WITH_WRONG_GRANULARITY_TARGET(graph, type, granularity, target) \
- do \
- { \
- graph g; \
- g.init(); \
- auto node = loco::must_cast<luci::CircleNode *>(target); \
- EXPECT_ANY_THROW(quantize_and_verify_with_wrong_granularity(&g, type, granularity, node)); \
+#define TEST_WITH_WRONG_GRANULARITY_TARGET(graph, type, granularity, target) \
+ do \
+ { \
+ graph g; \
+ g.init(); \
+ auto node = loco::must_cast<luci::CircleNode *>(target); \
+ luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, type, granularity); \
+ pass.run(g.g()); \
+ auto after_node = loco::must_cast<luci::CircleNode *>(target); \
+ insert_scale_zp(after_node, 1.0, 1); \
+ luci::QuantizedModelVerifier verifier(type, granularity); \
+ EXPECT_ANY_THROW(verifier.verify(g.g())); \
} while (0)
// Test a local helper function
@@ -1145,6 +1317,10 @@ TEST(QuantizedModelVerifierTest, InstanceNorm)
TEST_WITH_GRAPH(InstanceNormTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(InstanceNormTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(InstanceNormTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(InstanceNormTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(InstanceNormTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(InstanceNormTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1169,6 +1345,10 @@ TEST(QuantizedModelVerifierTest, LocalResponseNormalization)
TEST_WITH_GRAPH(LocalResponseNormalizationTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(LocalResponseNormalizationTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(LocalResponseNormalizationTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(LocalResponseNormalizationTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(LocalResponseNormalizationTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(LocalResponseNormalizationTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1199,6 +1379,10 @@ TEST(QuantizedModelVerifierTest, Logistic)
TEST_WITH_GRAPH(LogisticTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(LogisticTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(LogisticTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(LogisticTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(LogisticTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(LogisticTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1223,6 +1407,10 @@ TEST(QuantizedModelVerifierTest, Softmax)
TEST_WITH_GRAPH(SoftmaxTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(SoftmaxTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(SoftmaxTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(SoftmaxTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(SoftmaxTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(SoftmaxTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1247,6 +1435,10 @@ TEST(QuantizedModelVerifierTest, SpaceToBatchND)
TEST_WITH_GRAPH(SpaceToBatchNDTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(SpaceToBatchNDTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(SpaceToBatchNDTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(SpaceToBatchNDTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(SpaceToBatchNDTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(SpaceToBatchNDTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1271,6 +1463,10 @@ TEST(QuantizedModelVerifierTest, SpaceToDepth)
TEST_WITH_GRAPH(SpaceToDepthTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(SpaceToDepthTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(SpaceToDepthTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(SpaceToDepthTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(SpaceToDepthTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(SpaceToDepthTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1299,6 +1495,14 @@ TEST(QuantizedModelVerifierTest, Slice)
TEST_WITH_GRAPH(SliceTestGraph<Type::S64>, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(SliceTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(SliceTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(SliceTestGraph<Type::S32>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(SliceTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(SliceTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(SliceTestGraph<Type::S64>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(SliceTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(SliceTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1379,6 +1583,10 @@ TEST(QuantizedModelVerifierTest, StridedSlice)
TEST_WITH_GRAPH(StridedSliceTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(StridedSliceTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(StridedSliceTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(StridedSliceTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(StridedSliceTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(StridedSliceTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1463,6 +1671,10 @@ TEST(QuantizedModelVerifierTest, BatchToSpaceND)
TEST_WITH_GRAPH(BatchToSpaceNDTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(BatchToSpaceNDTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(BatchToSpaceNDTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(BatchToSpaceNDTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(BatchToSpaceNDTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(BatchToSpaceNDTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1487,6 +1699,10 @@ TEST(QuantizedModelVerifierTest, DepthToSpace)
TEST_WITH_GRAPH(DepthToSpaceTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(DepthToSpaceTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(DepthToSpaceTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(DepthToSpaceTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(DepthToSpaceTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(DepthToSpaceTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1511,6 +1727,10 @@ TEST(QuantizedModelVerifierTest, Concatenation)
TEST_WITH_GRAPH(ConcatenationTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(ConcatenationTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(ConcatenationTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(ConcatenationTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(ConcatenationTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(ConcatenationTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1557,6 +1777,10 @@ TEST(QuantizedModelVerifierTest, Reshape)
TEST_WITH_GRAPH(ReshapeTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(ReshapeTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(ReshapeTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(ReshapeTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(ReshapeTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(ReshapeTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1581,6 +1805,10 @@ TEST(QuantizedModelVerifierTest, Tanh)
TEST_WITH_GRAPH(TanhTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(TanhTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(TanhTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(TanhTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(TanhTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(TanhTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1606,6 +1834,10 @@ TEST(QuantizedModelVerifierTest, Pack)
TEST_WITH_GRAPH(PackTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(PackTestGraph, Type::S16, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(PackTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(PackTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(PackTestGraph, Type::S16, Granularity::ChannelWise);
+
// Test if Pack's qparam is propagated to the input
{
PackTestGraph g;
@@ -1640,6 +1872,10 @@ TEST(QuantizedModelVerifierTest, Pad)
TEST_WITH_GRAPH(PadTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(PadTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(PadTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(PadTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(PadTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(PadTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1664,6 +1900,10 @@ TEST(QuantizedModelVerifierTest, PadV2)
TEST_WITH_GRAPH(PadV2TestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(PadV2TestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(PadV2TestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(PadV2TestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(PadV2TestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(PadV2TestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1688,6 +1928,10 @@ TEST(QuantizedModelVerifierTest, MirrorPad)
TEST_WITH_GRAPH(MirrorPadTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(MirrorPadTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(MirrorPadTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(MirrorPadTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(MirrorPadTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(MirrorPadTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1712,6 +1956,10 @@ TEST(QuantizedModelVerifierTest, Transpose)
TEST_WITH_GRAPH(TransposeTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(TransposeTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(TransposeTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(TransposeTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(TransposeTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(TransposeTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1736,6 +1984,10 @@ TEST(QuantizedModelVerifierTest, Floor)
TEST_WITH_GRAPH(FloorTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(FloorTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(FloorTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(FloorTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(FloorTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(FloorTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1869,11 +2121,59 @@ TEST(QuantizedModelVerifierTest, NotEqual_wrong_granularity_NEG)
SUCCEED();
}
+TEST(QuantizedModelVerifierTest, OneHot)
+{
+ TEST_WITH_GRAPH(OneHotTestGraph<Type::S32>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(OneHotTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(OneHotTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_GRAPH(OneHotTestGraph<Type::S64>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(OneHotTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(OneHotTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(OneHotTestGraph<Type::S32>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(OneHotTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(OneHotTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(OneHotTestGraph<Type::S64>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(OneHotTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(OneHotTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, OneHot_wrong_input_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(OneHotTestGraph<Type::S32>, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(OneHotTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(OneHotTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise, Type::U8);
+
+ TEST_WITH_WRONG_TYPE(OneHotTestGraph<Type::S64>, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(OneHotTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(OneHotTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, OneHot_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(OneHotTestGraph<Type::S32>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(OneHotTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(OneHotTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_WRONG_GRANULARITY(OneHotTestGraph<Type::S64>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(OneHotTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(OneHotTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
TEST(QuantizedModelVerifierTest, Div)
{
TEST_WITH_GRAPH(DivTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(DivTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(DivTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(DivTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(DivTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(DivTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1902,6 +2202,10 @@ TEST(QuantizedModelVerifierTest, FloorDiv)
TEST_WITH_GRAPH(FloorDivTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(FloorDivTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(FloorDivTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(FloorDivTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(FloorDivTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(FloorDivTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1930,6 +2234,10 @@ TEST(QuantizedModelVerifierTest, Rsqrt)
TEST_WITH_GRAPH(RsqrtTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(RsqrtTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(RsqrtTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(RsqrtTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(RsqrtTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(RsqrtTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1954,6 +2262,10 @@ TEST(QuantizedModelVerifierTest, Sqrt)
TEST_WITH_GRAPH(SqrtTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(SqrtTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(SqrtTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(SqrtTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(SqrtTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(SqrtTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1978,6 +2290,10 @@ TEST(QuantizedModelVerifierTest, Elu)
TEST_WITH_GRAPH(EluTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(EluTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(EluTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(EluTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(EluTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(EluTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -2002,6 +2318,10 @@ TEST(QuantizedModelVerifierTest, Pow)
TEST_WITH_GRAPH(PowTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(PowTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(PowTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(PowTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(PowTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(PowTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -2030,6 +2350,10 @@ TEST(QuantizedModelVerifierTest, ResizeBilinear)
TEST_WITH_GRAPH(ResizeBilinearTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(ResizeBilinearTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(ResizeBilinearTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(ResizeBilinearTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(ResizeBilinearTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(ResizeBilinearTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -2054,6 +2378,10 @@ TEST(QuantizedModelVerifierTest, ResizeNearestNeighbor)
TEST_WITH_GRAPH(ResizeNearestNeighborTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(ResizeNearestNeighborTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(ResizeNearestNeighborTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(ResizeBilinearTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(ResizeBilinearTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(ResizeBilinearTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -2099,6 +2427,93 @@ TEST(QuantizedModelVerifierTest, Unpack_wrong_granularity_NEG)
SUCCEED();
}
+TEST(QuantizedModelVerifierTest, Add)
+{
+ TEST_WITH_GRAPH(AddTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(AddTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(AddTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(AddTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(AddTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(AddTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Add_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(AddTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(AddTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(AddTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Add_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY_TARGET(AddTestGraph, Type::U8, Granularity::LayerWise, g.x());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(AddTestGraph, Type::U8, Granularity::ChannelWise, g.x());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(AddTestGraph, Type::S16, Granularity::ChannelWise, g.x());
+
+ TEST_WITH_WRONG_GRANULARITY_TARGET(AddTestGraph, Type::U8, Granularity::LayerWise, g.y());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(AddTestGraph, Type::U8, Granularity::ChannelWise, g.y());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(AddTestGraph, Type::S16, Granularity::ChannelWise, g.y());
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Mul)
+{
+ TEST_WITH_GRAPH(MulTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(MulTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(MulTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(MulTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(MulTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(MulTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Mul_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(MulTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(MulTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(MulTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Mul_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY_TARGET(MulTestGraph, Type::U8, Granularity::LayerWise, g.x());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(MulTestGraph, Type::U8, Granularity::ChannelWise, g.x());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(MulTestGraph, Type::S16, Granularity::ChannelWise, g.x());
+
+ TEST_WITH_WRONG_GRANULARITY_TARGET(MulTestGraph, Type::U8, Granularity::LayerWise, g.y());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(MulTestGraph, Type::U8, Granularity::ChannelWise, g.y());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(MulTestGraph, Type::S16, Granularity::ChannelWise, g.y());
+ SUCCEED();
+}
+
+// TODO Add following testcases
+//
+// CircleConv2D
+//
+// CircleDepthwiseConv2D
+//
+// CirclePRelu
+//
+// CircleTransposeConv
+//
+// CircleFullyConnected
+//
+// CircleAveragePool2D
+//
+// CircleMaxPool2D
+//
+// CircleMean
+//
+// CircleRelu
+//
+// CircleCast
+//
+
#undef TEST_WITH_GRAPH
#undef TEST_WITH_WRONG_TYPE
#undef TEST_WITH_WRONG_GRANULARITY
diff --git a/compiler/luci/pass/src/RemoveRedundantQuantizePass.cpp b/compiler/luci/pass/src/RemoveRedundantQuantizePass.cpp
new file mode 100644
index 000000000..8a10ad4a0
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveRedundantQuantizePass.cpp
@@ -0,0 +1,104 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/RemoveRedundantQuantizePass.h"
+
+#include <luci/IR/CircleNode.h>
+
+/**
+ * Remove redundant quantize operations. For subsequent Quantize Ops,
+ * only the last Quantize Op is valid, so we can remove the rest of the Quantize Op.
+ *
+ * BEFORE
+ * [CircleNode_1]
+ * |
+ * [CircleQuantize, dtype_1, scale_1, zero_point_1]
+ * |
+ * [CircleQuantize, dtype_2, scale_2, zero_point_2]
+ * |
+ * [CircleNode_2]
+ *
+ * AFTER
+ * [CircleNode_1]
+ * / \
+ * / \
+ * / \
+ * / \
+ * / \
+ * [CircleQuantize, dtype_2, scale_2, zero_point_2] [CircleQuantize, dtype_1, scale_1, zero_point_1]
+ * |
+ * [CircleNode_2]
+ *
+ */
+
+namespace
+{
+
+bool remove_redundant_quantize(luci::CircleQuantize *node)
+{
+ auto pred_node = loco::must_cast<luci::CircleNode *>(node->input());
+
+ if (node->quantparam() == nullptr or pred_node->quantparam() == nullptr)
+ return false;
+
+ if (node->quantparam()->scale.size() != 1 or node->quantparam()->zerop.size() != 1 or
+ pred_node->quantparam()->scale.size() != 1 or pred_node->quantparam()->zerop.size() != 1)
+ {
+ return false;
+ }
+
+ if (node->dtype() != pred_node->dtype() or
+ pred_node->quantparam()->scale.at(0) != node->quantparam()->scale.at(0) or
+ pred_node->quantparam()->zerop.at(0) != node->quantparam()->zerop.at(0))
+ {
+ return false;
+ }
+
+ replace(node).with(pred_node);
+
+ return true;
+}
+
+bool remove_redundant_subsequent_quantize(luci::CircleQuantize *node)
+{
+ auto pred_node = dynamic_cast<luci::CircleQuantize *>(node->input());
+ if (pred_node == nullptr)
+ return remove_redundant_quantize(node);
+
+ node->input(pred_node->input());
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool RemoveRedundantQuantizePass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
+ {
+ if (auto quantize_node = dynamic_cast<luci::CircleQuantize *>(node))
+ {
+ if (remove_redundant_subsequent_quantize(quantize_node))
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/RemoveRedundantQuantizePass.test.cpp b/compiler/luci/pass/src/RemoveRedundantQuantizePass.test.cpp
new file mode 100644
index 000000000..d0166bd20
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveRedundantQuantizePass.test.cpp
@@ -0,0 +1,166 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/RemoveRedundantQuantizePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <luci/test/TestIOGraph.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class QuantizeGraphlet
+{
+public:
+ QuantizeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g)
+ {
+ _first_quantize = g->nodes()->create<luci::CircleQuantize>();
+ _first_quantize->dtype(loco::DataType::U8);
+ {
+ auto quantize_param = std::make_unique<luci::CircleQuantParam>();
+ quantize_param->scale = {0.5};
+ quantize_param->zerop = {0};
+ _first_quantize->quantparam(std::move(quantize_param));
+ }
+ _first_quantize->name("first_quantize");
+
+ _second_quantize = g->nodes()->create<luci::CircleQuantize>();
+ _second_quantize->dtype(loco::DataType::U8);
+ {
+ auto quantize_param = std::make_unique<luci::CircleQuantParam>();
+ quantize_param->scale = {0.5};
+ quantize_param->zerop = {0};
+ _second_quantize->quantparam(std::move(quantize_param));
+ }
+ _second_quantize->name("second_quantize");
+ }
+
+protected:
+ luci::CircleQuantize *_first_quantize = nullptr;
+ luci::CircleQuantize *_second_quantize = nullptr;
+};
+
+class RedundantSubsequentQuantizeGraph : public TestIOGraph, public QuantizeGraphlet
+{
+public:
+ RedundantSubsequentQuantizeGraph() = default;
+
+public:
+ void init(void)
+ {
+ TestIOGraph::init({1}, {1});
+ QuantizeGraphlet::init(g());
+
+ input()->dtype(loco::DataType::U8);
+ {
+ auto quantize_param = std::make_unique<luci::CircleQuantParam>();
+ quantize_param->scale = {1};
+ quantize_param->zerop = {1};
+ input()->quantparam(std::move(quantize_param));
+ }
+
+ _first_quantize->input(input());
+ _second_quantize->input(_first_quantize);
+
+ output()->from(_second_quantize);
+ output()->dtype(loco::DataType::U8);
+ }
+};
+
+class RedundantQuantizeGraph : public TestIOGraph, public QuantizeGraphlet
+{
+public:
+ RedundantQuantizeGraph() = default;
+
+public:
+ void init(void)
+ {
+ TestIOGraph::init({1}, {1});
+ QuantizeGraphlet::init(g());
+
+ input()->dtype(loco::DataType::U8);
+ {
+ auto quantize_param = std::make_unique<luci::CircleQuantParam>();
+ quantize_param->scale = {0.5};
+ quantize_param->zerop = {0};
+ input()->quantparam(std::move(quantize_param));
+ }
+
+ _first_quantize->input(input());
+
+ output()->from(_first_quantize);
+ output()->dtype(loco::DataType::U8);
+ }
+};
+
+} // namespace
+
+TEST(RemoveRedundantQuantizePass, name)
+{
+ luci::RemoveRedundantQuantizePass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST(RemoveRedundantQuantizePass, remove_subsequent_quantize)
+{
+ RedundantSubsequentQuantizeGraph g;
+ luci::RemoveRedundantQuantizePass pass;
+
+ g.init();
+
+ EXPECT_TRUE(pass.run(g.g()));
+
+ int count = 0;
+ for (auto node : loco::active_nodes(loco::output_nodes(g.g())))
+ {
+ if (dynamic_cast<luci::CircleQuantize *>(node))
+ {
+ count++;
+ }
+ }
+
+ ASSERT_EQ(1, count);
+}
+
+TEST(RemoveRedundantQuantizePass, remove_quantize)
+{
+ RedundantQuantizeGraph g;
+ luci::RemoveRedundantQuantizePass pass;
+
+ g.init();
+
+ EXPECT_TRUE(pass.run(g.g()));
+
+ int count = 0;
+ for (auto node : loco::active_nodes(loco::output_nodes(g.g())))
+ {
+ if (dynamic_cast<luci::CircleQuantize *>(node))
+ {
+ count++;
+ }
+ }
+
+ ASSERT_EQ(0, count);
+}
diff --git a/compiler/luci/pass/src/RemoveRedundantTransposePass.cpp b/compiler/luci/pass/src/RemoveRedundantTransposePass.cpp
index 71c51ecda..75cf72795 100644
--- a/compiler/luci/pass/src/RemoveRedundantTransposePass.cpp
+++ b/compiler/luci/pass/src/RemoveRedundantTransposePass.cpp
@@ -71,7 +71,7 @@ bool remove_consecutive_transpose_function(luci::CircleTranspose *target_node)
for (uint32_t i = 0; i < pred_perm->size<loco::DataType::S32>(); i++)
{
new_const_node->at<loco::DataType::S32>(i) =
- target_perm->at<loco::DataType::S32>(pred_perm->at<loco::DataType::S32>(i));
+ pred_perm->at<loco::DataType::S32>(target_perm->at<loco::DataType::S32>(i));
}
new_const_node->name(name + "/Transpose/perm");
diff --git a/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp b/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp
index e80623499..bb8e292d4 100644
--- a/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp
+++ b/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp
@@ -271,6 +271,31 @@ TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type2)
ASSERT_EQ(2, perm->at<loco::DataType::S32>(3));
}
+TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type3)
+{
+ auto graph = loco::make_graph();
+ create_redundunt_transpose(graph.get(), {0, 3, 2, 1}, {0, 2, 3, 1});
+
+ luci::RemoveRedundantTransposePass pass;
+ while (pass.run(graph.get()))
+ ;
+ luci::CircleTranspose *transpose_node = nullptr;
+ for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
+ {
+ auto trans = dynamic_cast<luci::CircleTranspose *>(node);
+ if (not trans)
+ continue;
+ transpose_node = trans;
+ break;
+ }
+ ASSERT_NE(nullptr, transpose_node);
+ auto perm = loco::must_cast<luci::CircleConst *>(transpose_node->perm());
+ ASSERT_EQ(0, perm->at<loco::DataType::S32>(0));
+ ASSERT_EQ(2, perm->at<loco::DataType::S32>(1));
+ ASSERT_EQ(1, perm->at<loco::DataType::S32>(2));
+ ASSERT_EQ(3, perm->at<loco::DataType::S32>(3));
+}
+
/**
* @brief Test case that first transpose output become input of operations more than one.
*/
diff --git a/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.cpp b/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.cpp
index 3f0c4ee82..fb46f490d 100644
--- a/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.cpp
+++ b/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.cpp
@@ -58,6 +58,25 @@ bool remove_no_effect_reshape(luci::CircleNode *node)
namespace luci
{
+/**
+ * BEFORE
+ * [CircleNode]
+ * |
+ * [CircleReshape]
+ * |
+ * [CircleNode]
+ *
+ * AFTER
+ * [CircleNode]
+ * | \
+ * | [CircleReshape]
+ * |
+ * [CircleNode]
+ *
+ * NOTE
+ * This pass will remove Reshape when input and output has same shape
+ */
+
bool RemoveUnnecessaryReshapePass::run(loco::Graph *g)
{
bool changed = false;
diff --git a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp
index a0cc0194f..bca0a9483 100644
--- a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp
+++ b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp
@@ -26,8 +26,17 @@ namespace
luci::CircleConst *create_weights_from_gamma(luci::CircleConst *gamma)
{
- assert(gamma->rank() == 1);
- auto channel_size = gamma->dim(0).value();
+ assert(gamma->rank() == 1 or gamma->rank() == 4);
+
+ uint32_t channel_idx = gamma->rank() - 1;
+ uint32_t channel_size = gamma->dim(channel_idx).value();
+
+ // Gamma should be broadcastable in the channel direction
+ for (uint32_t i = 0; i < gamma->rank(); i++)
+ {
+ if (i != channel_idx)
+ assert(gamma->dim(i).value() == 1); // FIX is_batchnorm_mul UNLESS
+ }
auto name = gamma->name();
assert(name.length() > 0);
@@ -53,8 +62,17 @@ luci::CircleConst *create_weights_from_gamma(luci::CircleConst *gamma)
luci::CircleConst *create_bias_from_beta(luci::CircleConst *beta)
{
- assert(beta->rank() == 1);
- auto channel_size = beta->dim(0).value();
+ assert(beta->rank() == 1 or beta->rank() == 4);
+
+ uint32_t channel_idx = beta->rank() - 1;
+ uint32_t channel_size = beta->dim(channel_idx).value();
+
+ // Beta should be broadcastable in the channel direction
+ for (uint32_t i = 0; i < beta->rank(); i++)
+ {
+ if (i != channel_idx)
+ assert(beta->dim(i).value() == 1); // FIX is_batchnorm_add UNLESS
+ }
auto name = beta->name();
assert(name.length() > 0);
diff --git a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp
index 903d4dcc9..bac033112 100644
--- a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp
+++ b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp
@@ -141,6 +141,37 @@ TEST(ReplaceMulAddWithDepthwiseConv, simple)
}
}
+TEST(ReplaceMulAddWithDepthwiseConv, simple_rank4)
+{
+ SimpleGraph g;
+
+ const uint32_t channel_size = 16;
+ g.gamma->shape({1, 1, 1, channel_size});
+ g.beta->shape({1, 1, 1, channel_size});
+
+ luci::ReplaceMulAddWithDepthwiseConvPass pass;
+ while (pass.run(&g.g))
+ ;
+
+ auto dwconv = dynamic_cast<luci::CircleDepthwiseConv2D *>(g.output->from());
+ EXPECT_NE(nullptr, dwconv);
+
+ auto weights = dynamic_cast<luci::CircleConst *>(dwconv->filter());
+ auto bias = dynamic_cast<luci::CircleConst *>(dwconv->bias());
+ EXPECT_NE(nullptr, weights);
+ EXPECT_EQ(4, weights->rank());
+ EXPECT_EQ(channel_size, weights->dim(3).value());
+ EXPECT_NE(nullptr, bias);
+ EXPECT_EQ(1, bias->rank());
+ EXPECT_EQ(channel_size, bias->dim(0).value());
+
+ for (int i = 0; i < channel_size; i++)
+ {
+ EXPECT_FLOAT_EQ(i, weights->at<loco::DataType::FLOAT32>(i));
+ EXPECT_FLOAT_EQ(i, bias->at<loco::DataType::FLOAT32>(i));
+ }
+}
+
TEST(ReplaceMulAddWithDepthwiseConv, wrong_op_NEG)
{
SimpleGraph g;
@@ -154,3 +185,18 @@ TEST(ReplaceMulAddWithDepthwiseConv, wrong_op_NEG)
EXPECT_EQ(false, changed);
}
+
+TEST(ReplaceMulAddWithDepthwiseConv, rank3_NEG)
+{
+ SimpleGraph g;
+
+ g.input->shape({4, 4, 16});
+ g.mul->shape({4, 4, 16});
+ g.add->shape({4, 4, 16});
+ g.output->shape({4, 4, 16});
+
+ luci::ReplaceMulAddWithDepthwiseConvPass pass;
+ auto changed = pass.run(&g.g);
+
+ EXPECT_EQ(false, changed);
+}
diff --git a/compiler/luci/pass/src/SubstituteSplitVToSplitPass.cpp b/compiler/luci/pass/src/SubstituteSplitVToSplitPass.cpp
index 9cba9a9e7..57c386d99 100644
--- a/compiler/luci/pass/src/SubstituteSplitVToSplitPass.cpp
+++ b/compiler/luci/pass/src/SubstituteSplitVToSplitPass.cpp
@@ -24,15 +24,6 @@
namespace
{
-void copy_quantparam(luci::CircleNode *dst, const luci::CircleNode *src)
-{
- auto q = src->quantparam();
- if (q == nullptr)
- dst->quantparam(nullptr);
- else
- dst->quantparam(std::make_unique<luci::CircleQuantParam>(*q));
-}
-
// SplitV is substituted to Split if the contents of size_splits are all same
// For example,
// size_splits = [32, 32] -> substitute
@@ -67,7 +58,7 @@ bool resolve_splitv(luci::CircleSplitV *sv)
split_node->split_dim(sv->split_dim());
split_node->num_split(sv->num_split());
split_node->name(sv->name());
- copy_quantparam(split_node, sv);
+ copy_quantparam(sv, split_node);
luci::add_origin(split_node, luci::get_origin(sv));
auto succs = loco::succs(sv);
@@ -78,7 +69,7 @@ bool resolve_splitv(luci::CircleSplitV *sv)
so_node->input(split_node);
so_node->index(svo->index());
so_node->name(svo->name());
- copy_quantparam(so_node, svo);
+ copy_quantparam(svo, so_node);
luci::add_origin(so_node, luci::get_origin(svo));
replace(svo).with(so_node);
diff --git a/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp b/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp
index f48763782..df7266df9 100644
--- a/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp
+++ b/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp
@@ -76,18 +76,6 @@ std::vector<uint32_t> node_shape(const luci::CircleNode *input)
}
/**
- * @brief copy quantparam of src to dst
- */
-void copy_quantparam(luci::CircleNode *dst, const luci::CircleNode *src)
-{
- auto q = src->quantparam();
- if (q == nullptr)
- dst->quantparam(nullptr);
- else
- dst->quantparam(std::make_unique<luci::CircleQuantParam>(*q));
-}
-
-/**
* @brief return CircleConst ptr with values of new_shape
*/
luci::CircleConst *create_shape_const(loco::Graph *graph, const std::vector<uint32_t> &new_shape)
@@ -142,7 +130,7 @@ bool substitute_squeeze_to_reshape(luci::CircleSqueeze *squeeze)
auto graph = squeeze->graph();
auto reshape = graph->nodes()->create<luci::CircleReshape>();
auto shape_const = create_shape_const(graph, reshape_shape);
- copy_quantparam(reshape, squeeze);
+ copy_quantparam(squeeze, reshape);
reshape->name(name + "/Reshape");
luci::add_origin(reshape, luci::get_origin(squeeze));
shape_const->name(name + "/Reshape/shape");
diff --git a/compiler/luci/pass/src/SubstituteStridedSliceToReshapePass.cpp b/compiler/luci/pass/src/SubstituteStridedSliceToReshapePass.cpp
index f50f2f54f..9e1c5a4a3 100644
--- a/compiler/luci/pass/src/SubstituteStridedSliceToReshapePass.cpp
+++ b/compiler/luci/pass/src/SubstituteStridedSliceToReshapePass.cpp
@@ -124,7 +124,7 @@ bool substitute_strided_slice_to_reshape(luci::CircleStridedSlice *ss_node)
std::bitset<32> end_mask(ss_node->end_mask());
std::bitset<32> shrink_axis_mask(ss_node->shrink_axis_mask());
- uint input_rank = input_node->rank();
+ uint32_t input_rank = input_node->rank();
for (uint32_t i = 0; i < input_rank; i++)
{
if (!input_node->dim(i).known())
diff --git a/compiler/luci/pass/src/VerifyQuantizedBiasScale.cpp b/compiler/luci/pass/src/VerifyQuantizedBiasScale.cpp
new file mode 100644
index 000000000..e65d576cd
--- /dev/null
+++ b/compiler/luci/pass/src/VerifyQuantizedBiasScale.cpp
@@ -0,0 +1,105 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "VerifyQuantizedBiasScale.h"
+
+#include <cmath>
+
+// This macro is undef at the end of the file
+#define RETURN_FALSE_UNLESS(ARG) \
+ if (not(ARG)) \
+ { \
+ return false; \
+ }
+
+namespace
+{
+
+bool same(float a, float b)
+{
+ constexpr float epsilon = 1e-10;
+ return abs(a - b) < epsilon;
+}
+
+// Check bias scale = input scale * weight scale
+// This function checks both LWQ and CWQ
+bool check_bias_scale(const loco::Node *input, const loco::Node *weights, const loco::Node *bias)
+{
+ auto input_node = loco::must_cast<const luci::CircleNode *>(input);
+ auto input_qparam = input_node->quantparam();
+ RETURN_FALSE_UNLESS(input_qparam != nullptr);
+
+ auto weights_node = loco::must_cast<const luci::CircleNode *>(weights);
+ auto weights_qparam = weights_node->quantparam();
+ RETURN_FALSE_UNLESS(weights_qparam != nullptr);
+
+ auto bias_node = loco::must_cast<const luci::CircleNode *>(bias);
+ auto bias_qparam = bias_node->quantparam();
+ RETURN_FALSE_UNLESS(bias_qparam != nullptr);
+
+ RETURN_FALSE_UNLESS(input_qparam->scale.size() == 1);
+ RETURN_FALSE_UNLESS(weights_qparam->scale.size() == bias_qparam->scale.size());
+
+ auto input_scale = input_qparam->scale[0];
+ for (uint32_t i = 0; i < weights_qparam->scale.size(); i++)
+ {
+ auto weights_scale = weights_qparam->scale[i];
+ auto bias_scale = bias_qparam->scale[i];
+ RETURN_FALSE_UNLESS(same(bias_scale, input_scale * weights_scale));
+ }
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool VerifyQuantizedBiasScale::visit(const luci::CircleConv2D *node)
+{
+ RETURN_FALSE_UNLESS(check_bias_scale(node->input(), node->filter(), node->bias()));
+ return true;
+}
+
+bool VerifyQuantizedBiasScale::visit(const luci::CircleDepthwiseConv2D *node)
+{
+ RETURN_FALSE_UNLESS(check_bias_scale(node->input(), node->filter(), node->bias()));
+ return true;
+}
+
+bool VerifyQuantizedBiasScale::visit(const luci::CircleFullyConnected *node)
+{
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ {
+ RETURN_FALSE_UNLESS(check_bias_scale(node->input(), node->weights(), node->bias()));
+ }
+ return true;
+}
+
+bool VerifyQuantizedBiasScale::visit(const luci::CircleTransposeConv *node)
+{
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ {
+ RETURN_FALSE_UNLESS(check_bias_scale(node->outBackprop(), node->filter(), node->bias()));
+ }
+ return true;
+}
+
+} // namespace luci
+
+#undef RETURN_FALSE_UNLESS
diff --git a/compiler/luci/pass/src/VerifyQuantizedBiasScale.h b/compiler/luci/pass/src/VerifyQuantizedBiasScale.h
new file mode 100644
index 000000000..b41f78eca
--- /dev/null
+++ b/compiler/luci/pass/src/VerifyQuantizedBiasScale.h
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_VERIFY_QUANTIZED_BIAS_SCALE_H__
+#define __LUCI_VERIFY_QUANTIZED_BIAS_SCALE_H__
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+#include <memory>
+
+namespace luci
+{
+
+/**
+ * @brief Verify the scale of quantized bias node
+ * @details
+ *
+ * Bias of CONV, DCONV, TCONV, FC layers should meet the following condition.
+ *
+ * bias scale = input scale * weights scale
+ */
+class VerifyQuantizedBiasScale : public luci::CircleNodeVisitor<bool>
+{
+public:
+ static std::shared_ptr<VerifyQuantizedBiasScale> create()
+ {
+ return std::make_shared<VerifyQuantizedBiasScale>();
+ };
+
+public:
+ bool verify(luci::CircleNode *node) { return node->accept(this); }
+
+private:
+ // Operators with bias
+ bool visit(const luci::CircleConv2D *node);
+ bool visit(const luci::CircleDepthwiseConv2D *node);
+ bool visit(const luci::CircleFullyConnected *node);
+ bool visit(const luci::CircleTransposeConv *node);
+
+ bool visit(const luci::CircleNode *) { return true; }
+};
+
+} // namespace luci
+
+#endif // __LUCI_VERIFY_QUANTIZED_BIAS_SCALE_H__
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.cpp b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.cpp
new file mode 100644
index 000000000..8697090a7
--- /dev/null
+++ b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.cpp
@@ -0,0 +1,38 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "VerifyQuantizedNodeGranularity.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Pass/QuantizationParameters.h>
+
+#include <memory>
+
+namespace luci
+{
+
+std::shared_ptr<VerifyQuantizedNodeGranularity>
+VerifyQuantizedNodeGranularity::create(Granularity granularity)
+{
+ if (granularity == Granularity::ChannelWise)
+ return std::make_shared<VerifyQuantizedNodeChannelWiseGranularity>();
+ else if (granularity == Granularity::LayerWise)
+ return std::make_shared<VerifyQuantizedNodeLayerWiseGranularity>();
+ else
+ throw std::domain_error("Not supported Granularity type");
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeChannelWiseGranularity.h b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h
index bf3ff2e8a..442183c18 100644
--- a/compiler/luci/pass/src/VerifyQuantizedNodeChannelWiseGranularity.h
+++ b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h
@@ -1,5 +1,6 @@
/*
- * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
@@ -13,13 +14,15 @@
* limitations under the License.
*/
-#ifndef __LUCI_VERIFY_QUANTIZED_NODE_CHANNELWISE_GRANULARITY_H__
-#define __LUCI_VERIFY_QUANTIZED_NODE_CHANNELWISE_GRANULARITY_H__
+#ifndef __LUCI_VERIFY_QUANTIZED_NODE_GRANULARITY_H__
+#define __LUCI_VERIFY_QUANTIZED_NODE_GRANULARITY_H__
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
#include <luci/Pass/QuantizationParameters.h>
+#include <memory>
+
using Granularity = luci::QuantizationGranularity;
// This macro is undef at the end of the file
@@ -33,16 +36,19 @@ namespace luci
{
/**
- * @brief Verify the granualrity of channel-wise quantized node
+ * @brief Verify the granualrity of quantized node
* @details
*
* Targets to verify
* - node's output (i.e., node itself)
* - node's inputs
*/
-struct VerifyQuantizedNodeChannelWiseGranularity final : public luci::CircleNodeVisitor<bool>
+class VerifyQuantizedNodeGranularity : public luci::CircleNodeVisitor<bool>
{
-private:
+public:
+ static std::shared_ptr<VerifyQuantizedNodeGranularity> create(Granularity granularity);
+
+protected:
bool is_lwq(const loco::Node *node)
{
auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
@@ -59,48 +65,15 @@ private:
return true;
}
- uint32_t rank(const loco::Node *node)
- {
- auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
- return circle_node->rank();
- }
-
- bool is_cwq_const(const loco::Node *node, uint32_t channel_dim)
- {
- auto circle_node = loco::must_cast<const luci::CircleConst *>(node);
-
- assert(channel_dim < circle_node->rank()); // FIX_CALLER_UNLESS
- auto channel_size = circle_node->dim(channel_dim).value();
-
- if (circle_node->quantparam() == nullptr)
- return false;
-
- if (circle_node->quantparam()->quantized_dimension != static_cast<int32_t>(channel_dim))
- return false;
-
- if (circle_node->quantparam()->scale.size() != channel_size)
- return false;
-
- if (circle_node->quantparam()->zerop.size() != channel_size)
- return false;
-
- return true;
- }
-
private:
- bool visit(const luci::CircleConv2D *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
- return true;
- }
+ virtual bool visit(const luci::CircleConv2D *node) = 0;
bool visit(const luci::CircleConcatenation *node)
{
+ // Skip granularity check for concatenation of indices
+ if (node->dtype() == loco::DataType::S32 or node->dtype() == loco::DataType::S64)
+ return true;
+
RETURN_FALSE_UNLESS(is_lwq(node))
for (uint32_t i = 0; i < node->numValues(); i++)
{
@@ -116,25 +89,9 @@ private:
return true;
}
- bool visit(const luci::CircleDepthwiseConv2D *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 3))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
- return true;
- }
+ virtual bool visit(const luci::CircleDepthwiseConv2D *node) = 0;
- bool visit(const luci::CircleInstanceNorm *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_cwq_const(node->gamma(), rank(node->gamma()) - 1))
- RETURN_FALSE_UNLESS(is_cwq_const(node->beta(), rank(node->beta()) - 1))
- return true;
- }
+ virtual bool visit(const luci::CircleInstanceNorm *node) = 0;
bool visit(const luci::CirclePack *node)
{
@@ -168,37 +125,11 @@ private:
return true;
}
- bool visit(const luci::CirclePRelu *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_cwq_const(node->alpha(), rank(node->alpha()) - 1))
- return true;
- }
-
- bool visit(const luci::CircleTransposeConv *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->outBackprop()))
- RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
+ virtual bool visit(const luci::CirclePRelu *node) = 0;
- return true;
- }
+ virtual bool visit(const luci::CircleTransposeConv *node) = 0;
- bool visit(const luci::CircleFullyConnected *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_cwq_const(node->weights(), 0))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- // Bias is optional (it can be CircleOutputExclude)
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
- return true;
- }
+ virtual bool visit(const luci::CircleFullyConnected *node) = 0;
bool visit(const luci::CircleAdd *node)
{
@@ -258,6 +189,14 @@ private:
return true;
}
+ bool visit(const luci::CircleOneHot *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->off_value()));
+ RETURN_FALSE_UNLESS(is_lwq(node->on_value()));
+ return true;
+ }
+
bool visit(const luci::CircleRelu *node)
{
RETURN_FALSE_UNLESS(is_lwq(node));
@@ -480,8 +419,186 @@ private:
bool visit(const luci::CircleNode *) { return true; }
};
+class VerifyQuantizedNodeChannelWiseGranularity final : public VerifyQuantizedNodeGranularity
+{
+private:
+ uint32_t rank(const loco::Node *node)
+ {
+ auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
+ return circle_node->rank();
+ }
+
+ bool is_cwq_const(const loco::Node *node, uint32_t channel_dim)
+ {
+ auto circle_node = loco::must_cast<const luci::CircleConst *>(node);
+
+ assert(channel_dim < circle_node->rank()); // FIX_CALLER_UNLESS
+ auto channel_size = circle_node->dim(channel_dim).value();
+
+ if (circle_node->quantparam() == nullptr)
+ return false;
+
+ if (circle_node->quantparam()->quantized_dimension != static_cast<int32_t>(channel_dim))
+ return false;
+
+ if (circle_node->quantparam()->scale.size() != channel_size)
+ return false;
+
+ if (circle_node->quantparam()->zerop.size() != channel_size)
+ return false;
+
+ return true;
+ }
+
+private:
+ bool visit(const luci::CircleConv2D *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
+ return true;
+ }
+
+ bool visit(const luci::CircleDepthwiseConv2D *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 3))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
+ return true;
+ }
+
+ bool visit(const luci::CircleInstanceNorm *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->gamma(), rank(node->gamma()) - 1))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->beta(), rank(node->beta()) - 1))
+ return true;
+ }
+
+ bool visit(const luci::CirclePRelu *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->alpha(), rank(node->alpha()) - 1))
+ return true;
+ }
+
+ bool visit(const luci::CircleTransposeConv *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->outBackprop()))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
+
+ return true;
+ }
+
+ bool visit(const luci::CircleFullyConnected *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->weights(), 0))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ // Bias is optional (it can be CircleOutputExclude)
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
+ return true;
+ }
+};
+
+class VerifyQuantizedNodeLayerWiseGranularity final : public VerifyQuantizedNodeGranularity
+{
+private:
+ bool is_lwq_const(const loco::Node *node)
+ {
+ auto circle_node = loco::must_cast<const luci::CircleConst *>(node);
+
+ if (circle_node->quantparam() == nullptr)
+ return false;
+
+ if (circle_node->quantparam()->scale.size() != 1)
+ return false;
+
+ if (circle_node->quantparam()->zerop.size() != 1)
+ return false;
+
+ return true;
+ }
+
+private:
+ bool visit(const luci::CircleConv2D *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
+ return true;
+ }
+
+ bool visit(const luci::CircleDepthwiseConv2D *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
+ return true;
+ }
+
+ bool visit(const luci::CircleInstanceNorm *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->gamma()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->beta()))
+ return true;
+ }
+
+ bool visit(const luci::CirclePRelu *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->alpha()))
+ return true;
+ }
+
+ bool visit(const luci::CircleTransposeConv *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->outBackprop()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
+ return true;
+ }
+
+ bool visit(const luci::CircleFullyConnected *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->weights()))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
+ return true;
+ }
+};
+
} // namespace luci
#undef RETURN_FALSE_UNLESS
-#endif // __LUCI_VERIFY_QUANTIZED_NODE_CHANNELWISE_GRANULARITY_H__
+#endif // __LUCI_VERIFY_QUANTIZED_NODE_GRANULARITY_H__
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeLayerWiseGranularity.h b/compiler/luci/pass/src/VerifyQuantizedNodeLayerWiseGranularity.h
deleted file mode 100644
index 9bc8b31df..000000000
--- a/compiler/luci/pass/src/VerifyQuantizedNodeLayerWiseGranularity.h
+++ /dev/null
@@ -1,473 +0,0 @@
-/*
- * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __LUCI_VERIFY_QUANTIZED_NODE_LAYERWISE_GRANULARITY_H__
-#define __LUCI_VERIFY_QUANTIZED_NODE_LAYERWISE_GRANULARITY_H__
-
-#include <luci/IR/CircleNodes.h>
-#include <luci/IR/CircleNodeVisitor.h>
-#include <luci/Pass/QuantizationParameters.h>
-
-using Granularity = luci::QuantizationGranularity;
-
-// This macro is undef at the end of the file
-#define RETURN_FALSE_UNLESS(ARG) \
- if (not(ARG)) \
- { \
- return false; \
- }
-
-namespace luci
-{
-
-/**
- * @brief Verify the granualrity of layer-wise quantized node
- * @details
- *
- * Targets to verify
- * - node's output (i.e., node itself)
- * - node's inputs
- */
-struct VerifyQuantizedNodeLayerWiseGranularity final : public luci::CircleNodeVisitor<bool>
-{
-private:
- bool is_lwq(const loco::Node *node)
- {
- auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
-
- if (circle_node->quantparam() == nullptr)
- return false;
-
- if (circle_node->quantparam()->scale.size() != 1)
- return false;
-
- if (circle_node->quantparam()->zerop.size() != 1)
- return false;
-
- return true;
- }
-
- bool is_lwq_const(const loco::Node *node)
- {
- auto circle_node = loco::must_cast<const luci::CircleConst *>(node);
-
- if (circle_node->quantparam() == nullptr)
- return false;
-
- if (circle_node->quantparam()->scale.size() != 1)
- return false;
-
- if (circle_node->quantparam()->zerop.size() != 1)
- return false;
-
- return true;
- }
-
-private:
- bool visit(const luci::CircleConv2D *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
- return true;
- }
-
- bool visit(const luci::CircleConcatenation *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- for (uint32_t i = 0; i < node->numValues(); i++)
- {
- RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
- }
- return true;
- }
-
- bool visit(const luci::CircleDepthToSpace *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- return true;
- }
-
- bool visit(const luci::CircleDepthwiseConv2D *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
- return true;
- }
-
- bool visit(const luci::CircleInstanceNorm *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_lwq_const(node->gamma()))
- RETURN_FALSE_UNLESS(is_lwq_const(node->beta()))
- return true;
- }
-
- bool visit(const luci::CirclePack *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- for (uint32_t i = 0; i < node->values_count(); i++)
- {
- RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
- }
- return true;
- }
-
- bool visit(const luci::CirclePad *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- return true;
- }
-
- bool visit(const luci::CirclePadV2 *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_lwq(node->constant_values()))
- return true;
- }
-
- bool visit(const luci::CircleMirrorPad *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- return true;
- }
-
- bool visit(const luci::CirclePRelu *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_lwq_const(node->alpha()))
- return true;
- }
-
- bool visit(const luci::CircleTransposeConv *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->outBackprop()))
- RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
- return true;
- }
-
- bool visit(const luci::CircleFullyConnected *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_lwq_const(node->weights()))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
- return true;
- }
-
- bool visit(const luci::CircleAdd *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- RETURN_FALSE_UNLESS(is_lwq(node->y()));
- return true;
- }
-
- bool visit(const luci::CircleAveragePool2D *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->value()));
- return true;
- }
-
- bool visit(const luci::CircleLogicalOr *)
- {
- // Logical OR has bool-type inputs and output
- // Nothing to be checked
- return true;
- }
-
- bool visit(const luci::CircleMaxPool2D *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->value()));
- return true;
- }
-
- bool visit(const luci::CircleLocalResponseNormalization *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleMean *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleMul *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- RETURN_FALSE_UNLESS(is_lwq(node->y()));
- return true;
- }
-
- bool visit(const luci::CircleNotEqual *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- RETURN_FALSE_UNLESS(is_lwq(node->y()));
- return true;
- }
-
- bool visit(const luci::CircleRelu *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->features()));
- return true;
- }
-
- bool visit(const luci::CircleReshape *node)
- {
- auto input = loco::must_cast<const luci::CircleNode *>(node->tensor());
- bool input_quantized = input->quantparam() != nullptr;
- bool node_quantized = node->quantparam() != nullptr;
- RETURN_FALSE_UNLESS(input_quantized == node_quantized);
- RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node))
- RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input));
- return true;
- }
-
- bool visit(const luci::CircleLogistic *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- return true;
- }
-
- bool visit(const luci::CircleSoftmax *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->logits()));
- return true;
- }
-
- bool visit(const luci::CircleSpaceToBatchND *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleSpaceToDepth *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleSlice *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleSplit *node)
- {
- // node's output is the input of CircleSplitOut, thus not quantized
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleSplitOut *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- return true;
- }
-
- bool visit(const luci::CircleSplitV *node)
- {
- // node's output is the input of CircleSplitVOut, thus not quantized
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleSplitVOut *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- return true;
- }
-
- bool visit(const luci::CircleStridedSlice *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleArgMax *node)
- {
- // node's output is index, thus not quantized
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleBatchToSpaceND *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleTanh *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- return true;
- }
-
- bool visit(const luci::CircleTranspose *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->a()));
- return true;
- }
-
- bool visit(const luci::CircleFloor *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- return true;
- }
-
- bool visit(const luci::CircleGreater *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- RETURN_FALSE_UNLESS(is_lwq(node->y()));
- return true;
- }
-
- bool visit(const luci::CircleGreaterEqual *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- RETURN_FALSE_UNLESS(is_lwq(node->y()));
- return true;
- }
-
- bool visit(const luci::CircleDiv *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- RETURN_FALSE_UNLESS(is_lwq(node->y()));
- return true;
- }
-
- bool visit(const luci::CircleFloorDiv *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- RETURN_FALSE_UNLESS(is_lwq(node->y()));
- return true;
- }
-
- bool visit(const luci::CircleRsqrt *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- return true;
- }
-
- bool visit(const luci::CircleSqrt *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- return true;
- }
-
- bool visit(const luci::CircleElu *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->features()));
- return true;
- }
-
- bool visit(const luci::CirclePow *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- RETURN_FALSE_UNLESS(is_lwq(node->y()));
- return true;
- }
-
- bool visit(const luci::CircleResizeBilinear *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleResizeNearestNeighbor *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleUnpack *node)
- {
- // node's output is the input of CircleUnpackOut, thus not quantized
- RETURN_FALSE_UNLESS(is_lwq(node->value()));
- return true;
- }
-
- bool visit(const luci::CircleUnpackOut *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- return true;
- }
-
- bool visit(const luci::CircleCast *node)
- {
- auto input = loco::must_cast<const luci::CircleNode *>(node->x());
- bool input_quantized = input->quantparam() != nullptr;
- bool node_quantized = node->quantparam() != nullptr;
- RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input));
- RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node));
- return true;
- }
-
- // TODO: Implement more Ops
-
- bool visit(const luci::CircleNode *) { return true; }
-};
-
-} // namespace luci
-
-#undef RETURN_FALSE_UNLESS
-
-#endif // __LUCI_VERIFY_QUANTIZED_NODE_LAYERWISE_GRANULARITY_H__
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeS16Type.h b/compiler/luci/pass/src/VerifyQuantizedNodeS16Type.h
deleted file mode 100644
index eeec7b82b..000000000
--- a/compiler/luci/pass/src/VerifyQuantizedNodeS16Type.h
+++ /dev/null
@@ -1,516 +0,0 @@
-/*
- * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __LUCI_VERIFY_QUANTIZED_NODE_S16_TYPE_H__
-#define __LUCI_VERIFY_QUANTIZED_NODE_S16_TYPE_H__
-
-#include <luci/IR/CircleNodes.h>
-#include <luci/IR/CircleNodeVisitor.h>
-
-#include <cmath>
-
-using Type = loco::DataType;
-
-// This macro is undef at the end of the file
-#define RETURN_FALSE_UNLESS(ARG) \
- if (not(ARG)) \
- { \
- return false; \
- }
-
-namespace luci
-{
-
-/**
- * @brief Verify the data type of INT16 quantized node
- * @details
- *
- * Targets to verify
- * - node's output (i.e., node itself)
- * - node's inputs
- */
-struct VerifyQuantizedNodeS16Type final : public luci::CircleNodeVisitor<bool>
-{
-private:
- bool has_type(const loco::Node *node, Type dtype)
- {
- auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
- return circle_node->dtype() == dtype;
- }
-
-private:
- bool visit(const luci::CircleConv2D *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->filter(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S64))
- return true;
- }
-
- bool visit(const luci::CircleConcatenation *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- for (uint32_t i = 0; i < node->numValues(); i++)
- {
- RETURN_FALSE_UNLESS(has_type(node->values(i), Type::S16))
- }
- return true;
- }
-
- bool visit(const luci::CircleDepthToSpace *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleDepthwiseConv2D *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->filter(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S64))
- return true;
- }
-
- bool visit(const luci::CircleInstanceNorm *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->gamma(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->beta(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CirclePack *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- for (uint32_t i = 0; i < node->values_count(); i++)
- {
- RETURN_FALSE_UNLESS(has_type(node->values(i), Type::S16))
- }
- return true;
- }
-
- bool visit(const luci::CirclePad *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CirclePadV2 *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32))
- RETURN_FALSE_UNLESS(has_type(node->constant_values(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleMirrorPad *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CirclePRelu *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->alpha(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleTransposeConv *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->outBackprop(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->filter(), Type::S16))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(has_type(bias, Type::S64))
- return true;
- }
-
- bool visit(const luci::CircleFullyConnected *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->weights(), Type::S16))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(has_type(bias, Type::S64))
- return true;
- }
-
- bool visit(const luci::CircleAdd *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleAveragePool2D *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->value(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleLogicalOr *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::BOOL))
- return true;
- }
-
- bool visit(const luci::CircleMaxPool2D *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->value(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleLocalResponseNormalization *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleMean *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->reduction_indices(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleMul *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleNotEqual *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleRelu *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->features(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleReshape *node)
- {
- if (node->quantparam())
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->tensor(), Type::S16))
- }
- else
- {
- RETURN_FALSE_UNLESS(has_type(node->tensor(), node->dtype()))
- }
- luci::CircleConst *shape = dynamic_cast<luci::CircleConst *>(node->shape());
- if (shape != nullptr)
- RETURN_FALSE_UNLESS(has_type(shape, Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleLogistic *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
-
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32768.0f);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
- return true;
- }
-
- bool visit(const luci::CircleSoftmax *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->logits(), Type::S16))
-
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32767.0f);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
- return true;
- }
-
- bool visit(const luci::CircleSpaceToBatchND *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleSpaceToDepth *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleSlice *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->begin(), Type::S32) || has_type(node->begin(), Type::S64))
- RETURN_FALSE_UNLESS(has_type(node->size(), Type::S32) || has_type(node->size(), Type::S64))
- return true;
- }
-
- bool visit(const luci::CircleSplit *node)
- {
- // node's output is the input of CircleSplitOut, thus not quantized
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleSplitOut *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
-
- // SplitOut has the same qparam with the input of Split
- auto split = loco::must_cast<luci::CircleSplit *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(split->input());
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
- return true;
- }
-
- bool visit(const luci::CircleSplitV *node)
- {
- // node's output is the input of CircleSplitVOut, thus not quantized
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleSplitVOut *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
-
- // SplitVOut has the same qparam with the input of SplitV
- auto splitv = loco::must_cast<luci::CircleSplitV *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(splitv->input());
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
- return true;
- }
-
- bool visit(const luci::CircleStridedSlice *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
-
- auto input = loco::must_cast<luci::CircleNode *>(node->input());
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
- return true;
- }
-
- bool visit(const luci::CircleArgMax *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, node->output_type()))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->dimension(), Type::S32) ||
- has_type(node->dimension(), Type::S64))
- return true;
- }
-
- bool visit(const luci::CircleBatchToSpaceND *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleTanh *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
-
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32768.0f);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
- return true;
- }
-
- bool visit(const luci::CircleTranspose *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->a(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->perm(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleFloor *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
-
- // This checks the value of scale is an integer
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]);
- return true;
- }
-
- bool visit(const luci::CircleGreater *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleGreaterEqual *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleDiv *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleFloorDiv *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
-
- // This checks the value of scale is an integer
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]);
- return true;
- }
-
- bool visit(const luci::CircleRsqrt *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleSqrt *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleElu *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->features(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CirclePow *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleResizeBilinear *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleResizeNearestNeighbor *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleUnpack *node)
- {
- // node's output is the input of CircleUnpackOut, thus not quantized
- RETURN_FALSE_UNLESS(has_type(node->value(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleUnpackOut *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
-
- // UnpackOut has the same qparam with the input of Unpack
- auto Unpack = loco::must_cast<luci::CircleUnpack *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(Unpack->value());
- RETURN_FALSE_UNLESS(node->quantparam() && input->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
- return true;
- }
-
- bool visit(const luci::CircleCast *node)
- {
- auto *input = loco::must_cast<luci::CircleNode *>(node->x());
- RETURN_FALSE_UNLESS(has_type(input, node->in_data_type()))
-
- bool input_quantized = input->quantparam() != nullptr;
- if (input_quantized)
- RETURN_FALSE_UNLESS(has_type(input, Type::S16))
-
- RETURN_FALSE_UNLESS(has_type(node, node->out_data_type()))
-
- bool node_quantized = node->quantparam() != nullptr;
- if (node_quantized)
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- return true;
- }
-
- // TODO: Implement more Ops
-
- bool visit(const luci::CircleNode *) { return true; }
-};
-
-} // namespace luci
-
-#undef RETURN_FALSE_UNLESS
-
-#endif // __LUCI_VERIFY_QUNTIZED_NODE_S16_TYPE_H__
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp b/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp
new file mode 100644
index 000000000..4e1c062c0
--- /dev/null
+++ b/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp
@@ -0,0 +1,554 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "VerifyQuantizedNodeType.h"
+
+#include <cmath>
+#include <memory>
+
+// This macro is undef at the end of the file
+#define RETURN_FALSE_UNLESS(ARG) \
+ if (not(ARG)) \
+ { \
+ return false; \
+ }
+
+namespace luci
+{
+
+std::shared_ptr<VerifyQuantizedNodeType> VerifyQuantizedNodeType::create(loco::DataType dtype)
+{
+ if (dtype == loco::DataType::U8)
+ return std::make_shared<VerifyQuantizedNodeU8Type>();
+ else if (dtype == loco::DataType::S16)
+ return std::make_shared<VerifyQuantizedNodeS16Type>();
+ else
+ throw std::domain_error("Not supported Quantized type");
+}
+
+} // namespace luci
+
+namespace luci
+{
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleAdd *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleArgMax *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, node->output_type()))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->dimension(), loco::DataType::S32) ||
+ has_type(node->dimension(), loco::DataType::S64))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleAveragePool2D *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleBatchToSpaceND *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleCast *node)
+{
+ auto *input = loco::must_cast<luci::CircleNode *>(node->x());
+ bool input_quantized = input->quantparam() != nullptr;
+ if (input_quantized)
+ {
+ RETURN_FALSE_UNLESS(has_type(input, node->in_data_type()))
+ RETURN_FALSE_UNLESS(has_type(input, Qtype))
+ }
+
+ bool node_quantized = node->quantparam() != nullptr;
+ if (node_quantized)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, node->out_data_type()))
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ }
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleConv2D *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->filter(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->bias(), Btype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleConcatenation *node)
+{
+ // Allow concatenation of indices
+ if (group_has_type(node, loco::DataType::S32) or group_has_type(node, loco::DataType::S64))
+ return true;
+
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleDepthToSpace *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleDepthwiseConv2D *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->filter(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->bias(), Btype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleDiv *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleElu *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleFloor *node)
+{
+ RETURN_FALSE_UNLESS(group_has_type(node, Qtype));
+
+ // This checks the value of scale is an integer
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]);
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleFloorDiv *node)
+{
+ RETURN_FALSE_UNLESS(group_has_type(node, Qtype));
+
+ // This checks the value of scale is an integer
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]);
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleFullyConnected *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->weights(), Qtype))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(has_type(bias, Btype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleGreater *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, loco::DataType::BOOL))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleGreaterEqual *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, loco::DataType::BOOL))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleInstanceNorm *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(
+ const luci::CircleLocalResponseNormalization *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleLogicalOr *node)
+{
+ return group_has_type(node, loco::DataType::BOOL);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleMaxPool2D *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleMean *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->reduction_indices(), loco::DataType::S32))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleMirrorPad *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->paddings(), loco::DataType::S32))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleMul *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleNotEqual *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, loco::DataType::BOOL))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleOneHot *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype));
+ RETURN_FALSE_UNLESS(has_type(node->indices(), loco::DataType::S32) ||
+ has_type(node->indices(), loco::DataType::S64));
+ RETURN_FALSE_UNLESS(has_type(node->depth(), loco::DataType::S32));
+ RETURN_FALSE_UNLESS(has_type(node->on_value(), Qtype));
+ RETURN_FALSE_UNLESS(has_type(node->off_value(), Qtype));
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePack *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePad *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->paddings(), loco::DataType::S32))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePadV2 *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->paddings(), loco::DataType::S32))
+ RETURN_FALSE_UNLESS(has_type(node->constant_values(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePRelu *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePow *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleRelu *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleReshape *node)
+{
+ if (node->quantparam())
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->tensor(), Qtype))
+ }
+ else
+ {
+ RETURN_FALSE_UNLESS(has_type(node->tensor(), node->dtype()))
+ }
+ luci::CircleConst *shape = dynamic_cast<luci::CircleConst *>(node->shape());
+ if (shape != nullptr)
+ RETURN_FALSE_UNLESS(has_type(shape, loco::DataType::S32))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleResizeBilinear *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleResizeNearestNeighbor *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleRsqrt *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSlice *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->begin(), loco::DataType::S32) ||
+ has_type(node->begin(), loco::DataType::S64))
+ RETURN_FALSE_UNLESS(has_type(node->size(), loco::DataType::S32) ||
+ has_type(node->size(), loco::DataType::S64))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSpaceToBatchND *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSpaceToDepth *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSplit *node)
+{
+ // node's output is the input of CircleSplitOut, thus not quantized
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSplitOut *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+
+ // SplitOut has the same qparam with the input of Split
+ auto split = loco::must_cast<luci::CircleSplit *>(node->input());
+ auto input = loco::must_cast<luci::CircleNode *>(split->input());
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSplitV *node)
+{
+ // node's output is the input of CircleSplitVOut, thus not quantized
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSplitVOut *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+
+ // SplitVOut has the same qparam with the input of SplitV
+ auto splitv = loco::must_cast<luci::CircleSplitV *>(node->input());
+ auto input = loco::must_cast<luci::CircleNode *>(splitv->input());
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSqrt *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleStridedSlice *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+
+ auto input = loco::must_cast<luci::CircleNode *>(node->input());
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleTranspose *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->a(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->perm(), loco::DataType::S32))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleTransposeConv *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->outBackprop(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->filter(), Qtype))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(has_type(bias, Btype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleUnpack *node)
+{
+ // node's output is the input of CircleUnpackOut, thus not quantized
+ RETURN_FALSE_UNLESS(has_type(node->value(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleUnpackOut *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+
+ // UnpackOut has the same qparam with the input of Unpack
+ auto Unpack = loco::must_cast<luci::CircleUnpack *>(node->input());
+ auto input = loco::must_cast<luci::CircleNode *>(Unpack->value());
+ RETURN_FALSE_UNLESS(node->quantparam() && input->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
+ return true;
+}
+
+} // namespace luci
+
+namespace luci
+{
+
+bool VerifyQuantizedNodeU8Type::visit(const luci::CircleTanh *node)
+{
+ RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::U8));
+
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 2.0f / 256.0f);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 128);
+ return true;
+}
+
+bool VerifyQuantizedNodeU8Type::visit(const luci::CircleLogistic *node)
+{
+ RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::U8));
+
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 256.0f);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
+ return true;
+}
+
+bool VerifyQuantizedNodeU8Type::visit(const luci::CircleSoftmax *node)
+{
+ RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::U8));
+
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 255.0f);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
+ return true;
+}
+
+} // namespace luci
+
+namespace luci
+{
+
+bool VerifyQuantizedNodeS16Type::visit(const luci::CircleTanh *node)
+{
+ RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::S16));
+
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32768.0f);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
+ return true;
+}
+
+bool VerifyQuantizedNodeS16Type::visit(const luci::CircleLogistic *node)
+{
+ RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::S16));
+
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32768.0f);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
+ return true;
+}
+
+bool VerifyQuantizedNodeS16Type::visit(const luci::CircleSoftmax *node)
+{
+ RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::S16));
+
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32767.0f);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
+ return true;
+}
+
+} // namespace luci
+
+#undef RETURN_FALSE_UNLESS
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeType.h b/compiler/luci/pass/src/VerifyQuantizedNodeType.h
new file mode 100644
index 000000000..ff1acbd6f
--- /dev/null
+++ b/compiler/luci/pass/src/VerifyQuantizedNodeType.h
@@ -0,0 +1,157 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_VERIFY_QUANTIZED_NODE_TYPE_H__
+#define __LUCI_VERIFY_QUANTIZED_NODE_TYPE_H__
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+namespace luci
+{
+
+/**
+ * @brief Verify the data type of quantized node
+ * @details
+ *
+ * Targets to verify
+ * - node's output (i.e., node itself)
+ * - node's inputs
+ */
+class VerifyQuantizedNodeType
+{
+public:
+ static std::shared_ptr<VerifyQuantizedNodeType> create(loco::DataType dtype);
+
+public:
+ virtual bool verify(luci::CircleNode *node) = 0;
+};
+
+/**
+ * @brief Verify using quantization type of a node and bias
+ *
+ * @tparam Qtype Quantization type for a node (e.g. Q8, Q16, ...)
+ * @tparam Btype Bias quantization type (e.g. For Q8, S32 is used)
+ */
+template <loco::DataType Qtype, loco::DataType Btype>
+class VerifyQuantizedNodeTypeBase : public luci::CircleNodeVisitor<bool>,
+ public VerifyQuantizedNodeType
+{
+public:
+ bool verify(luci::CircleNode *node) { return node->accept(this); }
+
+protected:
+ bool has_type(const loco::Node *node, loco::DataType dtype)
+ {
+ auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
+ return circle_node->dtype() == dtype;
+ }
+
+ // Check whether a node and all of its inputs have dtype or not
+ bool group_has_type(const loco::Node *node, loco::DataType dtype)
+ {
+ if (!has_type(node, dtype))
+ return false;
+
+ for (uint32_t i = 0; i < node->arity(); ++i)
+ if (!has_type(node->arg(i), dtype))
+ return false;
+
+ return true;
+ }
+
+private:
+ bool visit(const luci::CircleAdd *node);
+ bool visit(const luci::CircleArgMax *node);
+ bool visit(const luci::CircleAveragePool2D *node);
+ bool visit(const luci::CircleBatchToSpaceND *node);
+ bool visit(const luci::CircleCast *node);
+ bool visit(const luci::CircleConv2D *node);
+ bool visit(const luci::CircleConcatenation *node);
+ bool visit(const luci::CircleDepthToSpace *node);
+ bool visit(const luci::CircleDepthwiseConv2D *node);
+ bool visit(const luci::CircleDiv *node);
+ bool visit(const luci::CircleElu *node);
+ bool visit(const luci::CircleFloor *node);
+ bool visit(const luci::CircleFloorDiv *node);
+ bool visit(const luci::CircleFullyConnected *node);
+ bool visit(const luci::CircleGreater *node);
+ bool visit(const luci::CircleGreaterEqual *node);
+ bool visit(const luci::CircleInstanceNorm *node);
+ bool visit(const luci::CircleLocalResponseNormalization *node);
+ bool visit(const luci::CircleLogicalOr *node);
+ bool visit(const luci::CircleMaxPool2D *node);
+ bool visit(const luci::CircleMean *node);
+ bool visit(const luci::CircleMirrorPad *node);
+ bool visit(const luci::CircleMul *node);
+ bool visit(const luci::CircleNotEqual *node);
+ bool visit(const luci::CircleOneHot *node);
+ bool visit(const luci::CirclePack *node);
+ bool visit(const luci::CirclePad *node);
+ bool visit(const luci::CirclePadV2 *node);
+ bool visit(const luci::CirclePRelu *node);
+ bool visit(const luci::CirclePow *node);
+ bool visit(const luci::CircleRelu *node);
+ bool visit(const luci::CircleReshape *node);
+ bool visit(const luci::CircleResizeBilinear *node);
+ bool visit(const luci::CircleResizeNearestNeighbor *node);
+ bool visit(const luci::CircleRsqrt *node);
+ bool visit(const luci::CircleSlice *node);
+ bool visit(const luci::CircleSpaceToBatchND *node);
+ bool visit(const luci::CircleSpaceToDepth *node);
+ bool visit(const luci::CircleSplit *node);
+ bool visit(const luci::CircleSplitOut *node);
+ bool visit(const luci::CircleSplitV *node);
+ bool visit(const luci::CircleSplitVOut *node);
+ bool visit(const luci::CircleSqrt *node);
+ bool visit(const luci::CircleStridedSlice *node);
+ bool visit(const luci::CircleTranspose *node);
+ bool visit(const luci::CircleTransposeConv *node);
+ bool visit(const luci::CircleUnpack *node);
+ bool visit(const luci::CircleUnpackOut *node);
+
+ // NOTE below nodes has differnent implementation for Qtype/Btype and
+ // implementations exist in VerifyQuantizedNodeU8Type, VerifyQuantizedNodeS16Type
+ // bool visit(const luci::CircleLogistic *node);
+ // bool visit(const luci::CircleSoftmax *node);
+ // bool visit(const luci::CircleTanh *node);
+
+ // TODO: Implement more Ops
+
+ bool visit(const luci::CircleNode *) { return true; }
+};
+
+class VerifyQuantizedNodeU8Type
+ : public VerifyQuantizedNodeTypeBase<loco::DataType::U8, loco::DataType::S32>
+{
+private:
+ bool visit(const luci::CircleLogistic *node);
+ bool visit(const luci::CircleSoftmax *node);
+ bool visit(const luci::CircleTanh *node);
+};
+
+class VerifyQuantizedNodeS16Type
+ : public VerifyQuantizedNodeTypeBase<loco::DataType::S16, loco::DataType::S64>
+{
+private:
+ bool visit(const luci::CircleLogistic *node);
+ bool visit(const luci::CircleSoftmax *node);
+ bool visit(const luci::CircleTanh *node);
+};
+
+} // namespace luci
+
+#endif // __LUCI_VERIFY_QUANTIZED_NODE_TYPE_H__
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h b/compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h
deleted file mode 100644
index e7dd1b072..000000000
--- a/compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h
+++ /dev/null
@@ -1,518 +0,0 @@
-/*
- * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __LUCI_VERIFY_QUANTIZED_NODE_U8_TYPE_H__
-#define __LUCI_VERIFY_QUANTIZED_NODE_U8_TYPE_H__
-
-#include <luci/IR/CircleNodes.h>
-#include <luci/IR/CircleNodeVisitor.h>
-
-#include <cmath>
-
-using Type = loco::DataType;
-
-// This macro is undef at the end of the file
-#define RETURN_FALSE_UNLESS(ARG) \
- if (not(ARG)) \
- { \
- return false; \
- }
-
-namespace luci
-{
-
-/**
- * @brief Verify the data type of UINT8 quantized node
- * @details
- *
- * Targets to verify
- * - node's output (i.e., node itself)
- * - node's inputs
- */
-struct VerifyQuantizedNodeU8Type final : public luci::CircleNodeVisitor<bool>
-{
-private:
- bool has_type(const loco::Node *node, Type dtype)
- {
- auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
- return circle_node->dtype() == dtype;
- }
-
-private:
- bool visit(const luci::CircleConv2D *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->filter(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleConcatenation *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- for (uint32_t i = 0; i < node->numValues(); i++)
- {
- RETURN_FALSE_UNLESS(has_type(node->values(i), Type::U8))
- }
- return true;
- }
-
- bool visit(const luci::CircleDepthToSpace *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleDepthwiseConv2D *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->filter(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleInstanceNorm *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->gamma(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->beta(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CirclePack *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- for (uint32_t i = 0; i < node->values_count(); i++)
- {
- RETURN_FALSE_UNLESS(has_type(node->values(i), Type::U8))
- }
- return true;
- }
-
- bool visit(const luci::CirclePad *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CirclePadV2 *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32))
- RETURN_FALSE_UNLESS(has_type(node->constant_values(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleMirrorPad *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CirclePRelu *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->alpha(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleTransposeConv *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->outBackprop(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->filter(), Type::U8))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(has_type(bias, Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleFullyConnected *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->weights(), Type::U8))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(has_type(bias, Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleAdd *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleAveragePool2D *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->value(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleBatchToSpaceND *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleLogicalOr *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::BOOL))
- return true;
- }
-
- bool visit(const luci::CircleMaxPool2D *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->value(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleLocalResponseNormalization *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleMean *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->reduction_indices(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleMul *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleNotEqual *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleRelu *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->features(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleReshape *node)
- {
- if (node->quantparam())
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->tensor(), Type::U8))
- }
- else
- {
- RETURN_FALSE_UNLESS(has_type(node->tensor(), node->dtype()))
- }
- luci::CircleConst *shape = dynamic_cast<luci::CircleConst *>(node->shape());
- if (shape != nullptr)
- RETURN_FALSE_UNLESS(has_type(shape, Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleLogistic *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
-
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 256.0f);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
- return true;
- }
-
- bool visit(const luci::CircleSoftmax *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->logits(), Type::U8))
-
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 255.0f);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
- return true;
- }
-
- bool visit(const luci::CircleSpaceToBatchND *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleSpaceToDepth *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleSlice *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->begin(), Type::S32) || has_type(node->begin(), Type::S64))
- RETURN_FALSE_UNLESS(has_type(node->size(), Type::S32) || has_type(node->size(), Type::S64))
- return true;
- }
-
- bool visit(const luci::CircleSplit *node)
- {
- // node's output is the input of CircleSplitOut, thus not quantized
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleSplitOut *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
-
- // SplitOut has the same qparam with the input of Split
- auto split = loco::must_cast<luci::CircleSplit *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(split->input());
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
- return true;
- }
-
- bool visit(const luci::CircleSplitV *node)
- {
- // node's output is the input of CircleSplitVOut, thus not quantized
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleSplitVOut *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
-
- // SplitVOut has the same qparam with the input of SplitV
- auto splitv = loco::must_cast<luci::CircleSplitV *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(splitv->input());
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
- return true;
- }
-
- bool visit(const luci::CircleStridedSlice *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
-
- auto input = loco::must_cast<luci::CircleNode *>(node->input());
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
- return true;
- }
-
- bool visit(const luci::CircleArgMax *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, node->output_type()))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->dimension(), Type::S32) ||
- has_type(node->dimension(), Type::S64))
- return true;
- }
-
- bool visit(const luci::CircleTanh *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
-
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 2.0f / 256.0f);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 128);
- return true;
- }
-
- bool visit(const luci::CircleTranspose *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->a(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->perm(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleFloor *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
-
- // This checks the value of scale is an integer
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]);
- return true;
- }
-
- bool visit(const luci::CircleGreater *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleGreaterEqual *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleDiv *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleFloorDiv *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
-
- // This checks the value of scale is an integer
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]);
- return true;
- }
-
- bool visit(const luci::CircleRsqrt *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleSqrt *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleElu *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->features(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CirclePow *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleResizeBilinear *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleResizeNearestNeighbor *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleUnpack *node)
- {
- // node's output is the input of CircleUnpackOut, thus not quantized
- RETURN_FALSE_UNLESS(has_type(node->value(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleUnpackOut *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
-
- // UnpackOut has the same qparam with the input of Unpack
- auto Unpack = loco::must_cast<luci::CircleUnpack *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(Unpack->value());
- RETURN_FALSE_UNLESS(node->quantparam() && input->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
- return true;
- }
-
- bool visit(const luci::CircleCast *node)
- {
- auto *input = loco::must_cast<luci::CircleNode *>(node->x());
- bool input_quantized = input->quantparam() != nullptr;
- if (input_quantized)
- {
- RETURN_FALSE_UNLESS(has_type(input, node->in_data_type()))
- RETURN_FALSE_UNLESS(has_type(input, Type::U8))
- }
-
- bool node_quantized = node->quantparam() != nullptr;
- if (node_quantized)
- {
- RETURN_FALSE_UNLESS(has_type(node, node->out_data_type()))
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- }
- return true;
- }
-
- // TODO: Implement more Ops
-
- bool visit(const luci::CircleNode *) { return true; }
-};
-
-} // namespace luci
-
-#undef RETURN_FALSE_UNLESS
-
-#endif // __LUCI_VERIFY_QUNTIZED_NODE_U8_TYPE_H__
diff --git a/compiler/luci/pass/src/helpers/LayerInfoMap.cpp b/compiler/luci/pass/src/helpers/LayerInfoMap.cpp
new file mode 100644
index 000000000..ac07f9ec9
--- /dev/null
+++ b/compiler/luci/pass/src/helpers/LayerInfoMap.cpp
@@ -0,0 +1,189 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "LayerInfoMap.h"
+
+#include <luci/IR/CircleNode.h>
+
+#include <cassert>
+
+namespace luci
+{
+namespace
+{
+
+bool is_multiple_output_node(const luci::CircleNode *node)
+{
+ switch (node->opcode())
+ {
+ // The following nodes have multiple outputs. Output tensors are not produced by themselves but
+ // by the corresponding *Out nodes.
+ case luci::CircleOpcode::SPLIT:
+ case luci::CircleOpcode::SPLIT_V:
+ case luci::CircleOpcode::TOPK_V2:
+ case luci::CircleOpcode::UNIQUE:
+ case luci::CircleOpcode::UNPACK:
+ return true;
+ // TODO: Support ops
+ case luci::CircleOpcode::BIDIRECTIONAL_SEQUENCE_LSTM:
+ case luci::CircleOpcode::CUSTOM:
+ case luci::CircleOpcode::IF:
+ case luci::CircleOpcode::NON_MAX_SUPPRESSION_V4:
+ case luci::CircleOpcode::NON_MAX_SUPPRESSION_V5:
+ case luci::CircleOpcode::WHILE:
+ throw std::runtime_error("Unsupported op now");
+ default:
+ return false;
+ }
+}
+
+const luci::CircleNode *get_multi_output_node(const luci::CircleNode *node)
+{
+ if (is_multiple_output_node(node))
+ return node;
+
+ switch (node->opcode())
+ {
+ // The following nodes denote outputs of multiple-output nodes.
+ case luci::CircleOpcode::CIRCLESPLITOUT:
+ {
+ const auto split_out = loco::must_cast<const CircleSplitOut *>(node);
+ return loco::must_cast<luci::CircleNode *>(split_out->input());
+ }
+ case luci::CircleOpcode::CIRCLESPLITVOUT:
+ {
+ const auto splitv_out = loco::must_cast<const CircleSplitVOut *>(node);
+ return loco::must_cast<luci::CircleNode *>(splitv_out->input());
+ }
+ case luci::CircleOpcode::CIRCLETOPKV2OUT:
+ {
+ const auto top_kv2_out = loco::must_cast<const CircleTopKV2Out *>(node);
+ return loco::must_cast<luci::CircleNode *>(top_kv2_out->input());
+ }
+ case luci::CircleOpcode::CIRCLEUNIQUEOUT:
+ {
+ const auto unique_out = loco::must_cast<const CircleUniqueOut *>(node);
+ return loco::must_cast<luci::CircleNode *>(unique_out->input());
+ }
+ case luci::CircleOpcode::CIRCLEUNPACKOUT:
+ {
+ const auto unpack_out = loco::must_cast<const CircleUnpackOut *>(node);
+ return loco::must_cast<luci::CircleNode *>(unpack_out->input());
+ }
+ // TODO: Support these ops
+ case luci::CircleOpcode::CIRCLEBIDIRECTIONAL_SEQUENCE_LSTM_OUT:
+ case luci::CircleOpcode::CIRCLECUSTOMOUT:
+ case luci::CircleOpcode::CIRCLEIFOUT:
+ case luci::CircleOpcode::CIRCLENONMAXSUPPRESSIONV4OUT:
+ case luci::CircleOpcode::CIRCLENONMAXSUPPRESSIONV5OUT:
+ case luci::CircleOpcode::CIRCLEWHILEOUT:
+ throw std::runtime_error("Unsupported op now");
+ default:
+ return nullptr;
+ }
+}
+
+bool same_setting(const LayerInfo &left, const LayerInfo &right)
+{
+ return left.dtype == right.dtype and left.granularity == right.granularity;
+}
+
+void add_multi_output_node(LayerInfoMap &info_by_name, LayerInfo &layer_info,
+ const luci::CircleNode *node)
+{
+ assert(is_multiple_output_node(node)); // FIX_CALLER_UNLESS
+
+ const auto succs_nodes = loco::succs(node);
+ const auto name = node->name();
+
+ if (info_by_name.find(name) != info_by_name.end())
+ {
+ // Check that all outputs have equal dtype and granularity
+ for (const auto succs_node : succs_nodes)
+ {
+ const auto succs_circle_node = loco::must_cast<luci::CircleNode *>(succs_node);
+
+ const auto it = info_by_name.find(succs_circle_node->name());
+ if (it != info_by_name.end() and not same_setting(layer_info, (it->second)))
+ throw std::runtime_error("Outputs of multiple-output nodes should have equal dtype and "
+ "granularity. Check the quantization configuration file");
+ }
+ return;
+ }
+
+ // Add multiple output node to info_by_name
+ info_by_name[name] = {name, layer_info.dtype, layer_info.granularity};
+
+ // Add outputs node to info_by_name
+ for (const auto succs_node : succs_nodes)
+ {
+ const auto succs_circle_node = loco::must_cast<luci::CircleNode *>(succs_node);
+ const auto succs_circle_node_name = succs_circle_node->name();
+ info_by_name[succs_circle_node_name] = {succs_circle_node_name, layer_info.dtype,
+ layer_info.granularity};
+ }
+}
+
+} // namespace
+
+LayerInfoMap layer_info_map(loco::Graph *g, std::vector<LayerInfo> &layers_info)
+{
+ LayerInfoMap info_by_name;
+
+ for (auto &&info : layers_info)
+ {
+ auto name = info.name;
+ bool found = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto cnode = loco::must_cast<luci::CircleNode *>(node);
+ if (cnode->opcode() == luci::CircleOpcode::CIRCLEOUTPUT)
+ continue;
+
+ if (cnode->name() == name)
+ {
+ // Check and add multiple-output node and its outputs to info_by_name
+ if (const auto multi_output = get_multi_output_node(cnode))
+ {
+ add_multi_output_node(info_by_name, info, multi_output);
+ found = true;
+ continue;
+ }
+
+ if (info_by_name.find(name) != info_by_name.end())
+ {
+ throw std::runtime_error("Duplicate layer name " + name +
+ ". Check layer names in the quantization configuration file.");
+ }
+
+ info_by_name[name] = info;
+ found = true;
+ continue;
+ }
+ }
+
+ if (not found)
+ throw std::runtime_error("No such layer named " + name +
+ ". Check layer names in the quantization configuration file.");
+ }
+
+ // TODO Check all names in layers_info exist in the info_by_name
+ // TODO Check names in info_by_name but not in layers_info are from virtual outputs
+
+ return info_by_name;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/helpers/LayerInfoMap.h b/compiler/luci/pass/src/helpers/LayerInfoMap.h
new file mode 100644
index 000000000..bb4724a50
--- /dev/null
+++ b/compiler/luci/pass/src/helpers/LayerInfoMap.h
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_PASS_HELPERS_LAYER_INFO_MAP_H__
+#define __LUCI_PASS_HELPERS_LAYER_INFO_MAP_H__
+
+#include <luci/Pass/QuantizationParameters.h>
+
+#include <unordered_map>
+
+namespace luci
+{
+
+using LayerInfoMap = std::unordered_map<std::string, luci::LayerInfo>;
+
+LayerInfoMap layer_info_map(loco::Graph *g, std::vector<LayerInfo> &layers_info);
+
+} // namespace luci
+
+#endif // __LUCI_PASS_HELPERS_LAYER_INFO_MAP_H__
diff --git a/compiler/luci/pass/src/helpers/LayerInfoMap.test.cpp b/compiler/luci/pass/src/helpers/LayerInfoMap.test.cpp
new file mode 100644
index 000000000..2ed28eda4
--- /dev/null
+++ b/compiler/luci/pass/src/helpers/LayerInfoMap.test.cpp
@@ -0,0 +1,201 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "LayerInfoMap.h"
+
+#include <luci/IR/CircleNode.h>
+#include <luci/test/TestIOGraph.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+class SoftmaxTestGraph : public luci::test::TestIOGraph
+{
+public:
+ void init(void)
+ {
+ TestIOGraph::init({32}, {32});
+ _softmax = g()->nodes()->create<luci::CircleSoftmax>();
+ {
+ _softmax->logits(input());
+ _softmax->beta(0.1);
+ _softmax->name("test");
+ }
+ output()->from(_softmax);
+ }
+
+private:
+ luci::CircleSoftmax *_softmax = nullptr;
+};
+
+class SplitAddTestGraph : public luci::test::TestIOGraph
+{
+public:
+ void init(void)
+ {
+ TestIOGraph::init({6, 1, 2}, {3, 1, 2});
+ _split_dim = g()->nodes()->create<luci::CircleConst>();
+ {
+ _split_dim->rank(1);
+ _split_dim->dtype(loco::DataType::S32);
+ _split_dim->size<loco::DataType::S32>(1);
+ _split_dim->at<loco::DataType::S32>(0);
+ _split_dim->shape({1});
+ _split_dim->name("split_dim");
+ }
+
+ _split = g()->nodes()->create<luci::CircleSplit>();
+ {
+ _split->input(input());
+ _split->num_split(2);
+ _split->split_dim(_split_dim);
+ _split->name("split0");
+ }
+
+ _split_out_1 = g()->nodes()->create<luci::CircleSplitOut>();
+ {
+ _split_out_1->input(_split);
+ _split_out_1->index(0);
+ _split_out_1->name("split0");
+ }
+
+ _split_out_2 = g()->nodes()->create<luci::CircleSplitOut>();
+ {
+ _split_out_2->input(_split);
+ _split_out_2->index(1);
+ _split_out_2->name("split1");
+ }
+
+ _add = g()->nodes()->create<luci::CircleAdd>();
+ {
+ _add->x(_split_out_1);
+ _add->y(_split_out_2);
+ _add->name("add");
+ }
+ output()->from(_add);
+ }
+
+private:
+ luci::CircleSplit *_split = nullptr;
+ luci::CircleSplitOut *_split_out_1 = nullptr;
+ luci::CircleSplitOut *_split_out_2 = nullptr;
+ luci::CircleConst *_split_dim = nullptr;
+ luci::CircleAdd *_add = nullptr;
+};
+
+} // namespace
+
+TEST(LayerInfoMapTest, simple_test)
+{
+ SoftmaxTestGraph g;
+ g.init();
+
+ luci::LayerInfo info;
+ {
+ info.name = "test";
+ info.dtype = loco::DataType::U8;
+ info.granularity = luci::QuantizationGranularity::ChannelWise;
+ }
+ std::vector<luci::LayerInfo> v;
+ v.emplace_back(info);
+ auto map = luci::layer_info_map(g.g(), v);
+
+ EXPECT_EQ("test", map["test"].name);
+ EXPECT_EQ(loco::DataType::U8, map["test"].dtype);
+ EXPECT_EQ(luci::QuantizationGranularity::ChannelWise, map["test"].granularity);
+}
+
+TEST(LayerInfoMapTest, multiple_output_node_test)
+{
+ SplitAddTestGraph g;
+ g.init();
+
+ luci::LayerInfo info;
+ {
+ info.name = "split0";
+ info.dtype = loco::DataType::U8;
+ info.granularity = luci::QuantizationGranularity::ChannelWise;
+ }
+ std::vector<luci::LayerInfo> v;
+ v.emplace_back(info);
+ auto map = luci::layer_info_map(g.g(), v);
+
+ EXPECT_EQ(map.size(), 2);
+ EXPECT_EQ("split0", map["split0"].name);
+ EXPECT_EQ("split1", map["split1"].name);
+
+ EXPECT_EQ(loco::DataType::U8, map["split0"].dtype);
+ EXPECT_EQ(luci::QuantizationGranularity::ChannelWise, map["split0"].granularity);
+}
+
+TEST(LayerInfoMapTest, invalid_layer_info_multiple_output_node_NEG)
+{
+ SplitAddTestGraph g;
+ g.init();
+
+ luci::LayerInfo info_0;
+ {
+ info_0.name = "split0";
+ info_0.dtype = loco::DataType::U8;
+ info_0.granularity = luci::QuantizationGranularity::ChannelWise;
+ }
+ luci::LayerInfo info_1;
+ {
+ info_1.name = "split1";
+ info_1.dtype = loco::DataType::S16;
+ info_1.granularity = luci::QuantizationGranularity::ChannelWise;
+ }
+ std::vector<luci::LayerInfo> v;
+ v.emplace_back(info_0);
+ v.emplace_back(info_1);
+
+ EXPECT_ANY_THROW(luci::layer_info_map(g.g(), v));
+}
+
+TEST(LayerInfoMapTest, duplicate_name_NEG)
+{
+ SoftmaxTestGraph g;
+ g.init();
+ g.input()->name("test");
+
+ luci::LayerInfo info;
+ {
+ info.name = "test";
+ info.dtype = loco::DataType::U8;
+ info.granularity = luci::QuantizationGranularity::ChannelWise;
+ }
+ std::vector<luci::LayerInfo> v;
+ v.emplace_back(info);
+ EXPECT_ANY_THROW(luci::layer_info_map(g.g(), v));
+}
+
+TEST(LayerInfoMapTest, no_name_NEG)
+{
+ SoftmaxTestGraph g;
+ g.init();
+
+ luci::LayerInfo info;
+ {
+ info.name = "noname";
+ info.dtype = loco::DataType::U8;
+ info.granularity = luci::QuantizationGranularity::ChannelWise;
+ }
+ std::vector<luci::LayerInfo> v;
+ v.emplace_back(info);
+ EXPECT_ANY_THROW(luci::layer_info_map(g.g(), v));
+}