summaryrefslogtreecommitdiff
path: root/compiler/luci/pass
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/pass')
-rw-r--r--compiler/luci/pass/include/luci/CircleOptimizer.h2
-rw-r--r--compiler/luci/pass/include/luci/CircleQuantizer.h1
-rw-r--r--compiler/luci/pass/include/luci/DynamicBatchToSingleBatch.h29
-rw-r--r--compiler/luci/pass/include/luci/Pass/DecomposeHardSwishPass.h37
-rw-r--r--compiler/luci/pass/include/luci/Pass/DynamicBatchToSingleBatchPass.h39
-rw-r--r--compiler/luci/pass/include/luci/Pass/FuseGeluPass.h39
-rw-r--r--compiler/luci/pass/include/luci/Pass/QuantizeWeightsPass.h70
-rw-r--r--compiler/luci/pass/include/luci/Pass/RequantizePass.h2
-rw-r--r--compiler/luci/pass/src/CircleOptimizer.cpp42
-rw-r--r--compiler/luci/pass/src/CircleQuantizer.cpp39
-rw-r--r--compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp26
-rw-r--r--compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp8
-rw-r--r--compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp4
-rw-r--r--compiler/luci/pass/src/DecomposeHardSwishPass.cpp147
-rw-r--r--compiler/luci/pass/src/DecomposeHardSwishPass.test.cpp205
-rw-r--r--compiler/luci/pass/src/DynamicBatchToSingleBatch.cpp51
-rw-r--r--compiler/luci/pass/src/DynamicBatchToSingleBatchPass.cpp78
-rw-r--r--compiler/luci/pass/src/DynamicBatchToSingleBatchPass.test.cpp126
-rw-r--r--compiler/luci/pass/src/FoldAddV2Pass.test.cpp8
-rw-r--r--compiler/luci/pass/src/FoldCastPass.test.cpp4
-rw-r--r--compiler/luci/pass/src/FoldDequantizePass.test.cpp4
-rw-r--r--compiler/luci/pass/src/FuseActivationFunctionPass.cpp5
-rw-r--r--compiler/luci/pass/src/FuseAddWithFullyConnectedPass.test.cpp54
-rw-r--r--compiler/luci/pass/src/FuseAddWithTConvPass.cpp3
-rw-r--r--compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp6
-rw-r--r--compiler/luci/pass/src/FuseGeluPass.cpp347
-rw-r--r--compiler/luci/pass/src/FuseGeluPass.test.cpp251
-rw-r--r--compiler/luci/pass/src/PropagateQParamBackwardPass.cpp63
-rw-r--r--compiler/luci/pass/src/PropagateQParamBackwardPass.test.cpp143
-rw-r--r--compiler/luci/pass/src/QuantizationUtils.cpp19
-rw-r--r--compiler/luci/pass/src/QuantizationUtils.h8
-rw-r--r--compiler/luci/pass/src/QuantizeActivation.cpp10
-rw-r--r--compiler/luci/pass/src/QuantizeActivation.h3
-rw-r--r--compiler/luci/pass/src/QuantizeBias.test.cpp47
-rw-r--r--compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp8
-rw-r--r--compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp1
-rw-r--r--compiler/luci/pass/src/QuantizeWeights.cpp9
-rw-r--r--compiler/luci/pass/src/QuantizeWeightsOnly.cpp224
-rw-r--r--compiler/luci/pass/src/QuantizeWeightsOnly.h51
-rw-r--r--compiler/luci/pass/src/QuantizeWeightsPass.cpp46
-rw-r--r--compiler/luci/pass/src/QuantizeWeightsPass.test.cpp123
-rw-r--r--compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp16
-rw-r--r--compiler/luci/pass/src/QuantizedModelVerifier.test.cpp73
-rw-r--r--compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp52
-rw-r--r--compiler/luci/pass/src/ReplaceSubWithAddPass.cpp2
-rw-r--r--compiler/luci/pass/src/RequantizePass.cpp159
-rw-r--r--compiler/luci/pass/src/RequantizePass.test.cpp156
-rw-r--r--compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp58
-rw-r--r--compiler/luci/pass/src/SubstituteSplitVToSplitPass.test.cpp51
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedBiasScale.cpp2
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h14
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeType.cpp15
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeType.h2
-rw-r--r--compiler/luci/pass/src/helpers/CreateCircleConst.cpp20
-rw-r--r--compiler/luci/pass/src/helpers/CreateCircleConst.h88
-rw-r--r--compiler/luci/pass/src/helpers/TypeMapper.h5
56 files changed, 2668 insertions, 427 deletions
diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h
index d77e89db1..6ebacee39 100644
--- a/compiler/luci/pass/include/luci/CircleOptimizer.h
+++ b/compiler/luci/pass/include/luci/CircleOptimizer.h
@@ -63,6 +63,7 @@ public:
MakeBatchNormGammaPositive,
FuseActivationFunction,
FusePRelu,
+ FuseGelu,
ShuffleWeightTo16x1Float32,
RemoveRedundantTranspose,
ReplaceMulAddWithDepthwiseConv,
@@ -80,6 +81,7 @@ public:
RemoveUnnecessaryReshape,
TransformMinMaxToRelu6Pass,
TransformMinReluToRelu6Pass,
+ DecomposeHardSwishPass,
SubstituteStridedSliceToReshape,
SubstituteTransposeToReshape,
RemoveRedundantQuantize,
diff --git a/compiler/luci/pass/include/luci/CircleQuantizer.h b/compiler/luci/pass/include/luci/CircleQuantizer.h
index 4e7074d98..463f31790 100644
--- a/compiler/luci/pass/include/luci/CircleQuantizer.h
+++ b/compiler/luci/pass/include/luci/CircleQuantizer.h
@@ -45,6 +45,7 @@ public:
CopyQuantParam,
ForceQuantParam,
ConvertToFakeQuantizedModel,
+ QuantizeWeights,
};
enum AlgorithmParameters
diff --git a/compiler/luci/pass/include/luci/DynamicBatchToSingleBatch.h b/compiler/luci/pass/include/luci/DynamicBatchToSingleBatch.h
new file mode 100644
index 000000000..2a02777f6
--- /dev/null
+++ b/compiler/luci/pass/include/luci/DynamicBatchToSingleBatch.h
@@ -0,0 +1,29 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_DYNAMIC_BATCH_TO_SINGLE_BATCH_H__
+#define __LUCI_DYNAMIC_BATCH_TO_SINGLE_BATCH_H__
+
+#include <luci/IR/Module.h>
+
+namespace luci
+{
+
+void dynamic_batch_to_single_batch(luci::Module *);
+
+} // namespace luci
+
+#endif // __LUCI_DYNAMIC_BATCH_TO_SINGLE_BATCH_H__
diff --git a/compiler/luci/pass/include/luci/Pass/DecomposeHardSwishPass.h b/compiler/luci/pass/include/luci/Pass/DecomposeHardSwishPass.h
new file mode 100644
index 000000000..83c16bcee
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/DecomposeHardSwishPass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_DECOMPOSE_HARDSWISH_PASS_H__
+#define __LUCI_DECOMPOSE_HARDSWISH_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to decompose HardSwish to Add, Mul and Relu6
+ */
+struct DecomposeHardSwishPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::DecomposeHardSwishPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_DECOMPOSE_HARDSWISH_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/DynamicBatchToSingleBatchPass.h b/compiler/luci/pass/include/luci/Pass/DynamicBatchToSingleBatchPass.h
new file mode 100644
index 000000000..b3598c986
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/DynamicBatchToSingleBatchPass.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_DYNAMIC_BATCH_TO_SINGLE_BATCH_PASS_H__
+#define __LUCI_DYNAMIC_BATCH_TO_SINGLE_BATCH_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Pass to convert dynamic batch to single batch
+ */
+class DynamicBatchToSingleBatchPass : public logo::Pass
+{
+public:
+ virtual const char *name(void) const { return "luci::DynamicBatchToSingleBatchPass"; }
+
+public:
+ bool run(loco::Graph *graph);
+};
+
+} // namespace luci
+
+#endif //__LUCI_DYNAMIC_BATCH_TO_SINGLE_BATCH_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/FuseGeluPass.h b/compiler/luci/pass/include/luci/Pass/FuseGeluPass.h
new file mode 100644
index 000000000..5fa23036c
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/FuseGeluPass.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_FUSE_GELU_PASS_H__
+#define __LUCI_FUSE_GELU_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to fuse certain pattern of subgraph into CircleGelu
+ *
+ * For detailed subgraph pattern to be fused, please check its implementation.
+ */
+struct FuseGeluPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::FuseGeluPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_FUSE_GELU_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/QuantizeWeightsPass.h b/compiler/luci/pass/include/luci/Pass/QuantizeWeightsPass.h
new file mode 100644
index 000000000..646597312
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/QuantizeWeightsPass.h
@@ -0,0 +1,70 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_QUANTIZE_WEIGHTS_PASS_H__
+#define __LUCI_QUANTIZE_WEIGHTS_PASS_H__
+
+#include <loco.h>
+
+#include <logo/Pass.h>
+
+#include <luci/Pass/QuantizationParameters.h>
+
+namespace luci
+{
+
+/**
+ * @brief Pass to quantize weights
+ */
+class QuantizeWeightsPass : public logo::Pass
+{
+public:
+ struct Context
+ {
+ loco::DataType input_model_dtype = loco::DataType::Unknown;
+ loco::DataType output_model_dtype = loco::DataType::Unknown;
+ QuantizationGranularity granularity = QuantizationGranularity::ChannelWise;
+ };
+
+public:
+ QuantizeWeightsPass(std::unique_ptr<Context> &&ctx) : _ctx{std::move(ctx)}
+ {
+ // DO NOTHING
+ }
+
+public:
+ QuantizeWeightsPass(loco::DataType input_model_dtype, loco::DataType output_model_dtype,
+ QuantizationGranularity granularity)
+ {
+ _ctx = std::make_unique<Context>();
+ {
+ _ctx->input_model_dtype = input_model_dtype;
+ _ctx->output_model_dtype = output_model_dtype;
+ _ctx->granularity = granularity;
+ }
+ }
+ virtual const char *name(void) const { return "luci::QuantizeWeightsPass"; }
+
+public:
+ bool run(loco::Graph *graph);
+
+private:
+ std::unique_ptr<Context> _ctx;
+};
+
+} // namespace luci
+
+#endif //__LUCI_QUANTIZE_WEIGHTS_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/RequantizePass.h b/compiler/luci/pass/include/luci/Pass/RequantizePass.h
index c6c424f1b..50b9073b5 100644
--- a/compiler/luci/pass/include/luci/Pass/RequantizePass.h
+++ b/compiler/luci/pass/include/luci/Pass/RequantizePass.h
@@ -27,7 +27,7 @@ namespace luci
{
/**
- * @brief Pass to quantize weights
+ * @brief Pass to re-quantize graph (ex: int8 -> uint8)
*/
class RequantizePass : public logo::Pass
{
diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp
index 5e1613ad9..b011581af 100644
--- a/compiler/luci/pass/src/CircleOptimizer.cpp
+++ b/compiler/luci/pass/src/CircleOptimizer.cpp
@@ -39,6 +39,7 @@
#include "luci/Pass/FuseMeanWithMeanPass.h"
#include "luci/Pass/FusePreActivationBatchNormPass.h"
#include "luci/Pass/FusePReluPass.h"
+#include "luci/Pass/FuseGeluPass.h"
#include "luci/Pass/FuseTransposeWithMeanPass.h"
#include "luci/Pass/MakeBatchNormGammaPositivePass.h"
#include "luci/Pass/RemoveDuplicateConstPass.h"
@@ -70,6 +71,7 @@
#include "luci/Pass/SubstituteTransposeToReshapePass.h"
#include "luci/Pass/TransformMinMaxToRelu6Pass.h"
#include "luci/Pass/TransformMinReluToRelu6Pass.h"
+#include "luci/Pass/DecomposeHardSwishPass.h"
#include "luci/Pass/UnrollUnidirectionalSequenceLSTMPass.h"
// TODO add more passes
@@ -137,7 +139,8 @@ bool OptimizeOptionsImpl::query(Algorithm algo)
}
// TODO Make a struct for args
-void convert_nchw_to_nhwc(loco::Graph *g, bool preserve_input, bool preserve_output, bool fuse_fc)
+void convert_nchw_to_nhwc(loco::Graph *g, bool preserve_input, bool preserve_output, bool fuse_fc,
+ bool fuse_gelu)
{
logo::Phase phase;
@@ -160,6 +163,12 @@ void convert_nchw_to_nhwc(loco::Graph *g, bool preserve_input, bool preserve_out
if (fuse_fc)
phase.emplace_back(std::make_unique<luci::FuseAddWithFullyConnectedPass>());
+ // Fuse decomposed ops to Gelu Op
+ // Why here? ConverNCHWToNHWCPass inserts additional Ops, so it is better to fuse
+ // Gelu in advance.
+ if (fuse_gelu)
+ phase.emplace_back(std::make_unique<luci::FuseGeluPass>());
+
phase.emplace_back(
std::make_unique<luci::ConvertNCHWToNHWCPass>(preserve_input, preserve_output));
@@ -216,8 +225,9 @@ void CircleOptimizer::optimize(loco::Graph *g) const
_options->param(Options::AlgorithmParameters::NCHW_to_NHWC_output_shape) != "true";
bool fuse_fc = _options->query(Options::Algorithm::FuseAddWithFullyConnected);
+ bool fuse_gelu = _options->query(Options::Algorithm::FuseGelu);
- convert_nchw_to_nhwc(g, preserve_input, preserve_output, fuse_fc);
+ convert_nchw_to_nhwc(g, preserve_input, preserve_output, fuse_fc, fuse_gelu);
}
/* TRANSFORM DECLARATION BEGIN */
@@ -283,6 +293,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<FusePReluPass>());
}
+ if (_options->query(Options::Algorithm::FuseGelu))
+ {
+ phase.emplace_back(std::make_unique<FuseGeluPass>());
+ }
if (_options->query(Options::Algorithm::FuseTransposeWithMean))
{
phase.emplace_back(std::make_unique<FuseTransposeWithMeanPass>());
@@ -319,14 +333,6 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<luci::FoldSparseToDensePass>());
}
- if (_options->query(Options::Algorithm::ForwardReshapeToUnaryOp))
- {
- phase.emplace_back(std::make_unique<luci::ForwardReshapeToUnaryOpPass>());
- }
- if (_options->query(Options::Algorithm::ForwardTransposeOp))
- {
- phase.emplace_back(std::make_unique<luci::ForwardTransposeOpPass>());
- }
if (_options->query(Options::Algorithm::FusePreActivationBatchNorm))
{
phase.emplace_back(std::make_unique<luci::FusePreActivationBatchNormPass>());
@@ -428,10 +434,26 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<luci::TransformMinReluToRelu6Pass>());
}
+ if (_options->query(Options::Algorithm::DecomposeHardSwishPass))
+ {
+ phase.emplace_back(std::make_unique<luci::DecomposeHardSwishPass>());
+ }
if (_options->query(Options::Algorithm::UnrollUnidirSeqLSTM))
{
phase.emplace_back(std::make_unique<luci::UnrollUnidirectionalSequenceLSTMPass>());
}
+ // Forward Reshape/Transpose is done after
+ // 1. SubstituteXXXToReshape
+ // 2. RemoveRedundantReshape/Transpose
+ // See https://github.com/Samsung/ONE/pull/10596 for more details
+ if (_options->query(Options::Algorithm::ForwardReshapeToUnaryOp))
+ {
+ phase.emplace_back(std::make_unique<luci::ForwardReshapeToUnaryOpPass>());
+ }
+ if (_options->query(Options::Algorithm::ForwardTransposeOp))
+ {
+ phase.emplace_back(std::make_unique<luci::ForwardTransposeOpPass>());
+ }
/* TRANSFORM DECLARATION END */
diff --git a/compiler/luci/pass/src/CircleQuantizer.cpp b/compiler/luci/pass/src/CircleQuantizer.cpp
index 3ffa1180c..9039a839f 100644
--- a/compiler/luci/pass/src/CircleQuantizer.cpp
+++ b/compiler/luci/pass/src/CircleQuantizer.cpp
@@ -26,6 +26,7 @@
#include "luci/Pass/QuantizePreCheckerPass.h"
#include "luci/Pass/QuantizeWithMinMaxPass.h"
#include "luci/Pass/QuantizeDequantizeWeightsPass.h"
+#include "luci/Pass/QuantizeWeightsPass.h"
#include "luci/Pass/CircleShapeInferencePass.h"
#include "luci/Pass/CircleTypeInferencePass.h"
@@ -439,14 +440,14 @@ void CircleQuantizer::quantize(loco::Graph *g) const
throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
to_string(qwmm_supported_granularity));
- for (auto dtype : input_type_vec)
+ for (const auto &dtype : input_type_vec)
{
if (!in_array(to_lower_case(dtype), qwmm_supported_input_type))
throw std::runtime_error("Unsupported input type. List of supported input types: " +
to_string(qwmm_supported_input_type));
}
- for (auto dtype : output_type_vec)
+ for (const auto &dtype : output_type_vec)
{
if (!in_array(to_lower_case(dtype), qwmm_supported_output_type))
throw std::runtime_error("Unsupported output type. List of supported output types: " +
@@ -536,6 +537,40 @@ void CircleQuantizer::quantize(loco::Graph *g) const
verifier.verify(g);
}
+ if (_options->query(Options::Algorithm::QuantizeWeights))
+ {
+ static const std::vector<std::string> qw_supported_input_model_dtype{"float32"};
+ static const std::vector<std::string> qw_supported_output_model_dtype{"int8", "int16"};
+ static const std::vector<std::string> qw_supported_granularity{"channel"};
+
+ auto input_model_dtype =
+ _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
+ auto output_model_dtype =
+ _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
+ auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
+
+ if (!in_array(to_lower_case(input_model_dtype), qw_supported_input_model_dtype))
+ throw std::runtime_error("Unsupported input type. List of supported input type: " +
+ to_string(qw_supported_input_model_dtype));
+
+ if (!in_array(to_lower_case(output_model_dtype), qw_supported_output_model_dtype))
+ throw std::runtime_error("Unsupported output type. List of supported output type: " +
+ to_string(qw_supported_output_model_dtype));
+
+ if (!in_array(to_lower_case(granularity), qw_supported_granularity))
+ throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
+ to_string(qw_supported_granularity));
+ auto ctx = std::make_unique<luci::QuantizeWeightsPass::Context>();
+ {
+ ctx->input_model_dtype = str_to_dtype(input_model_dtype);
+ ctx->output_model_dtype = str_to_dtype(output_model_dtype);
+ ctx->granularity = str_to_granularity(granularity);
+ }
+ luci::QuantizeWeightsPass weights_quantizer(std::move(ctx));
+
+ weights_quantizer.run(g);
+ }
+
// Requantize
if (_options->query(Options::Algorithm::Requantize))
{
diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp
index 99e1e2939..ac4320246 100644
--- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp
+++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp
@@ -55,16 +55,18 @@ bool broadcastable(const luci::CircleConst *from, const luci::CircleNode *to)
return true;
}
-// Expand node to rank 4
+// Return node with rank 4
// node should have rank less than or equal to 4
-void expand_to_rank_4(luci::CircleConst *node)
+// 1 is inserted to the front of shape if rank is less than 4
+// For example, [2] -> [1, 1, 1, 2]
+luci::CircleConst *expand_to_rank_4(luci::CircleConst *node)
{
auto original_rank = node->rank();
assert(original_rank <= 4); // FIX_CALLER_UNLESS
if (original_rank == 4)
- return;
+ return node;
std::vector<uint32_t> original_shape;
for (uint32_t i = 0; i < original_rank; i++)
@@ -72,12 +74,17 @@ void expand_to_rank_4(luci::CircleConst *node)
original_shape.emplace_back(node->dim(i).value());
}
- node->rank(4);
+ auto cloned = luci::clone(node);
+ cloned->name(cloned->name() + "_rank4");
+
+ cloned->rank(4);
for (uint32_t i = 0; i < (4 - original_rank); i++)
- node->dim(i) = 1;
+ cloned->dim(i) = 1;
for (uint32_t i = 0; i < original_rank; i++)
- node->dim(i + (4 - original_rank)) = original_shape.at(i);
+ cloned->dim(i + (4 - original_rank)) = original_shape.at(i);
+
+ return cloned;
}
bool is_output(const loco::Node *node)
@@ -564,7 +571,7 @@ bool is_NCHW_with_const(const luci::CircleMul *node, luci::CircleNode *&pred_nod
if (not broadcastable(multiplier, node))
return false;
- expand_to_rank_4(multiplier);
+ multiplier = expand_to_rank_4(multiplier);
return true;
}
@@ -602,7 +609,7 @@ bool is_NCHW_with_const(const luci::CircleAdd *node, luci::CircleNode *&pred_nod
if (not broadcastable(beta, node))
return false;
- expand_to_rank_4(beta);
+ beta = expand_to_rank_4(beta);
return true;
}
@@ -834,6 +841,8 @@ class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool>
bool visit(luci::CircleElu *node) { return convert_unary_features<luci::CircleElu>(node); }
+ bool visit(luci::CircleGelu *node) { return convert_unary_features<luci::CircleGelu>(node); }
+
bool visit(luci::CircleLeakyRelu *node)
{
return convert_unary_features<luci::CircleLeakyRelu>(node);
@@ -1510,6 +1519,7 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g)
case luci::CircleOpcode::ADD:
case luci::CircleOpcode::CONCATENATION:
case luci::CircleOpcode::ELU:
+ case luci::CircleOpcode::GELU:
case luci::CircleOpcode::LEAKY_RELU:
case luci::CircleOpcode::LOGISTIC:
case luci::CircleOpcode::MAXIMUM:
diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
index fd326518e..85648cf2c 100644
--- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
+++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
@@ -535,6 +535,8 @@ public:
luci::CircleMaximum *max = nullptr;
};
+static constexpr std::initializer_list<uint32_t> kDefaultShape = {1, 16, 1, 1};
+
class MeanGraph final : public SimpleGraph
{
protected:
@@ -577,7 +579,7 @@ public:
private:
bool _keep_dims = true;
std::vector<int32_t> _axes = {2, 3};
- std::initializer_list<uint32_t> _shape = {1, 16, 1, 1};
+ std::initializer_list<uint32_t> _shape = kDefaultShape;
};
class MinimumGraph final : public SimpleGraph
@@ -876,7 +878,7 @@ public:
private:
bool _keep_dims = true;
std::vector<int32_t> _axes = {2, 3};
- std::initializer_list<uint32_t> _shape = {1, 16, 1, 1};
+ std::initializer_list<uint32_t> _shape = kDefaultShape;
};
class ReduceMinGraph final : public SimpleGraph
@@ -921,7 +923,7 @@ public:
private:
bool _keep_dims = true;
std::vector<int32_t> _axes = {2, 3};
- std::initializer_list<uint32_t> _shape = {1, 16, 1, 1};
+ std::initializer_list<uint32_t> _shape = kDefaultShape;
};
class ReluGraph final : public SimpleGraph
diff --git a/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp
index aacfce3d0..ae5ab1519 100644
--- a/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp
+++ b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp
@@ -198,6 +198,7 @@ struct FakeQuantize final : public luci::CircleNodeMutableVisitor<void>
void visit(luci::CircleDepthwiseConv2D *node) { fq_activation(node); }
void visit(luci::CircleDiv *node) { fq_activation(node); }
void visit(luci::CircleFullyConnected *node) { fq_activation(node); }
+ void visit(luci::CircleGelu *node) { fq_activation(node); }
void visit(luci::CircleInstanceNorm *node) { fq_activation(node); }
void visit(luci::CircleLeakyRelu *node) { fq_activation(node); }
void visit(luci::CircleLogistic *node) { fq_activation(node); }
@@ -217,6 +218,9 @@ struct FakeQuantize final : public luci::CircleNodeMutableVisitor<void>
void visit(luci::CircleRsqrt *node) { fq_activation(node); }
void visit(luci::CircleSoftmax *node) { fq_activation(node); }
void visit(luci::CircleSqrt *node) { fq_activation(node); }
+ void visit(luci::CircleSquaredDifference *node) { fq_activation(node); }
+ void visit(luci::CircleSub *node) { fq_activation(node); }
+ void visit(luci::CircleSum *node) { fq_activation(node); }
void visit(luci::CircleTanh *node) { fq_activation(node); }
void visit(luci::CircleTransposeConv *node) { fq_activation(node); }
diff --git a/compiler/luci/pass/src/DecomposeHardSwishPass.cpp b/compiler/luci/pass/src/DecomposeHardSwishPass.cpp
new file mode 100644
index 000000000..bd99d2de0
--- /dev/null
+++ b/compiler/luci/pass/src/DecomposeHardSwishPass.cpp
@@ -0,0 +1,147 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/DecomposeHardSwishPass.h"
+
+#include "helpers/NodeFiller.h"
+#include "helpers/TypeMapper.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Profile/CircleNodeOrigin.h>
+
+namespace
+{
+/**
+ * BEFORE
+ * [CircleNode]
+ * |
+ * |
+ * [CircleHardSwish]
+ * |
+ * |
+ * [CircleNode]
+ *
+ *
+ * AFTER
+ *
+ * [CircleNode] [CircleConst]
+ * | \ /
+ * | \ /
+ * | [CircleAdd]
+ * | |
+ * | |
+ * \ [CircleRelu6] [CircleConst]
+ * \ \ /
+ * \ \ /
+ * \ [CircleMul]
+ * \ /
+ * \ /
+ * [CircleMul]
+ * |
+ * |
+ * [CircleNode]
+ *
+ */
+bool decompose_hardswish(luci::CircleHardSwish *hardswish)
+{
+ if (not hardswish)
+ return false;
+
+ if (hardswish->dtype() != loco::DataType::FLOAT32)
+ return false;
+
+ auto g = hardswish->graph();
+
+ auto name = hardswish->name();
+ assert(name.length() > 0);
+
+ // Create a const for CircleAdd operation
+ auto add_const = g->nodes()->create<luci::CircleConst>();
+ add_const->shape({}); // scalar
+ add_const->dtype(loco::DataType::FLOAT32);
+ add_const->rank(0);
+ add_const->size<loco::DataType::FLOAT32>(1);
+ add_const->at<loco::DataType::FLOAT32>(0) = 3.;
+ add_const->name(name + "/Add/const");
+ luci::add_origin(add_const, luci::get_origin(hardswish));
+
+ // Create an Add operation
+ auto add = g->nodes()->create<luci::CircleAdd>();
+ add->fusedActivationFunction(luci::FusedActFunc::NONE);
+ add->x(hardswish->features());
+ add->y(add_const);
+ add->name(name + "/Add");
+ luci::add_origin(add, luci::get_origin(hardswish));
+
+ // Create a Relu6 operation
+ auto relu6 = g->nodes()->create<luci::CircleRelu6>();
+ relu6->features(add);
+ relu6->name(name + "/Relu6");
+ luci::add_origin(relu6, luci::get_origin(hardswish));
+
+ // Create a const for CircleMul operation
+ auto mul_const = g->nodes()->create<luci::CircleConst>();
+ mul_const->shape({}); // scalar
+ mul_const->dtype(loco::DataType::FLOAT32);
+ mul_const->rank(0);
+ mul_const->size<loco::DataType::FLOAT32>(1);
+ mul_const->at<loco::DataType::FLOAT32>(0) = 1. / 6.;
+ mul_const->name(name + "/Mul/const");
+ luci::add_origin(mul_const, luci::get_origin(hardswish));
+
+ // Create first Mul operation
+ auto mul1 = g->nodes()->create<luci::CircleMul>();
+ mul1->fusedActivationFunction(luci::FusedActFunc::NONE);
+ mul1->x(relu6);
+ mul1->y(mul_const);
+ mul1->name(name + "/Mul1");
+ luci::add_origin(mul1, luci::get_origin(hardswish));
+
+ // Create second Mul operation
+ auto mul2 = g->nodes()->create<luci::CircleMul>();
+ mul2->fusedActivationFunction(luci::FusedActFunc::NONE);
+ mul2->x(hardswish->features());
+ mul2->y(mul1);
+ mul2->name(name + "/Mul2");
+ luci::add_origin(mul2, luci::get_origin(hardswish));
+
+ replace(hardswish).with(mul2);
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool DecomposeHardSwishPass::run(loco::Graph *g)
+{
+ bool changed = false;
+
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto hardswish = dynamic_cast<luci::CircleHardSwish *>(node))
+ {
+ if (decompose_hardswish(hardswish))
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/DecomposeHardSwishPass.test.cpp b/compiler/luci/pass/src/DecomposeHardSwishPass.test.cpp
new file mode 100644
index 000000000..d51a07fdc
--- /dev/null
+++ b/compiler/luci/pass/src/DecomposeHardSwishPass.test.cpp
@@ -0,0 +1,205 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/DecomposeHardSwishPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+/**
+ * HardSwish graph
+ *
+ * [CircleInput]
+ * |
+ * |
+ * [CircleHardSwish]
+ * |
+ * |
+ * [CircleOutput]
+ */
+struct HardSwishGraph
+{
+ loco::Graph _g;
+ luci::CircleInput *_input = nullptr;
+ luci::CircleHardSwish *_hardswish = nullptr;
+ luci::CircleOutput *_output = nullptr;
+};
+
+class DecomposeHardSwishPass : public ::testing::Test
+{
+protected:
+ void MakeGraph()
+ {
+ const int N = 1;
+ const int H = 4;
+ const int W = 4;
+ const int C = 3;
+
+ // graph input and output
+ auto graph_input = _hardswish_g._g.inputs()->create();
+ auto graph_output = _hardswish_g._g.outputs()->create();
+
+ // CircleInput
+ _hardswish_g._input = _hardswish_g._g.nodes()->create<luci::CircleInput>();
+ _hardswish_g._input->index(graph_input->index());
+ _hardswish_g._input->shape({N, H, W, C});
+ _hardswish_g._input->dtype(loco::DataType::FLOAT32);
+ _hardswish_g._input->name("input");
+
+ // CircleHardSwish
+ _hardswish_g._hardswish = _hardswish_g._g.nodes()->create<luci::CircleHardSwish>();
+ _hardswish_g._hardswish->features(_hardswish_g._input);
+ _hardswish_g._hardswish->shape({N, H, W, C});
+ _hardswish_g._hardswish->dtype(loco::DataType::FLOAT32);
+ _hardswish_g._hardswish->name("hardswish");
+
+ // CircleOutput
+ _hardswish_g._output = _hardswish_g._g.nodes()->create<luci::CircleOutput>();
+ _hardswish_g._output->index(graph_output->index());
+ _hardswish_g._output->from(_hardswish_g._hardswish);
+ _hardswish_g._output->shape({N, H, W, C});
+ _hardswish_g._output->dtype(loco::DataType::FLOAT32);
+ _hardswish_g._output->name("output");
+ }
+
+ void MakeInt32Graph()
+ {
+ const int N = 1;
+ const int H = 4;
+ const int W = 4;
+ const int C = 3;
+
+ // graph input and output
+ auto graph_input = _hardswish_int32_g._g.inputs()->create();
+ auto graph_output = _hardswish_int32_g._g.outputs()->create();
+
+ // CircleInput
+ _hardswish_int32_g._input = _hardswish_int32_g._g.nodes()->create<luci::CircleInput>();
+ _hardswish_int32_g._input->index(graph_input->index());
+ _hardswish_int32_g._input->shape({N, H, W, C});
+ _hardswish_int32_g._input->dtype(loco::DataType::S32);
+ _hardswish_int32_g._input->name("input");
+
+ // CircleHardSwish
+ _hardswish_int32_g._hardswish = _hardswish_int32_g._g.nodes()->create<luci::CircleHardSwish>();
+ _hardswish_int32_g._hardswish->features(_hardswish_int32_g._input);
+ _hardswish_int32_g._hardswish->shape({N, H, W, C});
+ _hardswish_int32_g._hardswish->dtype(loco::DataType::S32);
+ _hardswish_int32_g._hardswish->name("hardswish");
+
+ // CircleOutput
+ _hardswish_int32_g._output = _hardswish_int32_g._g.nodes()->create<luci::CircleOutput>();
+ _hardswish_int32_g._output->index(graph_output->index());
+ _hardswish_int32_g._output->from(_hardswish_int32_g._hardswish);
+ _hardswish_int32_g._output->shape({N, H, W, C});
+ _hardswish_int32_g._output->dtype(loco::DataType::S32);
+ _hardswish_int32_g._output->name("output");
+ }
+
+ virtual void SetUp()
+ {
+ MakeGraph();
+ MakeInt32Graph();
+ }
+
+protected:
+ luci::DecomposeHardSwishPass _pass;
+ HardSwishGraph _hardswish_g;
+ HardSwishGraph _hardswish_int32_g;
+};
+
+} // namespace
+
+TEST_F(DecomposeHardSwishPass, name)
+{
+ auto const name = _pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+/**
+ * Decomposed graph looks like below.
+ *
+ * [CircleInput] [CircleConst]
+ * | \ /
+ * | \ /
+ * | [CircleAdd]
+ * | |
+ * | |
+ * \ [CircleRelu6] [CircleConst]
+ * \ \ /
+ * \ \ /
+ * \ [CircleMul]
+ * \ /
+ * \ /
+ * [CircleMul]
+ * |
+ * |
+ * [CircleOutput]
+ *
+ */
+TEST_F(DecomposeHardSwishPass, simple_test)
+{
+ auto ret = _pass.run(&_hardswish_g._g);
+ EXPECT_TRUE(ret);
+
+ auto mul2 = dynamic_cast<luci::CircleMul *>(_hardswish_g._output->from());
+ EXPECT_NE(nullptr, mul2);
+
+ auto input2 = dynamic_cast<luci::CircleInput *>(mul2->x());
+ EXPECT_NE(nullptr, input2);
+
+ auto mul1 = dynamic_cast<luci::CircleMul *>(mul2->y());
+ EXPECT_NE(nullptr, mul1);
+
+ auto relu6 = dynamic_cast<luci::CircleRelu6 *>(mul1->x());
+ EXPECT_NE(nullptr, relu6);
+
+ auto mul_const = dynamic_cast<luci::CircleConst *>(mul1->y());
+ EXPECT_NE(nullptr, mul_const);
+ EXPECT_FLOAT_EQ(1. / 6., mul_const->at<loco::DataType::FLOAT32>(0));
+
+ auto add = dynamic_cast<luci::CircleAdd *>(relu6->features());
+ EXPECT_NE(nullptr, add);
+
+ auto input1 = dynamic_cast<luci::CircleInput *>(add->x());
+ EXPECT_NE(nullptr, input1);
+
+ auto add_const = dynamic_cast<luci::CircleConst *>(add->y());
+ EXPECT_NE(nullptr, add_const);
+ EXPECT_FLOAT_EQ(3., add_const->at<loco::DataType::FLOAT32>(0));
+}
+
+TEST_F(DecomposeHardSwishPass, check_last_node)
+{
+ auto ret = _pass.run(&_hardswish_g._g);
+ EXPECT_TRUE(ret);
+
+ auto hardswish = dynamic_cast<luci::CircleHardSwish *>(_hardswish_g._output->from());
+ EXPECT_EQ(nullptr, hardswish);
+}
+
+TEST_F(DecomposeHardSwishPass, wrong_condition_NEG)
+{
+ auto ret = _pass.run(&_hardswish_int32_g._g);
+ EXPECT_FALSE(ret);
+
+ auto hardswish = dynamic_cast<luci::CircleHardSwish *>(_hardswish_g._output->from());
+ EXPECT_NE(nullptr, hardswish);
+}
diff --git a/compiler/luci/pass/src/DynamicBatchToSingleBatch.cpp b/compiler/luci/pass/src/DynamicBatchToSingleBatch.cpp
new file mode 100644
index 000000000..86876063a
--- /dev/null
+++ b/compiler/luci/pass/src/DynamicBatchToSingleBatch.cpp
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/DynamicBatchToSingleBatch.h"
+
+#include "luci/Pass/DynamicBatchToSingleBatchPass.h"
+#include "luci/Pass/CircleShapeInferencePass.h"
+
+#include "ProgressReporter.h"
+
+#include <logo/Phase.h>
+
+namespace luci
+{
+
+void dynamic_batch_to_single_batch(luci::Module *m)
+{
+ assert(m); // FIX CALLER UNLESS
+
+ for (uint32_t i = 0; i < m->size(); i++)
+ {
+ auto g = m->graph(i);
+
+ logo::Phase phase;
+
+ phase.emplace_back(std::make_unique<luci::DynamicBatchToSingleBatchPass>());
+
+ // Needed to infer shapes of other nodes
+ phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
+
+ ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
+ logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
+ phase_runner.attach(&prog);
+ phase_runner.run(phase);
+ }
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/DynamicBatchToSingleBatchPass.cpp b/compiler/luci/pass/src/DynamicBatchToSingleBatchPass.cpp
new file mode 100644
index 000000000..59a9f5ab3
--- /dev/null
+++ b/compiler/luci/pass/src/DynamicBatchToSingleBatchPass.cpp
@@ -0,0 +1,78 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/DynamicBatchToSingleBatchPass.h"
+
+#include <luci/IR/CircleNode.h>
+#include <loco.h>
+
+namespace luci
+{
+
+bool DynamicBatchToSingleBatchPass::run(loco::Graph *g)
+{
+ assert(g); // FIX CALLER UNLESS
+
+ bool changed = false;
+
+ auto graph_inputs = g->inputs();
+
+ // Assume the first dimension is batch dimension
+ const uint32_t BATCH_DIM = 0;
+
+ for (auto node : loco::input_nodes(g))
+ {
+ auto input_node = loco::must_cast<luci::CircleInput *>(node);
+
+ if (input_node->rank() == 0)
+ continue;
+
+ // Skip if batch dimension is known
+ if (input_node->dim(BATCH_DIM).known())
+ continue;
+
+ if (input_node->rank() != 4)
+ {
+ // Limit use only for rank 4 inputs (for NHWC and NCHW)
+ // TODO Enable this if necessary
+ throw std::runtime_error("First dimension of input is unknown, but its rank is not 4.");
+ }
+
+ // 'set' will make the dimension known
+ input_node->dim(BATCH_DIM).set(1);
+
+ // Update graph input
+ auto graph_input = graph_inputs->at(input_node->index());
+ auto graph_input_shape = graph_input->shape();
+ auto tensor_shape = std::make_unique<loco::TensorShape>();
+ {
+ tensor_shape->rank(graph_input_shape->rank());
+ for (uint32_t i = 0; i < tensor_shape->rank(); i++)
+ {
+ tensor_shape->dim(i) = graph_input_shape->dim(i);
+ }
+ tensor_shape->dim(BATCH_DIM).set(1);
+ }
+
+ graph_input->shape(std::move(tensor_shape));
+
+ changed = true;
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/DynamicBatchToSingleBatchPass.test.cpp b/compiler/luci/pass/src/DynamicBatchToSingleBatchPass.test.cpp
new file mode 100644
index 000000000..f19f57d17
--- /dev/null
+++ b/compiler/luci/pass/src/DynamicBatchToSingleBatchPass.test.cpp
@@ -0,0 +1,126 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/DynamicBatchToSingleBatchPass.h"
+
+#include <loco.h>
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+std::unique_ptr<loco::TensorShape> make_tshape(std::initializer_list<uint32_t> dims)
+{
+ auto tensor_shape = std::make_unique<loco::TensorShape>();
+ {
+ tensor_shape->rank(dims.size());
+ uint32_t axis = 0;
+ for (auto it = dims.begin(); it != dims.end(); ++it)
+ {
+ tensor_shape->dim(axis++) = *it;
+ }
+ }
+
+ return std::move(tensor_shape);
+}
+
+} // namespace
+
+TEST(DynamicBatchToSingleBatchPassTest, simple)
+{
+ luci::DynamicBatchToSingleBatchPass pass;
+
+ auto g = loco::make_graph();
+
+ auto graph_input = g->inputs()->create();
+ {
+ auto tensor_shape = make_tshape({1, 5, 5, 3});
+ tensor_shape->dim(0).unset();
+ graph_input->shape(std::move(tensor_shape));
+ }
+
+ // Create nodes to make relu traversed first
+ auto input = g->nodes()->create<luci::CircleInput>();
+ {
+ input->index(0);
+ input->shape({1, 5, 5, 3});
+ input->dim(0).unset();
+ }
+
+ EXPECT_FALSE(graph_input->shape()->dim(0).known());
+ EXPECT_FALSE(input->dim(0).known());
+
+ EXPECT_TRUE(pass.run(g.get()));
+
+ // Check input is knwon
+ EXPECT_TRUE(graph_input->shape()->dim(0).known());
+ EXPECT_EQ(1, graph_input->shape()->dim(0));
+ EXPECT_TRUE(input->dim(0).known());
+ EXPECT_EQ(1, input->dim(0));
+}
+
+TEST(DynamicBatchToSingleBatchPassTest, simple_NEG)
+{
+ luci::DynamicBatchToSingleBatchPass pass;
+
+ auto g = loco::make_graph();
+
+ auto graph_input = g->inputs()->create();
+ {
+ graph_input->shape({1, 5, 5, 3});
+ }
+
+ // Create nodes to make relu traversed first
+ auto input = g->nodes()->create<luci::CircleInput>();
+ {
+ input->index(0);
+ input->shape({1, 5, 5, 3});
+ }
+
+ EXPECT_FALSE(pass.run(g.get()));
+}
+
+// Remove this test if we support rank 1 in this pass
+TEST(DynamicBatchToSingleBatchPassTest, rank1_NEG)
+{
+ luci::DynamicBatchToSingleBatchPass pass;
+
+ auto g = loco::make_graph();
+
+ auto graph_input = g->inputs()->create();
+ {
+ auto tensor_shape = make_tshape({1});
+ tensor_shape->dim(0).unset();
+ graph_input->shape(std::move(tensor_shape));
+ }
+
+ // Create nodes to make relu traversed first
+ auto input = g->nodes()->create<luci::CircleInput>();
+ {
+ input->index(0);
+ input->shape({1});
+ input->dim(0).unset();
+ }
+
+ EXPECT_FALSE(graph_input->shape()->dim(0).known());
+ EXPECT_FALSE(input->dim(0).known());
+
+ // Rank 1 is unsupported for now
+ EXPECT_ANY_THROW(pass.run(g.get()));
+}
diff --git a/compiler/luci/pass/src/FoldAddV2Pass.test.cpp b/compiler/luci/pass/src/FoldAddV2Pass.test.cpp
index 438d7f077..200fcc093 100644
--- a/compiler/luci/pass/src/FoldAddV2Pass.test.cpp
+++ b/compiler/luci/pass/src/FoldAddV2Pass.test.cpp
@@ -44,10 +44,10 @@ template <loco::DataType T> class FoldAddV2Test : public luci::ConstantFoldingAd
public:
FoldAddV2Test(std::initializer_list<uint32_t> shape) : luci::ConstantFoldingAddTestGraph(shape, T)
{
- _addV2 = _g.nodes()->create<luci::CircleCustom>(2, 1);
- _x = _g.nodes()->create<luci::CircleConst>();
- _y = _g.nodes()->create<luci::CircleConst>();
- _addV2_out = _g.nodes()->create<luci::CircleCustomOut>();
+ _addV2 = _g.nodes()->template create<luci::CircleCustom>(2, 1);
+ _x = _g.nodes()->template create<luci::CircleConst>();
+ _y = _g.nodes()->template create<luci::CircleConst>();
+ _addV2_out = _g.nodes()->template create<luci::CircleCustomOut>();
_addV2->dtype(T);
_x->dtype(T);
diff --git a/compiler/luci/pass/src/FoldCastPass.test.cpp b/compiler/luci/pass/src/FoldCastPass.test.cpp
index 5911adf11..da33e4379 100644
--- a/compiler/luci/pass/src/FoldCastPass.test.cpp
+++ b/compiler/luci/pass/src/FoldCastPass.test.cpp
@@ -31,8 +31,8 @@ public:
FoldCastTest(std::initializer_list<uint32_t> shape)
: luci::ConstantFoldingAddTestGraph(shape, ToT)
{
- _cast = _g.nodes()->create<luci::CircleCast>();
- _x = _g.nodes()->create<luci::CircleConst>();
+ _cast = _g.nodes()->template create<luci::CircleCast>();
+ _x = _g.nodes()->template create<luci::CircleConst>();
_cast->dtype(ToT);
_x->dtype(FromT);
diff --git a/compiler/luci/pass/src/FoldDequantizePass.test.cpp b/compiler/luci/pass/src/FoldDequantizePass.test.cpp
index fb5b6adc0..87dff5dc0 100644
--- a/compiler/luci/pass/src/FoldDequantizePass.test.cpp
+++ b/compiler/luci/pass/src/FoldDequantizePass.test.cpp
@@ -32,8 +32,8 @@ public:
loco::Node *createFoldedPattern() override
{
- _dequantize = _g.nodes()->create<luci::CircleDequantize>();
- _input = _g.nodes()->create<luci::CircleConst>();
+ _dequantize = _g.nodes()->template create<luci::CircleDequantize>();
+ _input = _g.nodes()->template create<luci::CircleConst>();
_dequantize->dtype(loco::DataType::FLOAT32);
_input->dtype(DT);
diff --git a/compiler/luci/pass/src/FuseActivationFunctionPass.cpp b/compiler/luci/pass/src/FuseActivationFunctionPass.cpp
index d83973cd5..868ccd140 100644
--- a/compiler/luci/pass/src/FuseActivationFunctionPass.cpp
+++ b/compiler/luci/pass/src/FuseActivationFunctionPass.cpp
@@ -42,6 +42,11 @@ bool fuse_activation_function(luci::CircleNode *node)
// This will skip fuse for concat as luci-interpreter doesn't support this yet
if (dynamic_cast<luci::CircleConcatenation *>(pred_node) != nullptr)
return false;
+ // TODO remove this work-around
+ // This will skip fuse for TransposeConv as backends does not support this yet
+ // NOTE remove this when XpSepActFromTransposeConvOpPass is removed
+ if (dynamic_cast<luci::CircleTransposeConv *>(pred_node) != nullptr)
+ return false;
auto fused_act = node_with_fused_act->fusedActivationFunction();
diff --git a/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.test.cpp b/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.test.cpp
index 300796594..b132c6bd9 100644
--- a/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.test.cpp
+++ b/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.test.cpp
@@ -16,6 +16,8 @@
#include "luci/Pass/FuseAddWithFullyConnectedPass.h"
+#include "helpers/CreateCircleConst.h"
+
#include <luci/IR/CircleNodes.h>
#include <luci/test/TestIOGraph.h>
@@ -27,52 +29,6 @@ namespace
using namespace luci::test;
-// TODO Reduce duplicate codes in ResolveCustomOpMatMulPass.cpp
-template <typename T>
-luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype,
- const std::vector<uint32_t> &shape,
- const std::vector<T> &values)
-{
- auto node = g->nodes()->create<luci::CircleConst>();
- node->dtype(dtype);
- node->rank(shape.size());
-
- uint32_t size = 1;
- for (uint32_t i = 0; i < shape.size(); ++i)
- {
- node->dim(i) = shape.at(i);
- size *= shape.at(i);
- }
- node->shape_status(luci::ShapeStatus::VALID);
-
-#define INIT_VALUES(DT) \
- { \
- node->size<DT>(size); \
- for (uint32_t i = 0; i < values.size(); ++i) \
- node->at<DT>(i) = values[i]; \
- }
-
- switch (dtype)
- {
- case loco::DataType::U8:
- INIT_VALUES(loco::DataType::U8);
- break;
- case loco::DataType::S16:
- INIT_VALUES(loco::DataType::S16);
- break;
- case loco::DataType::S32:
- INIT_VALUES(loco::DataType::S32);
- break;
- case loco::DataType::FLOAT32:
- INIT_VALUES(loco::DataType::FLOAT32)
- break;
- default:
- INTERNAL_EXN("create_const_node called with unsupported type");
- break;
- }
- return node;
-}
-
/**
* Simple graph for test
*
@@ -95,10 +51,10 @@ public:
void init(loco::Graph *g)
{
std::vector<float> weights_val(16 * 4);
- _fc_f = create_const_node(g, loco::DataType::FLOAT32, {16, 4}, weights_val);
+ _fc_f = luci::create_const_node(g, loco::DataType::FLOAT32, {16, 4}, weights_val);
std::vector<float> bias_val(16);
- _fc_b = create_const_node(g, loco::DataType::FLOAT32, {1, 16}, bias_val);
+ _fc_b = luci::create_const_node(g, loco::DataType::FLOAT32, {1, 16}, bias_val);
_fc = g->nodes()->create<luci::CircleFullyConnected>();
_fc->weights(_fc_f);
@@ -111,7 +67,7 @@ public:
std::vector<float> addition_val;
for (uint32_t i = 0; i < 16; i++)
addition_val.push_back(static_cast<float>(i));
- _add_c = create_const_node(g, loco::DataType::FLOAT32, {1, 16}, addition_val);
+ _add_c = luci::create_const_node(g, loco::DataType::FLOAT32, {1, 16}, addition_val);
_add = g->nodes()->create<luci::CircleAdd>();
_add->x(_fc);
diff --git a/compiler/luci/pass/src/FuseAddWithTConvPass.cpp b/compiler/luci/pass/src/FuseAddWithTConvPass.cpp
index 852bc8b63..d8e9f11f5 100644
--- a/compiler/luci/pass/src/FuseAddWithTConvPass.cpp
+++ b/compiler/luci/pass/src/FuseAddWithTConvPass.cpp
@@ -44,6 +44,9 @@ namespace
*/
bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv)
{
+ // skip if tconv has fused activation
+ if (tconv->fusedActivationFunction() != luci::FusedActFunc::NONE)
+ return false;
// check whether it has bias or not. This optimization works only if it doesn't.
auto bias = dynamic_cast<luci::CircleOutputExclude *>(tconv->bias());
if (not bias)
diff --git a/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp b/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp
index 265a8398b..919ce6edc 100644
--- a/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp
+++ b/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp
@@ -87,6 +87,9 @@ bool fused_batch_norm_with_tconv(luci::CircleAdd *add)
return false;
if (not luci::fill(&scale, &tconv).with_commutative_args_of(mul))
return false;
+ // skip if tconv has fused activation
+ if (tconv->fusedActivationFunction() != luci::FusedActFunc::NONE)
+ return false;
// check scale and shift constant attributes
// TODO maybe rank check is not needed
@@ -215,6 +218,9 @@ bool fused_batch_norm_with_tconv(luci::CircleAdd *add)
fused_tconv->stride()->h(tconv->stride()->h());
fused_tconv->stride()->w(tconv->stride()->w());
fused_tconv->name(name + "/TransposeConv");
+ // TODO set activation from Add and remove adding following Relu/Relu6 Op
+ // when all of our backends supports fused activation of TransposeConv
+ fused_tconv->fusedActivationFunction(luci::FusedActFunc::NONE);
luci::add_origin(fused_tconv,
luci::composite_origin(
{luci::get_origin(add), luci::get_origin(mul), luci::get_origin(tconv)}));
diff --git a/compiler/luci/pass/src/FuseGeluPass.cpp b/compiler/luci/pass/src/FuseGeluPass.cpp
new file mode 100644
index 000000000..e3e7cecb3
--- /dev/null
+++ b/compiler/luci/pass/src/FuseGeluPass.cpp
@@ -0,0 +1,347 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FuseGeluPass.h"
+#include "helpers/NodeFiller.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <luci/Profile/CircleNodeOrigin.h>
+#include <luci/Service/CircleNodeClone.h>
+
+#include <cmath>
+
+#include <cassert>
+
+// Helper to fuse Gelu
+namespace
+{
+
+// Float comparison
+bool same(float a, float b) { return fabs(a - b) < 1e-5; }
+
+class GeluPatternBase
+{
+public:
+ GeluPatternBase(luci::CircleMul *candidate) { _pattern_last_node = candidate; }
+
+ virtual ~GeluPatternBase() = default;
+
+public:
+ virtual bool matched() = 0;
+
+public:
+ luci::CircleNode *_ifm = nullptr;
+ luci::CircleMul *_mul_sqrt = nullptr;
+ luci::CircleCustom *_erf = nullptr;
+ luci::CircleCustomOut *_erf_out = nullptr;
+ luci::CircleAdd *_add_one = nullptr;
+ luci::CircleMul *_mul = nullptr;
+ luci::CircleMul *_mul_half = nullptr;
+ luci::CircleConst *_const_sqrt = nullptr;
+ luci::CircleConst *_const_one = nullptr;
+ luci::CircleConst *_const_half = nullptr;
+ luci::CircleMul *_pattern_last_node = nullptr;
+};
+
+/**
+ * Below diagram shows Gelu pattern to fuse.
+ * - Gelu(x) = 0.5 * x * (1.0 + erf(x / sqrt(2.0)))
+ * - the below pattern will be replaced with one Gelu
+ *
+ * [In]
+ * |
+ * V
+ * +---- ifm
+ * | |
+ * | V
+ * | mul_sqrt (1/sqrt(2) = 0.707106..)
+ * | |
+ * | V
+ * | erf
+ * | |
+ * | V
+ * | add_one (1.0)
+ * | |
+ * | V
+ * +---> mul
+ * |
+ * V
+ * mul_half (0.5)
+ * |
+ * V
+ * [Out]
+ *
+ */
+class GeluPattern1 final : public GeluPatternBase
+{
+public:
+ GeluPattern1(luci::CircleMul *candidate) : GeluPatternBase(candidate)
+ {
+ assert(candidate);
+ _mul_half = candidate;
+ }
+
+public:
+ bool matched() override;
+};
+
+/**
+ * Below diagram shows Gelu pattern to fuse.
+ * - Gelu(x) = 0.5 * x * (1.0 + erf(x / sqrt(2.0)))
+ * - the below pattern will be replaced with one Gelu
+ *
+ * [In]
+ * |
+ * V
+ * +----------- ifm
+ * | |
+ * | V
+ * | mul_sqrt (1/sqrt(2) = 0.707106..)
+ * | |
+ * | V
+ * | erf
+ * mul_half (0.5) |
+ * | V
+ * | add_one (1.0)
+ * | |
+ * | V
+ * +----------> mul
+ * |
+ * |
+ * V
+ * [Out]
+ *
+ */
+class GeluPattern2 final : public GeluPatternBase
+{
+public:
+ GeluPattern2(luci::CircleMul *candidate) : GeluPatternBase(candidate)
+ {
+ assert(candidate);
+ _mul = candidate;
+ }
+
+ ~GeluPattern2() override = default;
+
+public:
+ bool matched() override;
+};
+
+#define CHECK_OR_FALSE(condition) \
+ if (not(condition)) \
+ return false;
+
+bool GeluPattern1::matched()
+{
+ // check pattern
+ CHECK_OR_FALSE(luci::fill(&_mul, &_const_half).with_commutative_args_of(_mul_half));
+ CHECK_OR_FALSE(luci::fill(&_ifm, &_add_one).with_commutative_args_of(_mul));
+ CHECK_OR_FALSE(luci::fill(&_erf_out, &_const_one).with_commutative_args_of(_add_one));
+
+ if (auto erf = dynamic_cast<luci::CircleCustom *>(_erf_out->input()))
+ _erf = erf;
+
+ CHECK_OR_FALSE(_erf != nullptr);
+
+ // Check erf
+ CHECK_OR_FALSE(_erf->custom_code() == "Erf");
+ CHECK_OR_FALSE(_erf->numInputs() == 1);
+ CHECK_OR_FALSE(_erf->numOutputs() == 1);
+
+ if (auto mul_sqrt = dynamic_cast<luci::CircleMul *>(_erf->inputs(0)))
+ _mul_sqrt = mul_sqrt;
+
+ CHECK_OR_FALSE(_mul_sqrt != nullptr);
+
+ CHECK_OR_FALSE(luci::fill(&_ifm, &_const_sqrt).with_commutative_args_of(_mul_sqrt));
+
+ CHECK_OR_FALSE(_mul_sqrt->x() == _ifm);
+ CHECK_OR_FALSE(_mul->x() == _ifm);
+
+ // Check Activation to be NONE
+ CHECK_OR_FALSE(_mul_sqrt->fusedActivationFunction() == luci::FusedActFunc::NONE);
+ CHECK_OR_FALSE(_add_one->fusedActivationFunction() == luci::FusedActFunc::NONE);
+ CHECK_OR_FALSE(_mul->fusedActivationFunction() == luci::FusedActFunc::NONE);
+ CHECK_OR_FALSE(_mul_half->fusedActivationFunction() == luci::FusedActFunc::NONE);
+
+ // check _const_sqrt condition
+ CHECK_OR_FALSE(_const_sqrt->dtype() == loco::DataType::FLOAT32);
+ CHECK_OR_FALSE(_const_sqrt->size<loco::DataType::FLOAT32>() == 1);
+ CHECK_OR_FALSE(::same(_const_sqrt->at<loco::DataType::FLOAT32>(0), sqrtf(0.5f)));
+
+ // check if _const_half is 0.5 (fp32)
+ CHECK_OR_FALSE(_const_half->dtype() == loco::DataType::FLOAT32);
+ CHECK_OR_FALSE(_const_half->size<loco::DataType::FLOAT32>() == 1);
+ CHECK_OR_FALSE(_const_half->at<loco::DataType::FLOAT32>(0) == 0.5);
+
+ // check _const_one condition
+ CHECK_OR_FALSE(_const_one->dtype() == loco::DataType::FLOAT32);
+ CHECK_OR_FALSE(_const_one->size<loco::DataType::FLOAT32>() == 1);
+ CHECK_OR_FALSE(_const_one->at<loco::DataType::FLOAT32>(0) == 1);
+
+ return true;
+}
+
+bool GeluPattern2::matched()
+{
+ // check pattern
+ CHECK_OR_FALSE(luci::fill(&_mul_half, &_add_one).with_commutative_args_of(_mul));
+ CHECK_OR_FALSE(luci::fill(&_ifm, &_const_half).with_commutative_args_of(_mul_half));
+ CHECK_OR_FALSE(luci::fill(&_erf_out, &_const_one).with_commutative_args_of(_add_one));
+
+ CHECK_OR_FALSE(_mul_half->x() == _ifm);
+
+ if (auto erf = dynamic_cast<luci::CircleCustom *>(_erf_out->input()))
+ _erf = erf;
+
+ CHECK_OR_FALSE(_erf != nullptr);
+
+ // Check erf
+ CHECK_OR_FALSE(_erf->custom_code() == "Erf");
+ CHECK_OR_FALSE(_erf->numInputs() == 1);
+ CHECK_OR_FALSE(_erf->numOutputs() == 1);
+
+ if (auto mul_sqrt = dynamic_cast<luci::CircleMul *>(_erf->inputs(0)))
+ _mul_sqrt = mul_sqrt;
+
+ CHECK_OR_FALSE(_mul_sqrt != nullptr);
+
+ CHECK_OR_FALSE(luci::fill(&_ifm, &_const_sqrt).with_commutative_args_of(_mul_sqrt));
+
+ CHECK_OR_FALSE(_mul_sqrt->x() == _ifm);
+
+ // Check Activation to be NONE
+ CHECK_OR_FALSE(_mul_sqrt->fusedActivationFunction() == luci::FusedActFunc::NONE);
+ CHECK_OR_FALSE(_add_one->fusedActivationFunction() == luci::FusedActFunc::NONE);
+ CHECK_OR_FALSE(_mul->fusedActivationFunction() == luci::FusedActFunc::NONE);
+ CHECK_OR_FALSE(_mul_half->fusedActivationFunction() == luci::FusedActFunc::NONE);
+
+ // check _const_sqrt condition
+ CHECK_OR_FALSE(_const_sqrt->dtype() == loco::DataType::FLOAT32);
+ CHECK_OR_FALSE(_const_sqrt->size<loco::DataType::FLOAT32>() == 1);
+ CHECK_OR_FALSE(::same(_const_sqrt->at<loco::DataType::FLOAT32>(0), sqrtf(0.5f)));
+
+ // check if _const_half is 0.5 (fp32)
+ CHECK_OR_FALSE(_const_half->dtype() == loco::DataType::FLOAT32);
+ CHECK_OR_FALSE(_const_half->size<loco::DataType::FLOAT32>() == 1);
+ CHECK_OR_FALSE(_const_half->at<loco::DataType::FLOAT32>(0) == 0.5);
+
+ // check _const_one condition
+ CHECK_OR_FALSE(_const_one->dtype() == loco::DataType::FLOAT32);
+ CHECK_OR_FALSE(_const_one->size<loco::DataType::FLOAT32>() == 1);
+ CHECK_OR_FALSE(_const_one->at<loco::DataType::FLOAT32>(0) == 1);
+
+ return true;
+}
+
+#undef CHECK_OR_FALSE
+
+class FuseGelu final
+{
+public:
+ FuseGelu(const GeluPatternBase *p) : _p(p) {}
+
+public:
+ void apply(void);
+
+private:
+ luci::CircleGelu *create_gelu(loco::Graph *graph);
+
+private:
+ const GeluPatternBase *_p;
+};
+
+luci::CircleGelu *FuseGelu::create_gelu(loco::Graph *graph)
+{
+ assert(graph);
+
+ auto gelu = graph->nodes()->create<luci::CircleGelu>();
+ gelu->features(_p->_ifm);
+ // TODO Support approximate = True pattern
+ gelu->approximate(false);
+ gelu->name(_p->_pattern_last_node->name() + "_gelu");
+ return gelu;
+}
+
+void FuseGelu::apply()
+{
+ auto graph = _p->_pattern_last_node->graph();
+
+ auto gelu = create_gelu(graph);
+
+ // set origin
+ std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
+ luci::get_origin(_p->_mul_sqrt), luci::get_origin(_p->_erf), luci::get_origin(_p->_add_one),
+ luci::get_origin(_p->_mul), luci::get_origin(_p->_mul_half)};
+
+ luci::add_origin(gelu, luci::composite_origin(origin_vec));
+
+ replace(_p->_pattern_last_node).with(gelu);
+}
+
+} // namespace
+
+namespace
+{
+
+bool fuse_gelu(luci::CircleMul *mul)
+{
+ assert(mul);
+
+ // check first pattern
+ GeluPattern1 pattern(mul);
+ if (pattern.matched())
+ {
+ FuseGelu fuse(&pattern);
+ fuse.apply();
+ return true;
+ }
+
+ // check second pattern
+ GeluPattern2 pattern2(mul);
+ if (pattern2.matched())
+ {
+ FuseGelu fuse(&pattern2);
+ fuse.apply();
+ return true;
+ }
+ return false;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool FuseGeluPass::run(loco::Graph *g)
+{
+ bool changed = false;
+
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto mul = dynamic_cast<luci::CircleMul *>(node);
+ if (not mul)
+ continue;
+
+ if (fuse_gelu(mul))
+ changed = true;
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/FuseGeluPass.test.cpp b/compiler/luci/pass/src/FuseGeluPass.test.cpp
new file mode 100644
index 000000000..db6f6993a
--- /dev/null
+++ b/compiler/luci/pass/src/FuseGeluPass.test.cpp
@@ -0,0 +1,251 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FuseGeluPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <luci/test/TestIOGraph.h>
+
+#include <cmath>
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class GeluGraphlet
+{
+public:
+ GeluGraphlet() = default;
+
+ void init(loco::Graph *g)
+ {
+ _ifm = g->nodes()->create<luci::CircleAbs>();
+ _mul_sqrt = g->nodes()->create<luci::CircleMul>();
+ _erf = g->nodes()->create<luci::CircleCustom>(1, 1);
+ _erf_out = g->nodes()->create<luci::CircleCustomOut>();
+ _add_one = g->nodes()->create<luci::CircleAdd>();
+ _mul = g->nodes()->create<luci::CircleMul>();
+ _mul_half = g->nodes()->create<luci::CircleMul>();
+ _const_sqrt = g->nodes()->create<luci::CircleConst>();
+ _const_one = g->nodes()->create<luci::CircleConst>();
+ _const_half = g->nodes()->create<luci::CircleConst>();
+
+ _mul->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _mul_sqrt->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _mul_half->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _add_one->fusedActivationFunction(luci::FusedActFunc::NONE);
+
+ _ifm->name("ifm");
+ _mul_sqrt->name("mul_sqrt");
+ _erf->name("erf");
+ _erf_out->name("erf_out");
+ _add_one->name("add_one");
+ _mul->name("mul");
+ _mul_half->name("mul_half");
+ _const_one->name("const_one");
+ _const_sqrt->name("const_sqrt");
+ _const_half->name("const_half");
+
+ _erf->custom_code("Erf");
+
+ _const_sqrt->dtype(loco::DataType::FLOAT32);
+ _const_sqrt->size<loco::DataType::FLOAT32>(1);
+ _const_sqrt->shape({1});
+ _const_sqrt->at<loco::DataType::FLOAT32>(0) = sqrtf(0.5f);
+ _const_sqrt->shape_status(luci::ShapeStatus::VALID);
+
+ _const_one->dtype(loco::DataType::FLOAT32);
+ _const_one->size<loco::DataType::FLOAT32>(1);
+ _const_one->shape({1});
+ _const_one->at<loco::DataType::FLOAT32>(0) = 1.0;
+ _const_one->shape_status(luci::ShapeStatus::VALID);
+
+ _const_half->dtype(loco::DataType::FLOAT32);
+ _const_half->size<loco::DataType::FLOAT32>(1);
+ _const_half->shape({1});
+ _const_half->at<loco::DataType::FLOAT32>(0) = 0.5;
+ _const_half->shape_status(luci::ShapeStatus::VALID);
+ }
+
+ void invalid_half() { _const_half->at<loco::DataType::FLOAT32>(0) = 0.1; }
+ void invalid_act() { _add_one->fusedActivationFunction(luci::FusedActFunc::RELU); }
+
+protected:
+ luci::CircleAbs *_ifm = nullptr;
+ luci::CircleMul *_mul_sqrt = nullptr;
+ luci::CircleCustom *_erf = nullptr;
+ luci::CircleCustomOut *_erf_out = nullptr;
+ luci::CircleAdd *_add_one = nullptr;
+ luci::CircleMul *_mul = nullptr;
+ luci::CircleMul *_mul_half = nullptr;
+ luci::CircleConst *_const_sqrt = nullptr;
+ luci::CircleConst *_const_one = nullptr;
+ luci::CircleConst *_const_half = nullptr;
+};
+
+class FuseGeluTestGraph1 : public TestIOGraph, public GeluGraphlet
+{
+public:
+ FuseGeluTestGraph1() = default;
+
+ void init(void)
+ {
+ TestIOGraph::init({1}, {1});
+ GeluGraphlet::init(g());
+
+ _ifm->x(input());
+ _mul_sqrt->x(_ifm);
+ _mul_sqrt->y(_const_sqrt);
+ _erf->inputs(0, _mul_sqrt);
+ _erf_out->input(_erf);
+ _add_one->x(_erf_out);
+ _add_one->y(_const_one);
+ _mul->x(_ifm);
+ _mul->y(_add_one);
+ _mul_half->x(_mul);
+ _mul_half->y(_const_half);
+
+ output()->from(_mul_half);
+ }
+};
+
+class FuseGeluTestGraph2 : public TestIOGraph, public GeluGraphlet
+{
+public:
+ FuseGeluTestGraph2() = default;
+
+ void init(void)
+ {
+ TestIOGraph::init({1}, {1});
+ GeluGraphlet::init(g());
+
+ _ifm->x(input());
+ _mul_sqrt->x(_ifm);
+ _mul_sqrt->y(_const_sqrt);
+ _erf->inputs(0, _mul_sqrt);
+ _erf_out->input(_erf);
+ _add_one->x(_erf_out);
+ _add_one->y(_const_one);
+ _mul_half->x(_ifm);
+ _mul_half->y(_const_half);
+ _mul->x(_mul_half);
+ _mul->y(_add_one);
+
+ output()->from(_mul);
+ }
+};
+
+class FuseGeluTestNegGraph : public TestIOGraph, public GeluGraphlet
+{
+public:
+ FuseGeluTestNegGraph() = default;
+
+ void init(void)
+ {
+ TestIOGraph::init({1}, {1});
+ GeluGraphlet::init(g());
+
+ _ifm->x(input());
+ _mul_sqrt->x(_ifm);
+ // NOTE y is incorrect (should be _const_sqrt)
+ _mul_sqrt->y(_ifm);
+ _erf->inputs(0, _mul_sqrt);
+ _erf_out->input(_erf);
+ _add_one->x(_erf_out);
+ _add_one->y(_const_one);
+ _mul->x(_ifm);
+ _mul->y(_add_one);
+ _mul_half->x(_mul);
+ _mul_half->y(_const_half);
+
+ output()->from(_mul_half);
+ }
+};
+
+} // namespace
+
+TEST(FuseGeluPassTest, name)
+{
+ luci::FuseGeluPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST(FuseGeluPassTest, fuse_pattern1)
+{
+ FuseGeluTestGraph1 g;
+ luci::FuseGeluPass pass;
+
+ g.init();
+
+ EXPECT_TRUE(pass.run(g.g()));
+}
+
+TEST(FuseGeluPassTest, fuse_pattern2)
+{
+ FuseGeluTestGraph2 g;
+ luci::FuseGeluPass pass;
+
+ g.init();
+
+ EXPECT_TRUE(pass.run(g.g()));
+}
+
+TEST(FuseGeluPassTest, fuse_invalid_half_NEG)
+{
+ FuseGeluTestNegGraph g;
+ luci::FuseGeluPass pass;
+
+ g.init();
+ g.invalid_half();
+
+ EXPECT_FALSE(pass.run(g.g()));
+}
+
+TEST(FuseGeluPassTest, fuse_pattern2_invalid_half_NEG)
+{
+ FuseGeluTestGraph2 g;
+ luci::FuseGeluPass pass;
+
+ g.init();
+ g.invalid_half();
+
+ EXPECT_FALSE(pass.run(g.g()));
+}
+
+TEST(FuseGeluPassTest, fuse_invalid_act_NEG)
+{
+ FuseGeluTestNegGraph g;
+ luci::FuseGeluPass pass;
+
+ g.init();
+ g.invalid_act();
+
+ EXPECT_FALSE(pass.run(g.g()));
+}
+
+TEST(FuseGeluPassTest, fuse_NEG)
+{
+ FuseGeluTestNegGraph g;
+ luci::FuseGeluPass pass;
+
+ g.init();
+
+ EXPECT_FALSE(pass.run(g.g()));
+}
diff --git a/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp b/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp
index e8fa2a478..18617e3b7 100644
--- a/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp
+++ b/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp
@@ -28,6 +28,25 @@
namespace
{
+// Return true if node is a virtual node
+bool virtual_op(const luci::CircleOpcode opcode)
+{
+ switch (opcode)
+ {
+#define CIRCLE_NODE(OPCODE, CIRCLE_CLASS) \
+ case luci::CircleOpcode::OPCODE: \
+ return false;
+#define CIRCLE_VNODE(OPCODE, CIRCLE_CLASS) \
+ case luci::CircleOpcode::OPCODE: \
+ return true;
+#include <luci/IR/CircleNodes.lst>
+#undef CIRCLE_NODE
+#undef CIRCLE_VNODE
+ default:
+ throw std::runtime_error("Unknown opcode detected");
+ }
+}
+
void quant_const_values(luci::CircleConst *const_node, float scaling_factor, float zerop,
loco::DataType quant_type)
{
@@ -448,6 +467,50 @@ struct PropagateQParamBackward final : public luci::CircleNodeMutableVisitor<voi
void visit(luci::CirclePack *node) { propagate_pack_quantparam(node); }
void visit(luci::CirclePadV2 *node) { propagate_pad_v2_quantparam(node); }
+
+ // Propagate qparam for non-value changing Ops
+ // (ex: Reshape, Transpose, etc.)
+ // TODO Add more Ops
+
+ void visit(luci::CircleReshape *node)
+ {
+ auto input_node = loco::must_cast<luci::CircleNode *>(node->tensor());
+
+ // Do not propagate qparam if input node has multiple users
+ if (loco::succs(input_node).size() > 1)
+ return;
+
+ const auto input_opcode = input_node->opcode();
+
+ // Do not propagate qparam if input node is virtual Op (except CIRCLEINPUT)
+ // Why? It is not safe to propagate qparam to some virtual nodes. For example,
+ // const node, multi-out nodes. Let's block them for now.
+ // TODO Revisit this condition
+ if (virtual_op(input_opcode) and input_opcode != luci::CircleOpcode::CIRCLEINPUT)
+ return;
+
+ overwrite_quantparam(node, input_node);
+ }
+
+ void visit(luci::CircleTranspose *node)
+ {
+ auto input_node = loco::must_cast<luci::CircleNode *>(node->a());
+
+ // Do not propagate qparam if input node has multiple users
+ if (loco::succs(input_node).size() > 1)
+ return;
+
+ const auto input_opcode = input_node->opcode();
+
+ // Do not propagate qparam if input node is virtual Op (except CIRCLEINPUT)
+ // Why? It is not safe to propagate qparam to some virtual nodes. For example,
+ // const node, multi-out nodes. Let's block them for now.
+ // TODO Revisit this condition
+ if (virtual_op(input_opcode) and input_opcode != luci::CircleOpcode::CIRCLEINPUT)
+ return;
+
+ overwrite_quantparam(node, input_node);
+ }
};
} // namespace
diff --git a/compiler/luci/pass/src/PropagateQParamBackwardPass.test.cpp b/compiler/luci/pass/src/PropagateQParamBackwardPass.test.cpp
index 33af70449..04573cc45 100644
--- a/compiler/luci/pass/src/PropagateQParamBackwardPass.test.cpp
+++ b/compiler/luci/pass/src/PropagateQParamBackwardPass.test.cpp
@@ -129,6 +129,119 @@ public:
CircleOutput *output = nullptr;
};
+/**
+ * BEFORE
+ *
+ * [Input]
+ * |
+ * [Conv] (qparam 1)
+ * |
+ * [Reshape] (qparam 2)
+ * |
+ * [Output]
+ *
+ * AFTER
+ *
+ * [Input]
+ * |
+ * [Conv] (qparam 2)
+ * |
+ * [Reshape] (qparam 2)
+ * |
+ * [Output]
+ */
+class ConvReshapeGraph
+{
+public:
+ ConvReshapeGraph()
+ {
+ input = g.nodes()->create<luci::CircleInput>();
+ conv = g.nodes()->create<luci::CircleConv2D>();
+ reshape = g.nodes()->create<luci::CircleReshape>();
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_input = g.inputs()->create();
+ input->index(graph_input->index());
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ set_qparam(conv, 2.0, 2);
+ set_qparam(reshape, 1.0, 1);
+
+ conv->input(input);
+ reshape->tensor(conv);
+ output->from(reshape);
+ }
+
+public:
+ loco::Graph g;
+ luci::CircleInput *input = nullptr;
+ luci::CircleConv2D *conv = nullptr;
+ luci::CircleReshape *reshape = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+/**
+ * BEFORE
+ *
+ * [Input]
+ * |
+ * [Conv] (qparam 1)
+ * |
+ * +---------------------+
+ * | |
+ * [Reshape] (qparam 2) [Output]
+ * |
+ * [Output]
+ *
+ * AFTER (qparam is not propagated as Conv has multiple users)
+ *
+ * [Input]
+ * |
+ * [Conv] (qparam 1)
+ * |
+ * +---------------------+
+ * | |
+ * [Reshape] (qparam 2) [Output]
+ * |
+ * [Output]
+ */
+class ConvReshapeMultiOutGraph
+{
+public:
+ ConvReshapeMultiOutGraph()
+ {
+ input = g.nodes()->create<luci::CircleInput>();
+ conv = g.nodes()->create<luci::CircleConv2D>();
+ reshape = g.nodes()->create<luci::CircleReshape>();
+ output1 = g.nodes()->create<luci::CircleOutput>();
+ output2 = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_input = g.inputs()->create();
+ input->index(graph_input->index());
+ auto graph_output1 = g.outputs()->create();
+ output1->index(graph_output1->index());
+ auto graph_output2 = g.outputs()->create();
+ output2->index(graph_output2->index());
+
+ set_qparam(conv, 2.0, 2);
+ set_qparam(reshape, 1.0, 1);
+
+ conv->input(input);
+ reshape->tensor(conv);
+ output1->from(reshape);
+ output2->from(conv);
+ }
+
+public:
+ loco::Graph g;
+ luci::CircleInput *input = nullptr;
+ luci::CircleConv2D *conv = nullptr;
+ luci::CircleReshape *reshape = nullptr;
+ luci::CircleOutput *output1 = nullptr;
+ luci::CircleOutput *output2 = nullptr;
+};
+
} // namespace
TEST(PropagateQParamBackwardPassTest, name)
@@ -165,3 +278,33 @@ TEST(PropagateQParamBackwardPassTest, subsequent_propagation)
EXPECT_EQ(3.0, graph.input->quantparam()->scale[0]);
EXPECT_EQ(3, graph.input->quantparam()->zerop[0]);
}
+
+TEST(PropagateQParamBackwardPassTest, reshape)
+{
+ ConvReshapeGraph graph;
+
+ EXPECT_NE(graph.conv->quantparam()->scale, graph.reshape->quantparam()->scale);
+ EXPECT_NE(graph.conv->quantparam()->zerop, graph.reshape->quantparam()->zerop);
+
+ luci::PropagateQParamBackwardPass pass(loco::DataType::U8);
+
+ pass.run(&graph.g);
+
+ EXPECT_EQ(graph.conv->quantparam()->scale, graph.reshape->quantparam()->scale);
+ EXPECT_EQ(graph.conv->quantparam()->zerop, graph.reshape->quantparam()->zerop);
+}
+
+TEST(PropagateQParamBackwardPassTest, reshape_multi_use_NEG)
+{
+ ConvReshapeMultiOutGraph graph;
+
+ EXPECT_NE(graph.conv->quantparam()->scale, graph.reshape->quantparam()->scale);
+ EXPECT_NE(graph.conv->quantparam()->zerop, graph.reshape->quantparam()->zerop);
+
+ luci::PropagateQParamBackwardPass pass(loco::DataType::U8);
+
+ pass.run(&graph.g);
+
+ EXPECT_NE(graph.conv->quantparam()->scale, graph.reshape->quantparam()->scale);
+ EXPECT_NE(graph.conv->quantparam()->zerop, graph.reshape->quantparam()->zerop);
+}
diff --git a/compiler/luci/pass/src/QuantizationUtils.cpp b/compiler/luci/pass/src/QuantizationUtils.cpp
index 45d229a0b..3e3cdde34 100644
--- a/compiler/luci/pass/src/QuantizationUtils.cpp
+++ b/compiler/luci/pass/src/QuantizationUtils.cpp
@@ -73,14 +73,14 @@ void asymmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float
}
void symmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float max,
- float &scaling_factor, int64_t &zp, float &nudged_min,
+ float &scaling_factor, float &nudged_min,
float &nudged_max)
{
const int32_t kMaxScale = std::numeric_limits<int16_t>::max();
const int32_t kMinScale = -kMaxScale;
uint32_t size = node->size<loco::DataType::FLOAT32>();
- compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
+ compute_sym_scale(min, max, scaling_factor, nudged_min, nudged_max);
const float scaling_factor_inv = 1.0 / scaling_factor;
std::vector<int32_t> quantized_values(size);
for (uint32_t i = 0; i < size; ++i)
@@ -101,12 +101,14 @@ void symmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float
}
}
-void compute_sym_scale_zp(float min, float max, float &scaling_factor, int64_t &zp,
- float &nudged_min, float &nudged_max)
+void compute_sym_scale(float min, float max, float &scaling_factor, float &nudged_min,
+ float &nudged_max, loco::DataType out_type)
{
assert(min <= max);
+ assert(out_type == loco::DataType::S8 || out_type == loco::DataType::S16);
- const int32_t kMaxScale = std::numeric_limits<int16_t>::max();
+ const int32_t kMaxScale = (out_type == loco::DataType::S16) ? std::numeric_limits<int16_t>::max()
+ : std::numeric_limits<int8_t>::max();
const int32_t kMinScale = -kMaxScale;
const double qmin_double = kMinScale;
const double qmax_double = kMaxScale;
@@ -126,10 +128,9 @@ void compute_sym_scale_zp(float min, float max, float &scaling_factor, int64_t &
: scale_factor_from_max_side;
// protect scale from being very low to avoid overflow/underflow
- if (scaling_factor < 1e-8)
- scaling_factor = 1e-8;
+ const float kMinScalingFactor = (out_type == loco::DataType::S16) ? 1e-8 : 1e-5;
+ scaling_factor = std::max(scaling_factor, kMinScalingFactor);
- zp = 0;
nudged_min = static_cast<float>(qmin_double * scaling_factor);
nudged_max = static_cast<float>(qmax_double * scaling_factor);
}
@@ -424,7 +425,7 @@ void quant_const(luci::CircleConst *node, loco::DataType quant_type)
nudged_max);
break;
case loco::DataType::S16:
- symmetric_wquant_with_minmax_per_layer(node, min, max, scaling_factor, zp, nudged_min,
+ symmetric_wquant_with_minmax_per_layer(node, min, max, scaling_factor, nudged_min,
nudged_max);
break;
default:
diff --git a/compiler/luci/pass/src/QuantizationUtils.h b/compiler/luci/pass/src/QuantizationUtils.h
index 0720c9839..93c4045b5 100644
--- a/compiler/luci/pass/src/QuantizationUtils.h
+++ b/compiler/luci/pass/src/QuantizationUtils.h
@@ -23,9 +23,9 @@
namespace luci
{
-// Compute scale/zp using given min/max for symmetric quantization (int16)
-void compute_sym_scale_zp(float min, float max, float &scaling_factor, int64_t &zp,
- float &nudged_min, float &nudged_max);
+// Compute scale using given min/max for symmetric quantization (int8/int16)
+void compute_sym_scale(float min, float max, float &scaling_factor, float &nudged_min,
+ float &nudged_max, loco::DataType out_type = loco::DataType::S16);
// Compute scale/zp using given min/max for asymmetric quantization (uint8)
void compute_asym_scale_zp(float min, float max, float &scaling_factor, int64_t &zp,
@@ -40,7 +40,7 @@ void asymmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float
// Symmetric per-layer quantization of weights (const tensor) using given min/max values
// NOTE: in-place update of node data
void symmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float max,
- float &scaling_factor, int64_t &zp, float &nudged_min,
+ float &scaling_factor, float &nudged_min,
float &nudged_max);
// Helper function to get channel dimension
diff --git a/compiler/luci/pass/src/QuantizeActivation.cpp b/compiler/luci/pass/src/QuantizeActivation.cpp
index 214e61c1e..913450083 100644
--- a/compiler/luci/pass/src/QuantizeActivation.cpp
+++ b/compiler/luci/pass/src/QuantizeActivation.cpp
@@ -78,7 +78,7 @@ void QuantizeActivation::visit(luci::CircleNode *node)
}
else
{
- compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
+ compute_sym_scale(min, max, scaling_factor, nudged_min, nudged_max);
node->dtype(loco::DataType::S16);
}
@@ -171,7 +171,10 @@ void QuantizeConstInputActivation::visit(luci::CircleNode *node)
auto input_node = node->arg(i);
auto const_node = dynamic_cast<luci::CircleConst *>(input_node);
if (const_node != nullptr)
- throw std::runtime_error("Unsupported Op for const inputs");
+ {
+ std::string msg = "Unsupported Op for const inputs: " + node->name();
+ throw std::runtime_error(msg);
+ }
}
}
@@ -221,6 +224,7 @@ QUANTIZE_SINGLE_CONST_INPUT(luci::CircleElu, features)
QUANTIZE_SINGLE_CONST_INPUT(luci::CircleExp, x)
QUANTIZE_SINGLE_CONST_INPUT(luci::CircleFloor, x)
QUANTIZE_SINGLE_CONST_INPUT(luci::CircleGather, params)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleGelu, features)
QUANTIZE_SINGLE_CONST_INPUT(luci::CircleLocalResponseNormalization, input)
QUANTIZE_SINGLE_CONST_INPUT(luci::CircleLogistic, x)
QUANTIZE_SINGLE_CONST_INPUT(luci::CircleMean, input)
@@ -242,6 +246,7 @@ QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSpaceToDepth, input)
QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSplit, input)
QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSplitV, input)
QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSqrt, x)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSqueeze, input)
QUANTIZE_SINGLE_CONST_INPUT(luci::CircleStridedSlice, input)
QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSum, input)
QUANTIZE_SINGLE_CONST_INPUT(luci::CircleTanh, x)
@@ -256,6 +261,7 @@ QUANTIZE_TWO_CONST_INPUTS(luci::CircleBatchMatMul, x, y)
QUANTIZE_TWO_CONST_INPUTS(luci::CircleDiv, x, y)
QUANTIZE_TWO_CONST_INPUTS(luci::CircleEqual, x, y)
QUANTIZE_TWO_CONST_INPUTS(luci::CircleFloorDiv, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleFloorMod, x, y)
QUANTIZE_TWO_CONST_INPUTS(luci::CircleGreater, x, y)
QUANTIZE_TWO_CONST_INPUTS(luci::CircleGreaterEqual, x, y)
QUANTIZE_TWO_CONST_INPUTS(luci::CircleLess, x, y)
diff --git a/compiler/luci/pass/src/QuantizeActivation.h b/compiler/luci/pass/src/QuantizeActivation.h
index c6c991a76..ba3bc59f2 100644
--- a/compiler/luci/pass/src/QuantizeActivation.h
+++ b/compiler/luci/pass/src/QuantizeActivation.h
@@ -111,6 +111,7 @@ private:
void visit(luci::CircleExp *node);
void visit(luci::CircleFloor *node);
void visit(luci::CircleGather *node);
+ void visit(luci::CircleGelu *node);
void visit(luci::CircleLocalResponseNormalization *node);
void visit(luci::CircleLogistic *node);
void visit(luci::CircleMean *node);
@@ -132,6 +133,7 @@ private:
void visit(luci::CircleSplit *node);
void visit(luci::CircleSplitV *node);
void visit(luci::CircleSqrt *node);
+ void visit(luci::CircleSqueeze *node);
void visit(luci::CircleStridedSlice *node);
void visit(luci::CircleSum *node);
void visit(luci::CircleTanh *node);
@@ -146,6 +148,7 @@ private:
void visit(luci::CircleDiv *node);
void visit(luci::CircleEqual *node);
void visit(luci::CircleFloorDiv *node);
+ void visit(luci::CircleFloorMod *node);
void visit(luci::CircleGreater *node);
void visit(luci::CircleGreaterEqual *node);
void visit(luci::CircleLess *node);
diff --git a/compiler/luci/pass/src/QuantizeBias.test.cpp b/compiler/luci/pass/src/QuantizeBias.test.cpp
index 0104a191b..9030f59e9 100644
--- a/compiler/luci/pass/src/QuantizeBias.test.cpp
+++ b/compiler/luci/pass/src/QuantizeBias.test.cpp
@@ -16,6 +16,8 @@
#include "QuantizeBias.h"
+#include "helpers/CreateCircleConst.h"
+
#include <luci/test/TestIOGraph.h>
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleQuantParam.h>
@@ -29,51 +31,6 @@ namespace
using namespace luci::test;
-// TODO Reduce duplicate codes in ResolveCustomOpMatMulPass.cpp
-template <typename T>
-luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype,
- const std::vector<uint32_t> &shape, T value)
-{
- auto node = g->nodes()->create<luci::CircleConst>();
- node->dtype(dtype);
- node->rank(shape.size());
-
- uint32_t size = 1;
- for (uint32_t i = 0; i < shape.size(); ++i)
- {
- node->dim(i) = shape.at(i);
- size *= shape.at(i);
- }
- node->shape_status(luci::ShapeStatus::VALID);
-
-#define INIT_VALUES(DT) \
- { \
- node->size<DT>(size); \
- for (uint32_t i = 0; i < size; ++i) \
- node->at<DT>(i) = value; \
- }
-
- switch (dtype)
- {
- case loco::DataType::U8:
- INIT_VALUES(loco::DataType::U8);
- break;
- case loco::DataType::S16:
- INIT_VALUES(loco::DataType::S16);
- break;
- case loco::DataType::S32:
- INIT_VALUES(loco::DataType::S32);
- break;
- case loco::DataType::FLOAT32:
- INIT_VALUES(loco::DataType::FLOAT32)
- break;
- default:
- INTERNAL_EXN("create_const_node called with unsupported type");
- break;
- }
- return node;
-}
-
/**
* Simple graph for test
*
diff --git a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp
index ef047d35d..f8989c9e0 100644
--- a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp
+++ b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp
@@ -110,8 +110,8 @@ void cal_minmax_per_channel(CircleConst *node, std::vector<float> &min, std::vec
}
void sym_wquant_per_channel(CircleConst *node, std::vector<float> &min, std::vector<float> &max,
- std::vector<float> &scaling_factor, std::vector<int64_t> &zp,
- std::vector<float> &nudged_min, std::vector<float> &nudged_max)
+ std::vector<float> &scaling_factor, std::vector<float> &nudged_min,
+ std::vector<float> &nudged_max)
{
assert(node->dtype() == loco::DataType::FLOAT32);
const int32_t kMaxScale = std::numeric_limits<int16_t>::max();
@@ -122,7 +122,7 @@ void sym_wquant_per_channel(CircleConst *node, std::vector<float> &min, std::vec
for (size_t i = 0; i < min.size(); ++i)
{
- compute_sym_scale_zp(min[i], max[i], scaling_factor[i], zp[i], nudged_min[i], nudged_max[i]);
+ compute_sym_scale(min[i], max[i], scaling_factor[i], nudged_min[i], nudged_max[i]);
}
auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) {
@@ -322,7 +322,7 @@ private:
}
else
{
- sym_wquant_per_channel(weights, min, max, scaling_factor, zp, nudged_min, nudged_max);
+ sym_wquant_per_channel(weights, min, max, scaling_factor, nudged_min, nudged_max);
sym_wdequant_per_channel(weights, scaling_factor);
}
diff --git a/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp b/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp
index 788353cd8..8f6a96f33 100644
--- a/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp
+++ b/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp
@@ -206,6 +206,7 @@ public:
transpose_conv->outBackprop(input_1);
transpose_conv->filter(filter);
transpose_conv->inputSizes(input_sizes);
+ transpose_conv->fusedActivationFunction(luci::FusedActFunc::NONE);
if (make_valid)
{
diff --git a/compiler/luci/pass/src/QuantizeWeights.cpp b/compiler/luci/pass/src/QuantizeWeights.cpp
index 29cdaffff..59329c19e 100644
--- a/compiler/luci/pass/src/QuantizeWeights.cpp
+++ b/compiler/luci/pass/src/QuantizeWeights.cpp
@@ -92,9 +92,8 @@ void asym_wquant_per_channel(CircleConst *node, std::vector<float> &min,
// TODO Reduce duplicate code with QuantizeDequantizeWeights
void sym_wquant_per_channel(CircleConst *node, std::vector<float> &min, std::vector<float> &max,
- std::vector<float> &scaling_factor, std::vector<int64_t> &zp,
- std::vector<float> &nudged_min, std::vector<float> &nudged_max,
- int32_t &channel_dim_index)
+ std::vector<float> &scaling_factor, std::vector<float> &nudged_min,
+ std::vector<float> &nudged_max, int32_t &channel_dim_index)
{
assert(node->dtype() == loco::DataType::FLOAT32);
const int32_t kMaxScale = std::numeric_limits<int16_t>::max();
@@ -105,7 +104,7 @@ void sym_wquant_per_channel(CircleConst *node, std::vector<float> &min, std::vec
for (size_t i = 0; i < min.size(); ++i)
{
- compute_sym_scale_zp(min[i], max[i], scaling_factor[i], zp[i], nudged_min[i], nudged_max[i]);
+ compute_sym_scale(min[i], max[i], scaling_factor[i], nudged_min[i], nudged_max[i]);
}
auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) {
@@ -383,7 +382,7 @@ void QuantizeWeights::quantize_weights(luci::CircleConst *weights)
}
else
{
- sym_wquant_per_channel(weights, min, max, scaling_factor, zp, nudged_min, nudged_max,
+ sym_wquant_per_channel(weights, min, max, scaling_factor, nudged_min, nudged_max,
channel_dim_index);
}
diff --git a/compiler/luci/pass/src/QuantizeWeightsOnly.cpp b/compiler/luci/pass/src/QuantizeWeightsOnly.cpp
new file mode 100644
index 000000000..e69a7b6a8
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizeWeightsOnly.cpp
@@ -0,0 +1,224 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "QuantizeWeightsOnly.h"
+#include "QuantizationUtils.h"
+
+#include <luci/Service/Nodes/CircleConst.h>
+#include <luci/Log.h>
+
+#include <cmath>
+#include <vector>
+#include <functional>
+#include <limits>
+
+using namespace luci;
+
+namespace
+{
+
+using IterFunc = std::function<void(uint32_t *, loco::TensorShape &, int32_t)>;
+
+void iterate_per_channel(CircleConst *node, int32_t &channel_dim_index, IterFunc func)
+{
+ loco::TensorShape dimension;
+ dimension.rank(4);
+ uint32_t indices[4] = {
+ 0,
+ };
+
+ if (!get_channel_dim_index(node, dimension, channel_dim_index))
+ {
+ assert(false);
+ return;
+ }
+
+ for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++)
+ {
+ for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++)
+ {
+ for (indices[2] = 0; indices[2] < dimension.dim(2).value(); indices[2]++)
+ {
+ for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++)
+ {
+ func(indices, dimension, channel_dim_index);
+ }
+ }
+ }
+ }
+}
+
+// TODO Reduce duplicate code with QuantizeDequantizeWeights
+template <loco::DataType out_type>
+void sym_wquant_per_channel(CircleConst *node, std::vector<float> &min, std::vector<float> &max,
+ std::vector<float> &scaling_factor, std::vector<float> &nudged_min,
+ std::vector<float> &nudged_max, int32_t &channel_dim_index)
+{
+ assert(node->dtype() == loco::DataType::FLOAT32);
+ assert(out_type == loco::DataType::S8 || out_type == loco::DataType::S16);
+ const int32_t kMaxScale = (out_type == loco::DataType::S8) ? std::numeric_limits<int8_t>::max()
+ : std::numeric_limits<int16_t>::max();
+ const int32_t kMinScale = -kMaxScale;
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+ std::vector<int32_t> quantized_values(size);
+
+ for (size_t i = 0; i < min.size(); ++i)
+ {
+ compute_sym_scale(min[i], max[i], scaling_factor[i], nudged_min[i], nudged_max[i], out_type);
+ }
+
+ auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) {
+ int channel_idx = indices[channel_dim_index];
+ const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
+ auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
+ data = data < nudged_min[channel_idx] ? nudged_min[channel_idx] : data;
+ data = data > nudged_max[channel_idx] ? nudged_max[channel_idx] : data;
+ quantized_values[cal_offset(dimension, indices)] =
+ static_cast<int32_t>(std::round(data * scaling_factor_inv));
+ };
+
+ iterate_per_channel(node, channel_dim_index, quantize);
+
+ node->dtype(out_type); // change the type of tensor
+ node->size<out_type>(size); // resize tensor
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ node->at<out_type>(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
+ }
+}
+
+void cal_minmax_per_channel(CircleConst *node, std::vector<float> &min, std::vector<float> &max,
+ int32_t &channel_dim_index)
+{
+ loco::TensorShape dimension;
+ dimension.rank(4);
+
+ if (!get_channel_dim_index(node, dimension, channel_dim_index))
+ {
+ throw std::runtime_error("Failed to find channel index in " + node->name());
+ }
+ auto size = dimension.dim(channel_dim_index).value();
+
+ std::vector<bool> has_min_max_value(size, false);
+ min.resize(size);
+ max.resize(size);
+
+ auto cal_minmax = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) {
+ int channel_idx = indices[channel_dim_index];
+ auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
+ if (has_min_max_value[channel_idx])
+ {
+ min[channel_idx] = data < min[channel_idx] ? data : min[channel_idx];
+ max[channel_idx] = data > max[channel_idx] ? data : max[channel_idx];
+ }
+ else
+ {
+ min[channel_idx] = data;
+ max[channel_idx] = data;
+ has_min_max_value[channel_idx] = true;
+ }
+ };
+
+ iterate_per_channel(node, channel_dim_index, cal_minmax);
+}
+
+} // namespace
+
+namespace luci
+{
+
+void QuantizeWeightsOnly::quantize_weights(luci::CircleConst *weights)
+{
+ // Find min/max per channel-wise
+ if (granularity == QuantizationGranularity::ChannelWise)
+ {
+ auto quantparam = weights->quantparam();
+ if (quantparam == nullptr)
+ {
+ // Find min/max on the fly
+ // NOTE This is for the case when QuantizeDequantizeWeights is skipped
+ // TODO Reduce duplicate codes
+ std::vector<float> min;
+ std::vector<float> max;
+ int32_t channel_dim_index = 0;
+
+ cal_minmax_per_channel(weights, min, max, channel_dim_index);
+
+ std::vector<float> nudged_min(min.size());
+ std::vector<float> nudged_max(min.size());
+ std::vector<float> scaling_factor(min.size());
+ std::vector<int64_t> zp(min.size());
+
+ if (output_type == loco::DataType::S8)
+ {
+ sym_wquant_per_channel<loco::DataType::S8>(weights, min, max, scaling_factor, nudged_min,
+ nudged_max, channel_dim_index);
+ }
+ else if (output_type == loco::DataType::S16)
+ {
+ sym_wquant_per_channel<loco::DataType::S16>(weights, min, max, scaling_factor, nudged_min,
+ nudged_max, channel_dim_index);
+ }
+ else
+ {
+ throw std::runtime_error("Weights-only quantization supports s8 and s16");
+ }
+
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->scale = scaling_factor;
+ quantparam->zerop = zp;
+ quantparam->quantized_dimension = channel_dim_index;
+ weights->quantparam(std::move(quantparam));
+
+ return;
+ }
+ }
+ else
+ throw std::runtime_error("Weights-only quantization does not support layer-wise");
+}
+
+void QuantizeWeightsOnly::visit(luci::CircleConv2D *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeWeightsOnly visits node: " << node->name() << std::endl;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
+ if (!is_quantized(weights))
+ {
+ auto new_weights = luci::clone(weights);
+ node->filter(new_weights);
+ quantize_weights(new_weights);
+ }
+}
+
+void QuantizeWeightsOnly::visit(luci::CircleDepthwiseConv2D *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeWeightsOnly visits node: " << node->name() << std::endl;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
+ if (!is_quantized(weights))
+ {
+ auto new_weights = luci::clone(weights);
+ node->filter(new_weights);
+ quantize_weights(new_weights);
+ }
+}
+
+void QuantizeWeightsOnly::visit(luci::CircleNode *) {}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/QuantizeWeightsOnly.h b/compiler/luci/pass/src/QuantizeWeightsOnly.h
new file mode 100644
index 000000000..ff6ad3261
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizeWeightsOnly.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_QUANTIZE_WEIGHTS_ONLY_H__
+#define __LUCI_QUANTIZE_WEIGHTS_ONLY_H__
+
+#include <luci/Pass/QuantizationParameters.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+namespace luci
+{
+
+/**
+ * @brief QuantizeWeightsOnly quantizes tensors for weights
+ * @details Find min/max values on the fly and then quantize
+ */
+struct QuantizeWeightsOnly final : public luci::CircleNodeMutableVisitor<void>
+{
+ QuantizeWeightsOnly(loco::DataType input, loco::DataType output, QuantizationGranularity gr)
+ : input_type(input), output_type(output), granularity(gr)
+ {
+ }
+
+ loco::DataType input_type;
+ loco::DataType output_type;
+ QuantizationGranularity granularity;
+
+private:
+ void quantize_weights(luci::CircleConst *weights);
+
+ void visit(luci::CircleConv2D *node);
+ void visit(luci::CircleDepthwiseConv2D *node);
+ void visit(luci::CircleNode *);
+};
+
+} // namespace luci
+
+#endif // __LUCI_QUANTIZE_WEIGHTS_ONLY_H__
diff --git a/compiler/luci/pass/src/QuantizeWeightsPass.cpp b/compiler/luci/pass/src/QuantizeWeightsPass.cpp
new file mode 100644
index 000000000..9ac203e77
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizeWeightsPass.cpp
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/QuantizeWeightsPass.h"
+#include "QuantizeWeightsOnly.h"
+#include "QuantizationUtils.h"
+
+#include <luci/Log.h>
+
+namespace luci
+{
+
+bool QuantizeWeightsPass::run(loco::Graph *g)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeWeightsPass Start" << std::endl;
+
+ if (_ctx->input_model_dtype != loco::DataType::FLOAT32)
+ throw std::runtime_error("Weights-only quantization supports float32 input only");
+
+ // Quantize weights
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ QuantizeWeightsOnly qw(_ctx->input_model_dtype, _ctx->output_model_dtype, _ctx->granularity);
+ circle_node->accept(&qw);
+ }
+
+ INFO(l) << "QuantizeWeightsPass End" << std::endl;
+ return false; // one time run
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/QuantizeWeightsPass.test.cpp b/compiler/luci/pass/src/QuantizeWeightsPass.test.cpp
new file mode 100644
index 000000000..058e029ab
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizeWeightsPass.test.cpp
@@ -0,0 +1,123 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/QuantizeWeightsPass.h"
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+struct QuantizeWeightsPassTest : public ::testing::Test
+{
+ /**
+ * nconv graph
+ *
+ * [CircleInput]
+ * |
+ * |
+ * [CircleConv2D]
+ * |
+ * |
+ * [CircleOutput]
+ */
+ void MakeGraph()
+ {
+ const int N = 1;
+ const int H = 4;
+ const int W = 4;
+ const int C = 3; // IC = OC
+
+ // graph input and output
+ auto graph_input = _g.inputs()->create();
+ auto graph_output = _g.outputs()->create();
+
+ // CircleInput
+ auto input = _g.nodes()->create<luci::CircleInput>();
+ input->index(graph_input->index());
+ input->shape({N, H, W, C});
+ input->dtype(loco::DataType::FLOAT32);
+ input->name("input");
+
+ // CircleConv2D
+ auto conv = _g.nodes()->create<luci::CircleConv2D>();
+ conv->input(input);
+ auto bias = _g.nodes()->create<luci::CircleConst>();
+ bias->dtype(loco::DataType::FLOAT32);
+ bias->shape({C});
+ bias->name("conv_bias");
+ conv->bias(bias);
+ auto weight = _g.nodes()->create<luci::CircleConst>();
+ weight->dtype(loco::DataType::FLOAT32);
+ weight->shape({C, H, W, C});
+ weight->size<loco::DataType::FLOAT32>(C * H * W * C);
+ conv->filter(weight);
+ conv->padding(luci::Padding::SAME);
+ conv->fusedActivationFunction(luci::FusedActFunc::NONE);
+ conv->dtype(loco::DataType::FLOAT32);
+ conv->name("nconv");
+
+ // CircleOutput
+ auto output = _g.nodes()->create<luci::CircleOutput>();
+ output->index(graph_output->index());
+ output->from(conv);
+ output->shape({N, H, W, C});
+ output->dtype(loco::DataType::FLOAT32);
+ output->name("output");
+ }
+ virtual void SetUp() { MakeGraph(); }
+ loco::Graph _g;
+};
+
+} // namespace
+
+TEST_F(QuantizeWeightsPassTest, name)
+{
+ luci::QuantizeWeightsPass pass(loco::DataType::FLOAT32, loco::DataType::S8,
+ luci::QuantizationGranularity::ChannelWise);
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST_F(QuantizeWeightsPassTest, name_ctx)
+{
+ auto ctx = std::make_unique<luci::QuantizeWeightsPass::Context>();
+ {
+ ctx->input_model_dtype = loco::DataType::FLOAT32;
+ ctx->output_model_dtype = loco::DataType::S8;
+ ctx->granularity = luci::QuantizationGranularity::ChannelWise;
+ }
+
+ luci::QuantizeWeightsPass pass(std::move(ctx));
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST_F(QuantizeWeightsPassTest, run_input_U8_NEG)
+{
+ loco::Graph g;
+ luci::QuantizeWeightsPass pass(loco::DataType::U8, loco::DataType::S8,
+ luci::QuantizationGranularity::ChannelWise);
+ EXPECT_THROW(pass.run(&_g), std::runtime_error);
+}
+
+TEST_F(QuantizeWeightsPassTest, run_output_f32_NEG)
+{
+ loco::Graph g;
+ luci::QuantizeWeightsPass pass(loco::DataType::FLOAT32, loco::DataType::FLOAT32,
+ luci::QuantizationGranularity::ChannelWise);
+ EXPECT_THROW(pass.run(&_g), std::runtime_error);
+}
diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
index c68e06712..4f4edaf36 100644
--- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
+++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
@@ -101,7 +101,7 @@ luci::CircleQuantize *create_quantize_op(luci::CircleNode *node, loco::DataType
else
{
assert(out_type == loco::DataType::S16);
- compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
+ compute_sym_scale(min, max, scaling_factor, nudged_min, nudged_max);
}
auto quantparam = std::make_unique<CircleQuantParam>();
@@ -271,6 +271,7 @@ private:
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleFloor, x)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleFullyConnected, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleGather, params)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleGelu, features)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleInstanceNorm, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLeakyRelu, features)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLocalResponseNormalization, input)
@@ -433,7 +434,7 @@ void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const
else
{
assert(user_given_dtype == loco::DataType::S16);
- compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
+ compute_sym_scale(min, max, scaling_factor, nudged_min, nudged_max);
}
input->quantparam()->scale[0] = scaling_factor;
input->quantparam()->zerop[0] = zp;
@@ -479,15 +480,15 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const
if (user_given_dtype == loco::DataType::FLOAT32)
{
auto dequant_op = create_dequantize(from);
- loco::replace(from).with(dequant_op);
dequant_op->input(from);
+ output->from(dequant_op);
}
else
{
// Insert Quantize Op for non-float32 output_type
auto quant_op = create_quantize_op(from, user_given_dtype);
- loco::replace(from).with(quant_op);
quant_op->input(from);
+ output->from(quant_op);
// TODO Set a proper origin (Quantize should have its own Origin)
luci::add_origin(quant_op, luci::get_origin(from));
@@ -629,6 +630,13 @@ bool QuantizeWithMinMaxPass::run(loco::Graph *g)
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+
+ // At this point, all activations have to be quantized.
+ // Un-quantized nodes are not the quantization target (ex: int32 tensor),
+ // so we skip them
+ if (circle_node->quantparam() == nullptr)
+ continue;
+
QuantizeSpecialActivation qsa(_ctx->input_model_dtype, quantize_dtype(circle_node));
circle_node->accept(&qsa);
}
diff --git a/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp b/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp
index 05ec31727..ae02edb3d 100644
--- a/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp
+++ b/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp
@@ -66,6 +66,8 @@ template <Type T> luci::CircleConst *create_dummy_const(loco::Graph *g, luci::te
// Fill with index
node->at<T>(i) = static_cast<int16_t>(i);
break;
+ default:
+ break;
}
}
}
@@ -470,15 +472,15 @@ public:
void init(void) override
{
TestIOGraph::init({32}, {32});
- _begin = g()->nodes()->create<luci::CircleConst>();
+ _begin = g()->nodes()->template create<luci::CircleConst>();
{
_begin->dtype(indexT);
}
- _size = g()->nodes()->create<luci::CircleConst>();
+ _size = g()->nodes()->template create<luci::CircleConst>();
{
_size->dtype(indexT);
}
- _slice = g()->nodes()->create<luci::CircleSlice>();
+ _slice = g()->nodes()->template create<luci::CircleSlice>();
{
_slice->input(input());
_slice->begin(_begin);
@@ -595,6 +597,31 @@ private:
luci::CircleConst *_strides = nullptr;
};
+class SumTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({4, 3, 2}, {2});
+
+ _axis = create_const<Type::S32, int32_t>(g(), {2}, {1, 0});
+ _sum = g()->nodes()->create<luci::CircleSum>();
+ {
+ _sum->input(input());
+ _sum->reduction_indices(_axis);
+ _sum->name("test");
+ _sum->keep_dims(false);
+ }
+ output()->from(_sum);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+private:
+ luci::CircleSum *_sum = nullptr;
+ luci::CircleConst *_axis = nullptr;
+};
+
class ReshapeTestGraph final : public SimpleTestGraph
{
public:
@@ -669,11 +696,11 @@ public:
TestIOGraph::init({32}, {1});
// output dtype is float by default, but ArgMax should have indexType (s32/s64)
output()->dtype(indexT);
- _dimension = g()->nodes()->create<luci::CircleConst>();
+ _dimension = g()->nodes()->template create<luci::CircleConst>();
{
_dimension->dtype(indexT);
}
- _argmax = g()->nodes()->create<luci::CircleArgMax>();
+ _argmax = g()->nodes()->template create<luci::CircleArgMax>();
{
_argmax->input(input());
_argmax->dimension(_dimension);
@@ -978,7 +1005,7 @@ public:
TestIOGraph::init({32}, {32});
output()->dtype(loco::DataType::BOOL);
_y = create_dummy_const<Type::FLOAT32>(g(), {32});
- _op = g()->nodes()->create<Op>();
+ _op = g()->nodes()->template create<Op>();
{
_op->x(input());
_op->y(_y);
@@ -1011,7 +1038,7 @@ public:
input()->dtype(loco::DataType::BOOL);
output()->dtype(loco::DataType::BOOL);
_y = create_dummy_const<Type::BOOL>(g(), {32});
- _op = g()->nodes()->create<Op>();
+ _op = g()->nodes()->template create<Op>();
{
_op->x(input());
_op->y(_y);
@@ -1315,7 +1342,7 @@ public:
TypedTestGraph::init(T, {32}, {32});
_const = create_dummy_const<T>(g(), {32});
- _mul = g()->nodes()->create<luci::CircleMul>();
+ _mul = g()->nodes()->template create<luci::CircleMul>();
{
_mul->x(input());
_mul->y(_const);
@@ -1370,7 +1397,7 @@ public:
TypedTestGraph::init(T, {32}, {32});
_const = create_dummy_const<T>(g(), {32});
- _add = g()->nodes()->create<luci::CircleAdd>();
+ _add = g()->nodes()->template create<luci::CircleAdd>();
{
_add->x(input());
_add->y(_const);
@@ -1786,6 +1813,34 @@ TEST(QuantizedModelVerifierTest, StridedSlice_wrong_granularity_NEG)
SUCCEED();
}
+TEST(QuantizedModelVerifierTest, Sum)
+{
+ TEST_WITH_GRAPH(SumTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(SumTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(SumTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(SumTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(SumTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(SumTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Sum_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(SumTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(SumTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(SumTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Sum_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(SumTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(SumTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(SumTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
TEST(QuantizedModelVerifierTest, ArgMax)
{
TEST_WITH_GRAPH(ArgMaxTestGraph<Type::S32>, Type::U8, Granularity::LayerWise);
diff --git a/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp b/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp
index 93024f3f7..194893f01 100644
--- a/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp
+++ b/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp
@@ -16,6 +16,8 @@
#include "luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h"
+#include "helpers/CreateCircleConst.h"
+
#include <luci/test/TestIOGraph.h>
#include <luci/IR/CircleNodes.h>
@@ -26,52 +28,6 @@ namespace
using namespace luci::test;
-// TODO Reduce duplicate codes in ResolveCustomOpMatMulPass.cpp
-template <typename T>
-luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype,
- const std::vector<uint32_t> &shape,
- const std::vector<T> &values)
-{
- auto node = g->nodes()->create<luci::CircleConst>();
- node->dtype(dtype);
- node->rank(shape.size());
-
- uint32_t size = 1;
- for (uint32_t i = 0; i < shape.size(); ++i)
- {
- node->dim(i) = shape.at(i);
- size *= shape.at(i);
- }
- node->shape_status(luci::ShapeStatus::VALID);
-
-#define INIT_VALUES(DT) \
- { \
- node->size<DT>(size); \
- for (uint32_t i = 0; i < values.size(); ++i) \
- node->at<DT>(i) = values[i]; \
- }
-
- switch (dtype)
- {
- case loco::DataType::U8:
- INIT_VALUES(loco::DataType::U8);
- break;
- case loco::DataType::S16:
- INIT_VALUES(loco::DataType::S16);
- break;
- case loco::DataType::S32:
- INIT_VALUES(loco::DataType::S32);
- break;
- case loco::DataType::FLOAT32:
- INIT_VALUES(loco::DataType::FLOAT32)
- break;
- default:
- INTERNAL_EXN("create_const_node called with unsupported type");
- break;
- }
- return node;
-}
-
/**
* Simple graph for test
*
@@ -104,7 +60,7 @@ public:
_tr_y = g->nodes()->create<luci::CircleTranspose>();
_tr_y->a(_y);
std::vector<int32_t> tr_val = {1, 0};
- _tr_y->perm(create_const_node(g, loco::DataType::S32, {2}, tr_val));
+ _tr_y->perm(luci::create_const_node(g, loco::DataType::S32, {2}, tr_val));
_fc = g->nodes()->create<luci::CircleFullyConnected>();
_fc->input(_x);
@@ -114,7 +70,7 @@ public:
_fc->shape(r_shape);
auto l = _fc->dim(_fc->rank() - 1).value();
std::vector<float> bias_val(l, bv);
- _fc->bias(create_const_node(g, loco::DataType::FLOAT32, {l}, bias_val));
+ _fc->bias(luci::create_const_node(g, loco::DataType::FLOAT32, {l}, bias_val));
_fc->name("fc");
}
diff --git a/compiler/luci/pass/src/ReplaceSubWithAddPass.cpp b/compiler/luci/pass/src/ReplaceSubWithAddPass.cpp
index 6bd83f5c5..f9102d836 100644
--- a/compiler/luci/pass/src/ReplaceSubWithAddPass.cpp
+++ b/compiler/luci/pass/src/ReplaceSubWithAddPass.cpp
@@ -17,6 +17,7 @@
#include "luci/Pass/ReplaceSubWithAddPass.h"
#include <luci/IR/CircleNodes.h>
+#include <luci/Profile/CircleNodeOrigin.h>
#include <luci/Service/Nodes/CircleConst.h>
namespace
@@ -47,6 +48,7 @@ bool replace_sub_with_const_rhs(luci::CircleSub *sub)
add->y(neg_const_rhs);
add->name(sub->name());
add->fusedActivationFunction(sub->fusedActivationFunction());
+ luci::add_origin(add, luci::get_origin(sub));
loco::replace(sub).with(add);
return true;
}
diff --git a/compiler/luci/pass/src/RequantizePass.cpp b/compiler/luci/pass/src/RequantizePass.cpp
index a56536251..77c55324a 100644
--- a/compiler/luci/pass/src/RequantizePass.cpp
+++ b/compiler/luci/pass/src/RequantizePass.cpp
@@ -32,37 +32,9 @@ namespace luci
namespace
{
-// Check if the node is the bias of Conv2D, DepthwiseConv2D, or FullyConnected layer
-bool is_bias(CircleConst *node)
-{
- if (node == nullptr)
- return false;
-
- auto succs = loco::succs(node);
- if (succs.size() != 1) // assume bias is used by only one node
- return false;
-
- for (auto out : succs)
- {
- auto conv = dynamic_cast<CircleConv2D *>(out);
- if (conv != nullptr && conv->bias() == node)
- return true;
-
- auto dw_conv = dynamic_cast<CircleDepthwiseConv2D *>(out);
- if (dw_conv != nullptr && dw_conv->bias() == node)
- return true;
-
- auto fc = dynamic_cast<CircleFullyConnected *>(out);
- if (fc != nullptr && fc->bias() == node)
- return true;
-
- auto tconv = dynamic_cast<CircleTransposeConv *>(out);
- if (tconv != nullptr && tconv->bias() == node)
- return true;
- }
- return false;
-}
-
+// Requantize Non-const node from int8 to uint8
+// Original values: -128 ~ 127
+// After requantization: 0 ~ 255
void requant_nonconst_int8_to_uint8(CircleNode *circle_node)
{
assert(circle_node->dtype() == loco::DataType::S8);
@@ -107,99 +79,48 @@ void requant_const_int8_to_uint8(CircleConst *node)
}
}
+#define RETURN_UNLESS(cond) \
+ if (not(cond)) \
+ return;
+
/**
- * @brief RequantizeNonConst requantizes tensors for activations
+ * @brief Requantize int8 quantized tensors to uint8 tensors
*/
-struct RequantizeNonConst final : public luci::CircleNodeMutableVisitor<bool>
+struct RequantizeS8ToU8 final : public luci::CircleNodeMutableVisitor<void>
{
- RequantizeNonConst(loco::DataType input, loco::DataType output)
- : _input_type(input), _output_type(output)
- {
- }
-
- loco::DataType _input_type;
- loco::DataType _output_type;
-
- // Requantize input tensors of each node
- bool visit(luci::CircleNode *node)
+ // Requantize non-const tensors
+ void visit(luci::CircleNode *node)
{
LOGGER(l);
- INFO(l) << "RequantizeNonConst visit node: " << node->name() << std::endl;
- auto arity = node->arity();
- for (uint32_t i = 0; i < arity; i++)
- {
- auto input_node = node->arg(i);
- auto circle_node = loco::must_cast<luci::CircleNode *>(input_node);
+ INFO(l) << "RequantizeS8ToU8 visit non-const node: " << node->name() << std::endl;
- // Check if this was quantized (only quantized tensors are requantized)
- if (circle_node->quantparam() == nullptr)
- continue;
+ // Ignore non-quantized tensors
+ RETURN_UNLESS(node->quantparam() != nullptr);
- // Check if this is already requantized
- if (circle_node->dtype() == _output_type)
- continue;
+ // Check dtype is int8
+ RETURN_UNLESS(node->dtype() == loco::DataType::S8);
- // Check if this is not const (only non-const is requantized in this function)
- auto circle_const = dynamic_cast<CircleConst *>(circle_node);
- if (circle_const != nullptr)
- continue;
-
- if (_input_type == loco::DataType::S8 && _output_type == loco::DataType::U8)
- requant_nonconst_int8_to_uint8(circle_node);
- }
- return false;
- }
-};
-
-/**
- * @brief RequantizeConst requantizes tensors for weights
- */
-struct RequantizeConst final : public luci::CircleNodeMutableVisitor<bool>
-{
- RequantizeConst(loco::DataType input, loco::DataType output)
- : _input_type(input), _output_type(output)
- {
+ requant_nonconst_int8_to_uint8(node);
}
- loco::DataType _input_type;
- loco::DataType _output_type;
-
- // Requantize input tensors of each node
- bool visit(luci::CircleNode *node)
+ // Requantize const tensors
+ void visit(luci::CircleConst *node)
{
LOGGER(l);
- INFO(l) << "RequantizeConst visit node: " << node->name() << std::endl;
- auto arity = node->arity();
- for (uint32_t i = 0; i < arity; i++)
- {
- auto input_node = node->arg(i);
- auto circle_node = loco::must_cast<luci::CircleNode *>(input_node);
+ INFO(l) << "RequantizeS8ToU8 visit const node: " << node->name() << std::endl;
- // Check if this was quantized (only quantized tensors are requantized)
- if (circle_node->quantparam() == nullptr)
- continue;
+ // Ignore non-quantized tensors
+ RETURN_UNLESS(node->quantparam() != nullptr);
- // Check if this is already requantized
- if (circle_node->dtype() == _output_type)
- continue;
+ // Check dtype is int8
+ RETURN_UNLESS(node->dtype() == loco::DataType::S8);
- // Check if this is const (only const is requantized in this function)
- auto circle_const = dynamic_cast<CircleConst *>(circle_node);
- if (circle_const == nullptr)
- continue;
-
- // Check if this is not bias
- // bias is not requantized when int8 -> uint8
- if (is_bias(circle_const))
- continue;
-
- if (_input_type == loco::DataType::S8 && _output_type == loco::DataType::U8)
- requant_const_int8_to_uint8(circle_const);
- }
- return false;
+ requant_const_int8_to_uint8(node);
}
};
+#undef RETURN_UNLESS
+
} // namespace
bool RequantizePass::run(loco::Graph *g)
@@ -207,20 +128,21 @@ bool RequantizePass::run(loco::Graph *g)
LOGGER(l);
INFO(l) << "RequantizePass Start" << std::endl;
- // Requantize non-const (activations)
- for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ // Input: int8 model
+ // Output: uint8 model
+ if (_input_dtype == loco::DataType::S8 and _output_dtype == loco::DataType::U8)
{
- RequantizeNonConst rqnc(_input_dtype, _output_dtype);
- auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- circle_node->accept(&rqnc);
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ RequantizeS8ToU8 rq;
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ circle_node->accept(&rq);
+ }
}
-
- // Requantize const (including weights, constants)
- for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ else
{
- RequantizeConst rqc(_input_dtype, _output_dtype);
- auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- circle_node->accept(&rqc);
+ // Ignore other cases
+ return false;
}
// Update output dtype
@@ -228,7 +150,8 @@ bool RequantizePass::run(loco::Graph *g)
for (auto node : loco::output_nodes(g))
{
auto circle_node = loco::must_cast<luci::CircleOutput *>(node);
- if (static_cast<luci::CircleNode *>(circle_node->from())->dtype() == _output_dtype)
+ auto from_node = loco::must_cast<luci::CircleNode *>(circle_node->from());
+ if (from_node->dtype() == _output_dtype)
{
circle_node->dtype(_output_dtype);
auto graph_output = graph_outputs->at(circle_node->index());
diff --git a/compiler/luci/pass/src/RequantizePass.test.cpp b/compiler/luci/pass/src/RequantizePass.test.cpp
index d26743c9d..a9293ce27 100644
--- a/compiler/luci/pass/src/RequantizePass.test.cpp
+++ b/compiler/luci/pass/src/RequantizePass.test.cpp
@@ -16,11 +16,167 @@
#include "luci/Pass/RequantizePass.h"
+#include "helpers/CreateCircleConst.h"
+
+#include <luci/test/TestIOGraph.h>
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleQuantParam.h>
+
+#include <vector>
+
#include <gtest/gtest.h>
+using namespace luci;
+using namespace luci::test;
+
+namespace
+{
+
+/**
+ * Simple graph for test
+ *
+ * BEFORE
+ *
+ * [IFM (S8)] [W (S8)] [B (S32)]
+ * | | |
+ * +-------+--------+
+ * |
+ * V
+ * [FC]
+ * |
+ * V
+ * [OFM(S8)]
+ *
+ * AFTER
+ *
+ * [IFM (U8)] [W (U8)] [B (S32)]
+ * | | |
+ * +-------+--------+
+ * |
+ * V
+ * [FC]
+ * |
+ * V
+ * [OFM(U8)]
+ */
+struct S8FCGraphlet
+{
+public:
+ S8FCGraphlet() = default;
+ virtual ~S8FCGraphlet() = default;
+
+ void init(loco::Graph *g, const ShapeU32 out_shape, const ShapeU32 w_shape,
+ const ShapeU32 bias_shape)
+ {
+ _fc = g->nodes()->create<CircleFullyConnected>();
+ _fc->input(_x);
+ _x->dtype(loco::DataType::S8);
+ {
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->scale.push_back(1.0);
+ quantparam->zerop.push_back(0);
+ quantparam->quantized_dimension = 0;
+ _x->quantparam(std::move(quantparam));
+ }
+
+ _weights = create_const_node<int8_t>(g, loco::DataType::S8, w_shape, 1.0);
+ {
+ auto w_qparam = std::make_unique<CircleQuantParam>();
+ std::vector<float> w_scale(_weights->dim(0).value(), 1.0);
+ std::vector<int64_t> w_zp(_weights->dim(0).value(), 0);
+ w_qparam->scale = w_scale;
+ w_qparam->zerop = w_zp;
+ w_qparam->quantized_dimension = 0;
+ _weights->quantparam(std::move(w_qparam));
+ }
+ _fc->weights(_weights);
+
+ _bias = create_const_node<int32_t>(g, loco::DataType::S32, bias_shape, 1.0);
+ {
+ auto b_qparam = std::make_unique<CircleQuantParam>();
+ const auto bias_size = _bias->size<loco::DataType::S32>();
+ std::vector<float> b_scale(bias_size, 1.0);
+ std::vector<int64_t> b_zp(bias_size, 0);
+ b_qparam->scale = b_scale;
+ b_qparam->zerop = b_zp;
+ b_qparam->quantized_dimension = 0;
+ _bias->quantparam(std::move(b_qparam));
+ }
+
+ _fc->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _fc->dtype(loco::DataType::S8);
+ _fc->shape(out_shape);
+ _fc->bias(_bias);
+ _fc->name("fc");
+ {
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->scale.push_back(1.0);
+ quantparam->zerop.push_back(0);
+ quantparam->quantized_dimension = 0;
+ _fc->quantparam(std::move(quantparam));
+ }
+ }
+
+public:
+ CircleFullyConnected *_fc = nullptr;
+ CircleInput *_x = nullptr;
+ CircleConst *_weights = nullptr;
+ CircleConst *_bias = nullptr;
+};
+
+struct S8FCGraph final : public TestIGraphlet, public TestOGraphlet, public S8FCGraphlet
+{
+ void init(const ShapeU32 in_shape, const ShapeU32 w_shape, const ShapeU32 out_shape,
+ const ShapeU32 bias_shape)
+ {
+ TestIGraphlet::init(g(), in_shape);
+ TestOGraphlet::init(g(), out_shape);
+ _x = input();
+ S8FCGraphlet::init(g(), out_shape, w_shape, bias_shape);
+ output()->from(_fc);
+ }
+};
+
+class RequantizeS8ToU8FCTest : public ::testing::Test
+{
+public:
+ S8FCGraph g;
+};
+
+} // namespace
+
TEST(RequantizePassTest, name)
{
luci::RequantizePass pass(loco::DataType::FLOAT32, loco::DataType::U8);
auto const name = pass.name();
ASSERT_NE(nullptr, name);
}
+
+TEST_F(RequantizeS8ToU8FCTest, FC)
+{
+ g.init({1, 18, 80} /* ifm shape */, {256, 80} /* weights shape*/, {18, 256} /* ofm shape */,
+ {1, 256} /* bias shape*/);
+
+ luci::RequantizePass rq(loco::DataType::S8, loco::DataType::U8);
+ rq.run(g.g());
+
+ EXPECT_EQ(loco::DataType::U8, g._x->dtype());
+ EXPECT_EQ(loco::DataType::U8, g._fc->dtype());
+ EXPECT_EQ(loco::DataType::U8, g._weights->dtype());
+ EXPECT_EQ(loco::DataType::S32, g._bias->dtype());
+}
+
+TEST_F(RequantizeS8ToU8FCTest, FC_wrong_dtype_NEG)
+{
+ g.init({1, 18, 80} /* ifm shape */, {256, 80} /* weights shape*/, {18, 256} /* ofm shape */,
+ {1, 256} /* bias shape*/);
+
+ // Wrong dtype
+ luci::RequantizePass rq(loco::DataType::U8, loco::DataType::S8);
+ rq.run(g.g());
+
+ EXPECT_EQ(loco::DataType::S8, g._x->dtype());
+ EXPECT_EQ(loco::DataType::S8, g._fc->dtype());
+ EXPECT_EQ(loco::DataType::S8, g._weights->dtype());
+ EXPECT_EQ(loco::DataType::S32, g._bias->dtype());
+}
diff --git a/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp b/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp
index f61882796..add55f66c 100644
--- a/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp
+++ b/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp
@@ -16,6 +16,8 @@
#include "luci/Pass/ResolveCustomOpMatMulPass.h"
+#include "helpers/CreateCircleConst.h"
+
#include <loco/IR/DataTypeTraits.h>
#include <luci/IR/CircleNodes.h>
@@ -29,51 +31,6 @@
namespace
{
-template <typename T>
-luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype,
- const std::vector<uint32_t> &shape,
- const std::vector<T> &values)
-{
- auto node = g->nodes()->create<luci::CircleConst>();
- node->dtype(dtype);
- node->rank(shape.size());
-
- uint32_t size = 1;
- for (uint32_t i = 0; i < shape.size(); ++i)
- {
- node->dim(i) = shape.at(i);
- size *= shape.at(i);
- }
- node->shape_status(luci::ShapeStatus::VALID);
-
-#define INIT_VALUES(DT) \
- { \
- node->size<DT>(size); \
- for (uint32_t i = 0; i < values.size(); ++i) \
- node->at<DT>(i) = values[i]; \
- }
-
- switch (dtype)
- {
- case loco::DataType::U8:
- INIT_VALUES(loco::DataType::U8);
- break;
- case loco::DataType::S16:
- INIT_VALUES(loco::DataType::S16);
- break;
- case loco::DataType::S32:
- INIT_VALUES(loco::DataType::S32);
- break;
- case loco::DataType::FLOAT32:
- INIT_VALUES(loco::DataType::FLOAT32)
- break;
- default:
- INTERNAL_EXN("create_const_node called with unsupported type");
- break;
- }
- return node;
-}
-
bool resolve_matmul(luci::CircleCustom *cop)
{
#define CHECK_OR_FALSE(condition) \
@@ -121,11 +78,12 @@ bool resolve_matmul(luci::CircleCustom *cop)
if (transpose_a)
{
// Create a permutation constant node
- std::vector<uint32_t> perm;
- for (uint32_t i = 0; i < circle_lhs->rank(); ++i)
+ std::vector<int32_t> perm;
+ const auto lhs_rank = static_cast<int32_t>(circle_lhs->rank());
+ for (int32_t i = 0; i < lhs_rank; ++i)
perm.push_back(i);
std::swap(perm[circle_lhs->rank() - 1], perm[circle_lhs->rank() - 2]);
- auto perm_node = create_const_node(graph, S32, {circle_lhs->rank()}, perm);
+ auto perm_node = luci::create_const_node(graph, S32, {circle_lhs->rank()}, perm);
perm_node->name(name + "/lhs/Transpose/perm");
// Now make a transpose node
auto transpose_node = graph->nodes()->create<luci::CircleTranspose>();
@@ -141,8 +99,8 @@ bool resolve_matmul(luci::CircleCustom *cop)
// in row-major order, thus we need to convert between them.
if (!transpose_b)
{
- const std::vector<uint32_t> perm{1, 0};
- auto perm_node = create_const_node(graph, S32, {2}, perm);
+ const std::vector<int32_t> perm{1, 0};
+ auto perm_node = luci::create_const_node(graph, S32, {2}, perm);
perm_node->name(name + "/rhs/Transpose/perm");
auto transpose_node = graph->nodes()->create<luci::CircleTranspose>();
transpose_node->a(rhs);
diff --git a/compiler/luci/pass/src/SubstituteSplitVToSplitPass.test.cpp b/compiler/luci/pass/src/SubstituteSplitVToSplitPass.test.cpp
index 6e30103f9..43f9cc116 100644
--- a/compiler/luci/pass/src/SubstituteSplitVToSplitPass.test.cpp
+++ b/compiler/luci/pass/src/SubstituteSplitVToSplitPass.test.cpp
@@ -16,6 +16,8 @@
#include "luci/Pass/SubstituteSplitVToSplitPass.h"
+#include "helpers/CreateCircleConst.h"
+
#include <luci/test/TestIOGraph.h>
#include <gtest/gtest.h>
@@ -30,51 +32,6 @@ const int C = 32;
const int H = 8;
const int W = 8;
-// Reduce duplicate codes in ResolveCustomOpMatMulPass.cpp
-template <typename T>
-luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype,
- const std::vector<uint32_t> &shape,
- const std::vector<T> &values)
-{
- auto node = g->nodes()->create<luci::CircleConst>();
- node->dtype(dtype);
- node->rank(shape.size());
-
- uint32_t size = 1;
- for (uint32_t i = 0; i < shape.size(); ++i)
- {
- node->dim(i) = shape.at(i);
- size *= shape.at(i);
- }
- node->shape_status(luci::ShapeStatus::VALID);
-
-#define INIT_VALUES(DT) \
- { \
- node->size<DT>(size); \
- for (uint32_t i = 0; i < values.size(); ++i) \
- node->at<DT>(i) = values[i]; \
- }
-
- switch (dtype)
- {
- case loco::DataType::U8:
- INIT_VALUES(loco::DataType::U8);
- break;
- case loco::DataType::S16:
- INIT_VALUES(loco::DataType::S16);
- break;
- case loco::DataType::S32:
- INIT_VALUES(loco::DataType::S32);
- break;
- case loco::DataType::FLOAT32:
- INIT_VALUES(loco::DataType::FLOAT32)
- break;
- default:
- INTERNAL_EXN("create_const_node called with unsupported type");
- break;
- }
- return node;
-}
/**
* graph having SplitV operator
*
@@ -95,10 +52,10 @@ public:
void init(loco::Graph *g)
{
const std::vector<int32_t> splits{16, 16};
- auto size_splits = create_const_node(g, loco::DataType::S32, {2}, splits);
+ auto size_splits = luci::create_const_node(g, loco::DataType::S32, {2}, splits);
const std::vector<int32_t> dim{3};
- auto split_dim = create_const_node(g, loco::DataType::S32, {1}, dim);
+ auto split_dim = luci::create_const_node(g, loco::DataType::S32, {1}, dim);
_sv = g->nodes()->create<luci::CircleSplitV>();
_sv->size_splits(size_splits);
diff --git a/compiler/luci/pass/src/VerifyQuantizedBiasScale.cpp b/compiler/luci/pass/src/VerifyQuantizedBiasScale.cpp
index e65d576cd..d40c19b9b 100644
--- a/compiler/luci/pass/src/VerifyQuantizedBiasScale.cpp
+++ b/compiler/luci/pass/src/VerifyQuantizedBiasScale.cpp
@@ -31,7 +31,7 @@ namespace
bool same(float a, float b)
{
constexpr float epsilon = 1e-10;
- return abs(a - b) < epsilon;
+ return std::abs(a - b) < epsilon;
}
// Check bias scale = input scale * weight scale
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h
index 6bf7ff698..cc618bf0e 100644
--- a/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h
+++ b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h
@@ -298,6 +298,13 @@ private:
return true;
}
+ bool visit(const luci::CircleSum *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->input()));
+ return true;
+ }
+
bool visit(const luci::CircleArgMax *node)
{
// node's output is index, thus not quantized
@@ -333,6 +340,13 @@ private:
return true;
}
+ bool visit(const luci::CircleGelu *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->features()));
+ return true;
+ }
+
bool visit(const luci::CircleGreater *node)
{
RETURN_FALSE_UNLESS(is_lwq(node->x()));
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp b/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp
index 3ce32555b..4bad9522b 100644
--- a/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp
+++ b/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp
@@ -181,6 +181,12 @@ bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleFullyCon
}
template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleGelu *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleGreater *node)
{
RETURN_FALSE_UNLESS(has_type(node, loco::DataType::BOOL))
@@ -454,6 +460,15 @@ bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleStridedS
}
template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSum *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->reduction_indices(), loco::DataType::S32))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleTranspose *node)
{
RETURN_FALSE_UNLESS(has_type(node, Qtype))
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeType.h b/compiler/luci/pass/src/VerifyQuantizedNodeType.h
index 789d3c7cd..03f1e1d86 100644
--- a/compiler/luci/pass/src/VerifyQuantizedNodeType.h
+++ b/compiler/luci/pass/src/VerifyQuantizedNodeType.h
@@ -88,6 +88,7 @@ private:
bool visit(const luci::CircleFloor *node);
bool visit(const luci::CircleFloorDiv *node);
bool visit(const luci::CircleFullyConnected *node);
+ bool visit(const luci::CircleGelu *node);
bool visit(const luci::CircleGreater *node);
bool visit(const luci::CircleGreaterEqual *node);
bool visit(const luci::CircleInstanceNorm *node);
@@ -119,6 +120,7 @@ private:
bool visit(const luci::CircleSplitVOut *node);
bool visit(const luci::CircleSqrt *node);
bool visit(const luci::CircleStridedSlice *node);
+ bool visit(const luci::CircleSum *node);
bool visit(const luci::CircleTranspose *node);
bool visit(const luci::CircleTransposeConv *node);
bool visit(const luci::CircleUnpack *node);
diff --git a/compiler/luci/pass/src/helpers/CreateCircleConst.cpp b/compiler/luci/pass/src/helpers/CreateCircleConst.cpp
new file mode 100644
index 000000000..bf1b0baf7
--- /dev/null
+++ b/compiler/luci/pass/src/helpers/CreateCircleConst.cpp
@@ -0,0 +1,20 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CreateCircleConst.h"
+
+// NOTE Do NOT delete this file; this file enforces compiler to check whether 'CreateCircleConst.h'
+// is complete.
diff --git a/compiler/luci/pass/src/helpers/CreateCircleConst.h b/compiler/luci/pass/src/helpers/CreateCircleConst.h
new file mode 100644
index 000000000..89c1a47be
--- /dev/null
+++ b/compiler/luci/pass/src/helpers/CreateCircleConst.h
@@ -0,0 +1,88 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_PASS_HELPERS_CREATE_CIRCLE_CONST_H__
+#define __LUCI_PASS_HELPERS_CREATE_CIRCLE_CONST_H__
+
+#include <luci/IR/CircleNodes.h>
+
+#include "TypeMapper.h"
+
+#include <vector>
+
+namespace luci
+{
+
+// Create CircleConst filled with a single value
+// Never return nullptr
+// TODO Remove dtype from the argument
+template <typename T>
+CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype,
+ const std::vector<uint32_t> &shape, const T value)
+{
+ auto node = g->nodes()->create<CircleConst>();
+ node->dtype(dtype);
+ node->rank(shape.size());
+
+ uint32_t size = 1;
+ for (uint32_t i = 0; i < shape.size(); ++i)
+ {
+ node->dim(i) = shape.at(i);
+ size *= shape.at(i);
+ }
+ node->shape_status(ShapeStatus::VALID);
+
+ node->size<TypeMapper<T>::get()>(size);
+ for (uint32_t i = 0; i < size; i++)
+ {
+ node->at<TypeMapper<T>::get()>(i) = value;
+ }
+
+ return node;
+}
+
+// Create CircleConst filled with values
+// Never return nullptr
+// TODO Remove dtype from the argument
+template <typename T>
+luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype,
+ const std::vector<uint32_t> &shape,
+ const std::vector<T> &values)
+{
+ auto node = g->nodes()->create<luci::CircleConst>();
+ node->dtype(dtype);
+ node->rank(shape.size());
+
+ uint32_t size = 1;
+ for (uint32_t i = 0; i < shape.size(); ++i)
+ {
+ node->dim(i) = shape.at(i);
+ size *= shape.at(i);
+ }
+ node->shape_status(luci::ShapeStatus::VALID);
+
+ node->size<TypeMapper<T>::get()>(size);
+ for (uint32_t i = 0; i < size; i++)
+ {
+ node->at<TypeMapper<T>::get()>(i) = values[i];
+ }
+
+ return node;
+}
+
+} // namespace luci
+
+#endif // __LUCI_PASS_HELPERS_CREATE_CIRCLE_CONST_H__
diff --git a/compiler/luci/pass/src/helpers/TypeMapper.h b/compiler/luci/pass/src/helpers/TypeMapper.h
index 90760e95b..a3e27d259 100644
--- a/compiler/luci/pass/src/helpers/TypeMapper.h
+++ b/compiler/luci/pass/src/helpers/TypeMapper.h
@@ -14,6 +14,9 @@
* limitations under the License.
*/
+#ifndef __LUCI_PASS_HELPERS_TYPE_MAPPER_H__
+#define __LUCI_PASS_HELPERS_TYPE_MAPPER_H__
+
#include <loco/IR/DataType.h>
#include <cstdint>
@@ -75,3 +78,5 @@ template <> struct TypeMapper<int64_t>
};
} // namespace luci
+
+#endif // __LUCI_PASS_HELPERS_TYPE_MAPPER_H__