diff options
Diffstat (limited to 'compiler/luci/pass')
56 files changed, 2668 insertions, 427 deletions
diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h index d77e89db1..6ebacee39 100644 --- a/compiler/luci/pass/include/luci/CircleOptimizer.h +++ b/compiler/luci/pass/include/luci/CircleOptimizer.h @@ -63,6 +63,7 @@ public: MakeBatchNormGammaPositive, FuseActivationFunction, FusePRelu, + FuseGelu, ShuffleWeightTo16x1Float32, RemoveRedundantTranspose, ReplaceMulAddWithDepthwiseConv, @@ -80,6 +81,7 @@ public: RemoveUnnecessaryReshape, TransformMinMaxToRelu6Pass, TransformMinReluToRelu6Pass, + DecomposeHardSwishPass, SubstituteStridedSliceToReshape, SubstituteTransposeToReshape, RemoveRedundantQuantize, diff --git a/compiler/luci/pass/include/luci/CircleQuantizer.h b/compiler/luci/pass/include/luci/CircleQuantizer.h index 4e7074d98..463f31790 100644 --- a/compiler/luci/pass/include/luci/CircleQuantizer.h +++ b/compiler/luci/pass/include/luci/CircleQuantizer.h @@ -45,6 +45,7 @@ public: CopyQuantParam, ForceQuantParam, ConvertToFakeQuantizedModel, + QuantizeWeights, }; enum AlgorithmParameters diff --git a/compiler/luci/pass/include/luci/DynamicBatchToSingleBatch.h b/compiler/luci/pass/include/luci/DynamicBatchToSingleBatch.h new file mode 100644 index 000000000..2a02777f6 --- /dev/null +++ b/compiler/luci/pass/include/luci/DynamicBatchToSingleBatch.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_DYNAMIC_BATCH_TO_SINGLE_BATCH_H__ +#define __LUCI_DYNAMIC_BATCH_TO_SINGLE_BATCH_H__ + +#include <luci/IR/Module.h> + +namespace luci +{ + +void dynamic_batch_to_single_batch(luci::Module *); + +} // namespace luci + +#endif // __LUCI_DYNAMIC_BATCH_TO_SINGLE_BATCH_H__ diff --git a/compiler/luci/pass/include/luci/Pass/DecomposeHardSwishPass.h b/compiler/luci/pass/include/luci/Pass/DecomposeHardSwishPass.h new file mode 100644 index 000000000..83c16bcee --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/DecomposeHardSwishPass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_DECOMPOSE_HARDSWISH_PASS_H__ +#define __LUCI_DECOMPOSE_HARDSWISH_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to decompose HardSwish to Add, Mul and Relu6 + */ +struct DecomposeHardSwishPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::DecomposeHardSwishPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_DECOMPOSE_HARDSWISH_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/DynamicBatchToSingleBatchPass.h b/compiler/luci/pass/include/luci/Pass/DynamicBatchToSingleBatchPass.h new file mode 100644 index 000000000..b3598c986 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/DynamicBatchToSingleBatchPass.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_DYNAMIC_BATCH_TO_SINGLE_BATCH_PASS_H__ +#define __LUCI_DYNAMIC_BATCH_TO_SINGLE_BATCH_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Pass to convert dynamic batch to single batch + */ +class DynamicBatchToSingleBatchPass : public logo::Pass +{ +public: + virtual const char *name(void) const { return "luci::DynamicBatchToSingleBatchPass"; } + +public: + bool run(loco::Graph *graph); +}; + +} // namespace luci + +#endif //__LUCI_DYNAMIC_BATCH_TO_SINGLE_BATCH_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/FuseGeluPass.h b/compiler/luci/pass/include/luci/Pass/FuseGeluPass.h new file mode 100644 index 000000000..5fa23036c --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FuseGeluPass.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FUSE_GELU_PASS_H__ +#define __LUCI_FUSE_GELU_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to fuse certain pattern of subgraph into CircleGelu + * + * For detailed subgraph pattern to be fused, please check its implementation. + */ +struct FuseGeluPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FuseGeluPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FUSE_GELU_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/QuantizeWeightsPass.h b/compiler/luci/pass/include/luci/Pass/QuantizeWeightsPass.h new file mode 100644 index 000000000..646597312 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/QuantizeWeightsPass.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_QUANTIZE_WEIGHTS_PASS_H__ +#define __LUCI_QUANTIZE_WEIGHTS_PASS_H__ + +#include <loco.h> + +#include <logo/Pass.h> + +#include <luci/Pass/QuantizationParameters.h> + +namespace luci +{ + +/** + * @brief Pass to quantize weights + */ +class QuantizeWeightsPass : public logo::Pass +{ +public: + struct Context + { + loco::DataType input_model_dtype = loco::DataType::Unknown; + loco::DataType output_model_dtype = loco::DataType::Unknown; + QuantizationGranularity granularity = QuantizationGranularity::ChannelWise; + }; + +public: + QuantizeWeightsPass(std::unique_ptr<Context> &&ctx) : _ctx{std::move(ctx)} + { + // DO NOTHING + } + +public: + QuantizeWeightsPass(loco::DataType input_model_dtype, loco::DataType output_model_dtype, + QuantizationGranularity granularity) + { + _ctx = std::make_unique<Context>(); + { + _ctx->input_model_dtype = input_model_dtype; + _ctx->output_model_dtype = output_model_dtype; + _ctx->granularity = granularity; + } + } + virtual const char *name(void) const { return "luci::QuantizeWeightsPass"; } + +public: + bool run(loco::Graph *graph); + +private: + std::unique_ptr<Context> _ctx; +}; + +} // namespace luci + +#endif //__LUCI_QUANTIZE_WEIGHTS_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/RequantizePass.h b/compiler/luci/pass/include/luci/Pass/RequantizePass.h index c6c424f1b..50b9073b5 100644 --- a/compiler/luci/pass/include/luci/Pass/RequantizePass.h +++ b/compiler/luci/pass/include/luci/Pass/RequantizePass.h @@ -27,7 +27,7 @@ namespace luci { /** - * @brief Pass to quantize weights + * @brief Pass to re-quantize graph (ex: int8 -> uint8) */ class RequantizePass : public logo::Pass { diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp index 5e1613ad9..b011581af 100644 --- a/compiler/luci/pass/src/CircleOptimizer.cpp +++ b/compiler/luci/pass/src/CircleOptimizer.cpp @@ -39,6 +39,7 @@ #include "luci/Pass/FuseMeanWithMeanPass.h" #include "luci/Pass/FusePreActivationBatchNormPass.h" #include "luci/Pass/FusePReluPass.h" +#include "luci/Pass/FuseGeluPass.h" #include "luci/Pass/FuseTransposeWithMeanPass.h" #include "luci/Pass/MakeBatchNormGammaPositivePass.h" #include "luci/Pass/RemoveDuplicateConstPass.h" @@ -70,6 +71,7 @@ #include "luci/Pass/SubstituteTransposeToReshapePass.h" #include "luci/Pass/TransformMinMaxToRelu6Pass.h" #include "luci/Pass/TransformMinReluToRelu6Pass.h" +#include "luci/Pass/DecomposeHardSwishPass.h" #include "luci/Pass/UnrollUnidirectionalSequenceLSTMPass.h" // TODO add more passes @@ -137,7 +139,8 @@ bool OptimizeOptionsImpl::query(Algorithm algo) } // TODO Make a struct for args -void convert_nchw_to_nhwc(loco::Graph *g, bool preserve_input, bool preserve_output, bool fuse_fc) +void convert_nchw_to_nhwc(loco::Graph *g, bool preserve_input, bool preserve_output, bool fuse_fc, + bool fuse_gelu) { logo::Phase phase; @@ -160,6 +163,12 @@ void convert_nchw_to_nhwc(loco::Graph *g, bool preserve_input, bool preserve_out if (fuse_fc) phase.emplace_back(std::make_unique<luci::FuseAddWithFullyConnectedPass>()); + // Fuse decomposed ops to Gelu Op + // Why here? ConverNCHWToNHWCPass inserts additional Ops, so it is better to fuse + // Gelu in advance. + if (fuse_gelu) + phase.emplace_back(std::make_unique<luci::FuseGeluPass>()); + phase.emplace_back( std::make_unique<luci::ConvertNCHWToNHWCPass>(preserve_input, preserve_output)); @@ -216,8 +225,9 @@ void CircleOptimizer::optimize(loco::Graph *g) const _options->param(Options::AlgorithmParameters::NCHW_to_NHWC_output_shape) != "true"; bool fuse_fc = _options->query(Options::Algorithm::FuseAddWithFullyConnected); + bool fuse_gelu = _options->query(Options::Algorithm::FuseGelu); - convert_nchw_to_nhwc(g, preserve_input, preserve_output, fuse_fc); + convert_nchw_to_nhwc(g, preserve_input, preserve_output, fuse_fc, fuse_gelu); } /* TRANSFORM DECLARATION BEGIN */ @@ -283,6 +293,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique<FusePReluPass>()); } + if (_options->query(Options::Algorithm::FuseGelu)) + { + phase.emplace_back(std::make_unique<FuseGeluPass>()); + } if (_options->query(Options::Algorithm::FuseTransposeWithMean)) { phase.emplace_back(std::make_unique<FuseTransposeWithMeanPass>()); @@ -319,14 +333,6 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique<luci::FoldSparseToDensePass>()); } - if (_options->query(Options::Algorithm::ForwardReshapeToUnaryOp)) - { - phase.emplace_back(std::make_unique<luci::ForwardReshapeToUnaryOpPass>()); - } - if (_options->query(Options::Algorithm::ForwardTransposeOp)) - { - phase.emplace_back(std::make_unique<luci::ForwardTransposeOpPass>()); - } if (_options->query(Options::Algorithm::FusePreActivationBatchNorm)) { phase.emplace_back(std::make_unique<luci::FusePreActivationBatchNormPass>()); @@ -428,10 +434,26 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique<luci::TransformMinReluToRelu6Pass>()); } + if (_options->query(Options::Algorithm::DecomposeHardSwishPass)) + { + phase.emplace_back(std::make_unique<luci::DecomposeHardSwishPass>()); + } if (_options->query(Options::Algorithm::UnrollUnidirSeqLSTM)) { phase.emplace_back(std::make_unique<luci::UnrollUnidirectionalSequenceLSTMPass>()); } + // Forward Reshape/Transpose is done after + // 1. SubstituteXXXToReshape + // 2. RemoveRedundantReshape/Transpose + // See https://github.com/Samsung/ONE/pull/10596 for more details + if (_options->query(Options::Algorithm::ForwardReshapeToUnaryOp)) + { + phase.emplace_back(std::make_unique<luci::ForwardReshapeToUnaryOpPass>()); + } + if (_options->query(Options::Algorithm::ForwardTransposeOp)) + { + phase.emplace_back(std::make_unique<luci::ForwardTransposeOpPass>()); + } /* TRANSFORM DECLARATION END */ diff --git a/compiler/luci/pass/src/CircleQuantizer.cpp b/compiler/luci/pass/src/CircleQuantizer.cpp index 3ffa1180c..9039a839f 100644 --- a/compiler/luci/pass/src/CircleQuantizer.cpp +++ b/compiler/luci/pass/src/CircleQuantizer.cpp @@ -26,6 +26,7 @@ #include "luci/Pass/QuantizePreCheckerPass.h" #include "luci/Pass/QuantizeWithMinMaxPass.h" #include "luci/Pass/QuantizeDequantizeWeightsPass.h" +#include "luci/Pass/QuantizeWeightsPass.h" #include "luci/Pass/CircleShapeInferencePass.h" #include "luci/Pass/CircleTypeInferencePass.h" @@ -439,14 +440,14 @@ void CircleQuantizer::quantize(loco::Graph *g) const throw std::runtime_error("Unsupported granularity. List of supported granularity: " + to_string(qwmm_supported_granularity)); - for (auto dtype : input_type_vec) + for (const auto &dtype : input_type_vec) { if (!in_array(to_lower_case(dtype), qwmm_supported_input_type)) throw std::runtime_error("Unsupported input type. List of supported input types: " + to_string(qwmm_supported_input_type)); } - for (auto dtype : output_type_vec) + for (const auto &dtype : output_type_vec) { if (!in_array(to_lower_case(dtype), qwmm_supported_output_type)) throw std::runtime_error("Unsupported output type. List of supported output types: " + @@ -536,6 +537,40 @@ void CircleQuantizer::quantize(loco::Graph *g) const verifier.verify(g); } + if (_options->query(Options::Algorithm::QuantizeWeights)) + { + static const std::vector<std::string> qw_supported_input_model_dtype{"float32"}; + static const std::vector<std::string> qw_supported_output_model_dtype{"int8", "int16"}; + static const std::vector<std::string> qw_supported_granularity{"channel"}; + + auto input_model_dtype = + _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype); + auto output_model_dtype = + _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype); + auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity); + + if (!in_array(to_lower_case(input_model_dtype), qw_supported_input_model_dtype)) + throw std::runtime_error("Unsupported input type. List of supported input type: " + + to_string(qw_supported_input_model_dtype)); + + if (!in_array(to_lower_case(output_model_dtype), qw_supported_output_model_dtype)) + throw std::runtime_error("Unsupported output type. List of supported output type: " + + to_string(qw_supported_output_model_dtype)); + + if (!in_array(to_lower_case(granularity), qw_supported_granularity)) + throw std::runtime_error("Unsupported granularity. List of supported granularity: " + + to_string(qw_supported_granularity)); + auto ctx = std::make_unique<luci::QuantizeWeightsPass::Context>(); + { + ctx->input_model_dtype = str_to_dtype(input_model_dtype); + ctx->output_model_dtype = str_to_dtype(output_model_dtype); + ctx->granularity = str_to_granularity(granularity); + } + luci::QuantizeWeightsPass weights_quantizer(std::move(ctx)); + + weights_quantizer.run(g); + } + // Requantize if (_options->query(Options::Algorithm::Requantize)) { diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp index 99e1e2939..ac4320246 100644 --- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp +++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp @@ -55,16 +55,18 @@ bool broadcastable(const luci::CircleConst *from, const luci::CircleNode *to) return true; } -// Expand node to rank 4 +// Return node with rank 4 // node should have rank less than or equal to 4 -void expand_to_rank_4(luci::CircleConst *node) +// 1 is inserted to the front of shape if rank is less than 4 +// For example, [2] -> [1, 1, 1, 2] +luci::CircleConst *expand_to_rank_4(luci::CircleConst *node) { auto original_rank = node->rank(); assert(original_rank <= 4); // FIX_CALLER_UNLESS if (original_rank == 4) - return; + return node; std::vector<uint32_t> original_shape; for (uint32_t i = 0; i < original_rank; i++) @@ -72,12 +74,17 @@ void expand_to_rank_4(luci::CircleConst *node) original_shape.emplace_back(node->dim(i).value()); } - node->rank(4); + auto cloned = luci::clone(node); + cloned->name(cloned->name() + "_rank4"); + + cloned->rank(4); for (uint32_t i = 0; i < (4 - original_rank); i++) - node->dim(i) = 1; + cloned->dim(i) = 1; for (uint32_t i = 0; i < original_rank; i++) - node->dim(i + (4 - original_rank)) = original_shape.at(i); + cloned->dim(i + (4 - original_rank)) = original_shape.at(i); + + return cloned; } bool is_output(const loco::Node *node) @@ -564,7 +571,7 @@ bool is_NCHW_with_const(const luci::CircleMul *node, luci::CircleNode *&pred_nod if (not broadcastable(multiplier, node)) return false; - expand_to_rank_4(multiplier); + multiplier = expand_to_rank_4(multiplier); return true; } @@ -602,7 +609,7 @@ bool is_NCHW_with_const(const luci::CircleAdd *node, luci::CircleNode *&pred_nod if (not broadcastable(beta, node)) return false; - expand_to_rank_4(beta); + beta = expand_to_rank_4(beta); return true; } @@ -834,6 +841,8 @@ class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool> bool visit(luci::CircleElu *node) { return convert_unary_features<luci::CircleElu>(node); } + bool visit(luci::CircleGelu *node) { return convert_unary_features<luci::CircleGelu>(node); } + bool visit(luci::CircleLeakyRelu *node) { return convert_unary_features<luci::CircleLeakyRelu>(node); @@ -1510,6 +1519,7 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g) case luci::CircleOpcode::ADD: case luci::CircleOpcode::CONCATENATION: case luci::CircleOpcode::ELU: + case luci::CircleOpcode::GELU: case luci::CircleOpcode::LEAKY_RELU: case luci::CircleOpcode::LOGISTIC: case luci::CircleOpcode::MAXIMUM: diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp index fd326518e..85648cf2c 100644 --- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp +++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp @@ -535,6 +535,8 @@ public: luci::CircleMaximum *max = nullptr; }; +static constexpr std::initializer_list<uint32_t> kDefaultShape = {1, 16, 1, 1}; + class MeanGraph final : public SimpleGraph { protected: @@ -577,7 +579,7 @@ public: private: bool _keep_dims = true; std::vector<int32_t> _axes = {2, 3}; - std::initializer_list<uint32_t> _shape = {1, 16, 1, 1}; + std::initializer_list<uint32_t> _shape = kDefaultShape; }; class MinimumGraph final : public SimpleGraph @@ -876,7 +878,7 @@ public: private: bool _keep_dims = true; std::vector<int32_t> _axes = {2, 3}; - std::initializer_list<uint32_t> _shape = {1, 16, 1, 1}; + std::initializer_list<uint32_t> _shape = kDefaultShape; }; class ReduceMinGraph final : public SimpleGraph @@ -921,7 +923,7 @@ public: private: bool _keep_dims = true; std::vector<int32_t> _axes = {2, 3}; - std::initializer_list<uint32_t> _shape = {1, 16, 1, 1}; + std::initializer_list<uint32_t> _shape = kDefaultShape; }; class ReluGraph final : public SimpleGraph diff --git a/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp index aacfce3d0..ae5ab1519 100644 --- a/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp +++ b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp @@ -198,6 +198,7 @@ struct FakeQuantize final : public luci::CircleNodeMutableVisitor<void> void visit(luci::CircleDepthwiseConv2D *node) { fq_activation(node); } void visit(luci::CircleDiv *node) { fq_activation(node); } void visit(luci::CircleFullyConnected *node) { fq_activation(node); } + void visit(luci::CircleGelu *node) { fq_activation(node); } void visit(luci::CircleInstanceNorm *node) { fq_activation(node); } void visit(luci::CircleLeakyRelu *node) { fq_activation(node); } void visit(luci::CircleLogistic *node) { fq_activation(node); } @@ -217,6 +218,9 @@ struct FakeQuantize final : public luci::CircleNodeMutableVisitor<void> void visit(luci::CircleRsqrt *node) { fq_activation(node); } void visit(luci::CircleSoftmax *node) { fq_activation(node); } void visit(luci::CircleSqrt *node) { fq_activation(node); } + void visit(luci::CircleSquaredDifference *node) { fq_activation(node); } + void visit(luci::CircleSub *node) { fq_activation(node); } + void visit(luci::CircleSum *node) { fq_activation(node); } void visit(luci::CircleTanh *node) { fq_activation(node); } void visit(luci::CircleTransposeConv *node) { fq_activation(node); } diff --git a/compiler/luci/pass/src/DecomposeHardSwishPass.cpp b/compiler/luci/pass/src/DecomposeHardSwishPass.cpp new file mode 100644 index 000000000..bd99d2de0 --- /dev/null +++ b/compiler/luci/pass/src/DecomposeHardSwishPass.cpp @@ -0,0 +1,147 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/DecomposeHardSwishPass.h" + +#include "helpers/NodeFiller.h" +#include "helpers/TypeMapper.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/Profile/CircleNodeOrigin.h> + +namespace +{ +/** + * BEFORE + * [CircleNode] + * | + * | + * [CircleHardSwish] + * | + * | + * [CircleNode] + * + * + * AFTER + * + * [CircleNode] [CircleConst] + * | \ / + * | \ / + * | [CircleAdd] + * | | + * | | + * \ [CircleRelu6] [CircleConst] + * \ \ / + * \ \ / + * \ [CircleMul] + * \ / + * \ / + * [CircleMul] + * | + * | + * [CircleNode] + * + */ +bool decompose_hardswish(luci::CircleHardSwish *hardswish) +{ + if (not hardswish) + return false; + + if (hardswish->dtype() != loco::DataType::FLOAT32) + return false; + + auto g = hardswish->graph(); + + auto name = hardswish->name(); + assert(name.length() > 0); + + // Create a const for CircleAdd operation + auto add_const = g->nodes()->create<luci::CircleConst>(); + add_const->shape({}); // scalar + add_const->dtype(loco::DataType::FLOAT32); + add_const->rank(0); + add_const->size<loco::DataType::FLOAT32>(1); + add_const->at<loco::DataType::FLOAT32>(0) = 3.; + add_const->name(name + "/Add/const"); + luci::add_origin(add_const, luci::get_origin(hardswish)); + + // Create an Add operation + auto add = g->nodes()->create<luci::CircleAdd>(); + add->fusedActivationFunction(luci::FusedActFunc::NONE); + add->x(hardswish->features()); + add->y(add_const); + add->name(name + "/Add"); + luci::add_origin(add, luci::get_origin(hardswish)); + + // Create a Relu6 operation + auto relu6 = g->nodes()->create<luci::CircleRelu6>(); + relu6->features(add); + relu6->name(name + "/Relu6"); + luci::add_origin(relu6, luci::get_origin(hardswish)); + + // Create a const for CircleMul operation + auto mul_const = g->nodes()->create<luci::CircleConst>(); + mul_const->shape({}); // scalar + mul_const->dtype(loco::DataType::FLOAT32); + mul_const->rank(0); + mul_const->size<loco::DataType::FLOAT32>(1); + mul_const->at<loco::DataType::FLOAT32>(0) = 1. / 6.; + mul_const->name(name + "/Mul/const"); + luci::add_origin(mul_const, luci::get_origin(hardswish)); + + // Create first Mul operation + auto mul1 = g->nodes()->create<luci::CircleMul>(); + mul1->fusedActivationFunction(luci::FusedActFunc::NONE); + mul1->x(relu6); + mul1->y(mul_const); + mul1->name(name + "/Mul1"); + luci::add_origin(mul1, luci::get_origin(hardswish)); + + // Create second Mul operation + auto mul2 = g->nodes()->create<luci::CircleMul>(); + mul2->fusedActivationFunction(luci::FusedActFunc::NONE); + mul2->x(hardswish->features()); + mul2->y(mul1); + mul2->name(name + "/Mul2"); + luci::add_origin(mul2, luci::get_origin(hardswish)); + + replace(hardswish).with(mul2); + + return true; +} + +} // namespace + +namespace luci +{ + +bool DecomposeHardSwishPass::run(loco::Graph *g) +{ + bool changed = false; + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto hardswish = dynamic_cast<luci::CircleHardSwish *>(node)) + { + if (decompose_hardswish(hardswish)) + changed = true; + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/DecomposeHardSwishPass.test.cpp b/compiler/luci/pass/src/DecomposeHardSwishPass.test.cpp new file mode 100644 index 000000000..d51a07fdc --- /dev/null +++ b/compiler/luci/pass/src/DecomposeHardSwishPass.test.cpp @@ -0,0 +1,205 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/DecomposeHardSwishPass.h" + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +namespace +{ + +/** + * HardSwish graph + * + * [CircleInput] + * | + * | + * [CircleHardSwish] + * | + * | + * [CircleOutput] + */ +struct HardSwishGraph +{ + loco::Graph _g; + luci::CircleInput *_input = nullptr; + luci::CircleHardSwish *_hardswish = nullptr; + luci::CircleOutput *_output = nullptr; +}; + +class DecomposeHardSwishPass : public ::testing::Test +{ +protected: + void MakeGraph() + { + const int N = 1; + const int H = 4; + const int W = 4; + const int C = 3; + + // graph input and output + auto graph_input = _hardswish_g._g.inputs()->create(); + auto graph_output = _hardswish_g._g.outputs()->create(); + + // CircleInput + _hardswish_g._input = _hardswish_g._g.nodes()->create<luci::CircleInput>(); + _hardswish_g._input->index(graph_input->index()); + _hardswish_g._input->shape({N, H, W, C}); + _hardswish_g._input->dtype(loco::DataType::FLOAT32); + _hardswish_g._input->name("input"); + + // CircleHardSwish + _hardswish_g._hardswish = _hardswish_g._g.nodes()->create<luci::CircleHardSwish>(); + _hardswish_g._hardswish->features(_hardswish_g._input); + _hardswish_g._hardswish->shape({N, H, W, C}); + _hardswish_g._hardswish->dtype(loco::DataType::FLOAT32); + _hardswish_g._hardswish->name("hardswish"); + + // CircleOutput + _hardswish_g._output = _hardswish_g._g.nodes()->create<luci::CircleOutput>(); + _hardswish_g._output->index(graph_output->index()); + _hardswish_g._output->from(_hardswish_g._hardswish); + _hardswish_g._output->shape({N, H, W, C}); + _hardswish_g._output->dtype(loco::DataType::FLOAT32); + _hardswish_g._output->name("output"); + } + + void MakeInt32Graph() + { + const int N = 1; + const int H = 4; + const int W = 4; + const int C = 3; + + // graph input and output + auto graph_input = _hardswish_int32_g._g.inputs()->create(); + auto graph_output = _hardswish_int32_g._g.outputs()->create(); + + // CircleInput + _hardswish_int32_g._input = _hardswish_int32_g._g.nodes()->create<luci::CircleInput>(); + _hardswish_int32_g._input->index(graph_input->index()); + _hardswish_int32_g._input->shape({N, H, W, C}); + _hardswish_int32_g._input->dtype(loco::DataType::S32); + _hardswish_int32_g._input->name("input"); + + // CircleHardSwish + _hardswish_int32_g._hardswish = _hardswish_int32_g._g.nodes()->create<luci::CircleHardSwish>(); + _hardswish_int32_g._hardswish->features(_hardswish_int32_g._input); + _hardswish_int32_g._hardswish->shape({N, H, W, C}); + _hardswish_int32_g._hardswish->dtype(loco::DataType::S32); + _hardswish_int32_g._hardswish->name("hardswish"); + + // CircleOutput + _hardswish_int32_g._output = _hardswish_int32_g._g.nodes()->create<luci::CircleOutput>(); + _hardswish_int32_g._output->index(graph_output->index()); + _hardswish_int32_g._output->from(_hardswish_int32_g._hardswish); + _hardswish_int32_g._output->shape({N, H, W, C}); + _hardswish_int32_g._output->dtype(loco::DataType::S32); + _hardswish_int32_g._output->name("output"); + } + + virtual void SetUp() + { + MakeGraph(); + MakeInt32Graph(); + } + +protected: + luci::DecomposeHardSwishPass _pass; + HardSwishGraph _hardswish_g; + HardSwishGraph _hardswish_int32_g; +}; + +} // namespace + +TEST_F(DecomposeHardSwishPass, name) +{ + auto const name = _pass.name(); + ASSERT_NE(nullptr, name); +} + +/** + * Decomposed graph looks like below. + * + * [CircleInput] [CircleConst] + * | \ / + * | \ / + * | [CircleAdd] + * | | + * | | + * \ [CircleRelu6] [CircleConst] + * \ \ / + * \ \ / + * \ [CircleMul] + * \ / + * \ / + * [CircleMul] + * | + * | + * [CircleOutput] + * + */ +TEST_F(DecomposeHardSwishPass, simple_test) +{ + auto ret = _pass.run(&_hardswish_g._g); + EXPECT_TRUE(ret); + + auto mul2 = dynamic_cast<luci::CircleMul *>(_hardswish_g._output->from()); + EXPECT_NE(nullptr, mul2); + + auto input2 = dynamic_cast<luci::CircleInput *>(mul2->x()); + EXPECT_NE(nullptr, input2); + + auto mul1 = dynamic_cast<luci::CircleMul *>(mul2->y()); + EXPECT_NE(nullptr, mul1); + + auto relu6 = dynamic_cast<luci::CircleRelu6 *>(mul1->x()); + EXPECT_NE(nullptr, relu6); + + auto mul_const = dynamic_cast<luci::CircleConst *>(mul1->y()); + EXPECT_NE(nullptr, mul_const); + EXPECT_FLOAT_EQ(1. / 6., mul_const->at<loco::DataType::FLOAT32>(0)); + + auto add = dynamic_cast<luci::CircleAdd *>(relu6->features()); + EXPECT_NE(nullptr, add); + + auto input1 = dynamic_cast<luci::CircleInput *>(add->x()); + EXPECT_NE(nullptr, input1); + + auto add_const = dynamic_cast<luci::CircleConst *>(add->y()); + EXPECT_NE(nullptr, add_const); + EXPECT_FLOAT_EQ(3., add_const->at<loco::DataType::FLOAT32>(0)); +} + +TEST_F(DecomposeHardSwishPass, check_last_node) +{ + auto ret = _pass.run(&_hardswish_g._g); + EXPECT_TRUE(ret); + + auto hardswish = dynamic_cast<luci::CircleHardSwish *>(_hardswish_g._output->from()); + EXPECT_EQ(nullptr, hardswish); +} + +TEST_F(DecomposeHardSwishPass, wrong_condition_NEG) +{ + auto ret = _pass.run(&_hardswish_int32_g._g); + EXPECT_FALSE(ret); + + auto hardswish = dynamic_cast<luci::CircleHardSwish *>(_hardswish_g._output->from()); + EXPECT_NE(nullptr, hardswish); +} diff --git a/compiler/luci/pass/src/DynamicBatchToSingleBatch.cpp b/compiler/luci/pass/src/DynamicBatchToSingleBatch.cpp new file mode 100644 index 000000000..86876063a --- /dev/null +++ b/compiler/luci/pass/src/DynamicBatchToSingleBatch.cpp @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/DynamicBatchToSingleBatch.h" + +#include "luci/Pass/DynamicBatchToSingleBatchPass.h" +#include "luci/Pass/CircleShapeInferencePass.h" + +#include "ProgressReporter.h" + +#include <logo/Phase.h> + +namespace luci +{ + +void dynamic_batch_to_single_batch(luci::Module *m) +{ + assert(m); // FIX CALLER UNLESS + + for (uint32_t i = 0; i < m->size(); i++) + { + auto g = m->graph(i); + + logo::Phase phase; + + phase.emplace_back(std::make_unique<luci::DynamicBatchToSingleBatchPass>()); + + // Needed to infer shapes of other nodes + phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>()); + + ProgressReporter prog(g, logo::PhaseStrategy::Saturate); + logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g}; + phase_runner.attach(&prog); + phase_runner.run(phase); + } +} + +} // namespace luci diff --git a/compiler/luci/pass/src/DynamicBatchToSingleBatchPass.cpp b/compiler/luci/pass/src/DynamicBatchToSingleBatchPass.cpp new file mode 100644 index 000000000..59a9f5ab3 --- /dev/null +++ b/compiler/luci/pass/src/DynamicBatchToSingleBatchPass.cpp @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/DynamicBatchToSingleBatchPass.h" + +#include <luci/IR/CircleNode.h> +#include <loco.h> + +namespace luci +{ + +bool DynamicBatchToSingleBatchPass::run(loco::Graph *g) +{ + assert(g); // FIX CALLER UNLESS + + bool changed = false; + + auto graph_inputs = g->inputs(); + + // Assume the first dimension is batch dimension + const uint32_t BATCH_DIM = 0; + + for (auto node : loco::input_nodes(g)) + { + auto input_node = loco::must_cast<luci::CircleInput *>(node); + + if (input_node->rank() == 0) + continue; + + // Skip if batch dimension is known + if (input_node->dim(BATCH_DIM).known()) + continue; + + if (input_node->rank() != 4) + { + // Limit use only for rank 4 inputs (for NHWC and NCHW) + // TODO Enable this if necessary + throw std::runtime_error("First dimension of input is unknown, but its rank is not 4."); + } + + // 'set' will make the dimension known + input_node->dim(BATCH_DIM).set(1); + + // Update graph input + auto graph_input = graph_inputs->at(input_node->index()); + auto graph_input_shape = graph_input->shape(); + auto tensor_shape = std::make_unique<loco::TensorShape>(); + { + tensor_shape->rank(graph_input_shape->rank()); + for (uint32_t i = 0; i < tensor_shape->rank(); i++) + { + tensor_shape->dim(i) = graph_input_shape->dim(i); + } + tensor_shape->dim(BATCH_DIM).set(1); + } + + graph_input->shape(std::move(tensor_shape)); + + changed = true; + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/DynamicBatchToSingleBatchPass.test.cpp b/compiler/luci/pass/src/DynamicBatchToSingleBatchPass.test.cpp new file mode 100644 index 000000000..f19f57d17 --- /dev/null +++ b/compiler/luci/pass/src/DynamicBatchToSingleBatchPass.test.cpp @@ -0,0 +1,126 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/DynamicBatchToSingleBatchPass.h" + +#include <loco.h> + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +namespace +{ + +std::unique_ptr<loco::TensorShape> make_tshape(std::initializer_list<uint32_t> dims) +{ + auto tensor_shape = std::make_unique<loco::TensorShape>(); + { + tensor_shape->rank(dims.size()); + uint32_t axis = 0; + for (auto it = dims.begin(); it != dims.end(); ++it) + { + tensor_shape->dim(axis++) = *it; + } + } + + return std::move(tensor_shape); +} + +} // namespace + +TEST(DynamicBatchToSingleBatchPassTest, simple) +{ + luci::DynamicBatchToSingleBatchPass pass; + + auto g = loco::make_graph(); + + auto graph_input = g->inputs()->create(); + { + auto tensor_shape = make_tshape({1, 5, 5, 3}); + tensor_shape->dim(0).unset(); + graph_input->shape(std::move(tensor_shape)); + } + + // Create nodes to make relu traversed first + auto input = g->nodes()->create<luci::CircleInput>(); + { + input->index(0); + input->shape({1, 5, 5, 3}); + input->dim(0).unset(); + } + + EXPECT_FALSE(graph_input->shape()->dim(0).known()); + EXPECT_FALSE(input->dim(0).known()); + + EXPECT_TRUE(pass.run(g.get())); + + // Check input is knwon + EXPECT_TRUE(graph_input->shape()->dim(0).known()); + EXPECT_EQ(1, graph_input->shape()->dim(0)); + EXPECT_TRUE(input->dim(0).known()); + EXPECT_EQ(1, input->dim(0)); +} + +TEST(DynamicBatchToSingleBatchPassTest, simple_NEG) +{ + luci::DynamicBatchToSingleBatchPass pass; + + auto g = loco::make_graph(); + + auto graph_input = g->inputs()->create(); + { + graph_input->shape({1, 5, 5, 3}); + } + + // Create nodes to make relu traversed first + auto input = g->nodes()->create<luci::CircleInput>(); + { + input->index(0); + input->shape({1, 5, 5, 3}); + } + + EXPECT_FALSE(pass.run(g.get())); +} + +// Remove this test if we support rank 1 in this pass +TEST(DynamicBatchToSingleBatchPassTest, rank1_NEG) +{ + luci::DynamicBatchToSingleBatchPass pass; + + auto g = loco::make_graph(); + + auto graph_input = g->inputs()->create(); + { + auto tensor_shape = make_tshape({1}); + tensor_shape->dim(0).unset(); + graph_input->shape(std::move(tensor_shape)); + } + + // Create nodes to make relu traversed first + auto input = g->nodes()->create<luci::CircleInput>(); + { + input->index(0); + input->shape({1}); + input->dim(0).unset(); + } + + EXPECT_FALSE(graph_input->shape()->dim(0).known()); + EXPECT_FALSE(input->dim(0).known()); + + // Rank 1 is unsupported for now + EXPECT_ANY_THROW(pass.run(g.get())); +} diff --git a/compiler/luci/pass/src/FoldAddV2Pass.test.cpp b/compiler/luci/pass/src/FoldAddV2Pass.test.cpp index 438d7f077..200fcc093 100644 --- a/compiler/luci/pass/src/FoldAddV2Pass.test.cpp +++ b/compiler/luci/pass/src/FoldAddV2Pass.test.cpp @@ -44,10 +44,10 @@ template <loco::DataType T> class FoldAddV2Test : public luci::ConstantFoldingAd public: FoldAddV2Test(std::initializer_list<uint32_t> shape) : luci::ConstantFoldingAddTestGraph(shape, T) { - _addV2 = _g.nodes()->create<luci::CircleCustom>(2, 1); - _x = _g.nodes()->create<luci::CircleConst>(); - _y = _g.nodes()->create<luci::CircleConst>(); - _addV2_out = _g.nodes()->create<luci::CircleCustomOut>(); + _addV2 = _g.nodes()->template create<luci::CircleCustom>(2, 1); + _x = _g.nodes()->template create<luci::CircleConst>(); + _y = _g.nodes()->template create<luci::CircleConst>(); + _addV2_out = _g.nodes()->template create<luci::CircleCustomOut>(); _addV2->dtype(T); _x->dtype(T); diff --git a/compiler/luci/pass/src/FoldCastPass.test.cpp b/compiler/luci/pass/src/FoldCastPass.test.cpp index 5911adf11..da33e4379 100644 --- a/compiler/luci/pass/src/FoldCastPass.test.cpp +++ b/compiler/luci/pass/src/FoldCastPass.test.cpp @@ -31,8 +31,8 @@ public: FoldCastTest(std::initializer_list<uint32_t> shape) : luci::ConstantFoldingAddTestGraph(shape, ToT) { - _cast = _g.nodes()->create<luci::CircleCast>(); - _x = _g.nodes()->create<luci::CircleConst>(); + _cast = _g.nodes()->template create<luci::CircleCast>(); + _x = _g.nodes()->template create<luci::CircleConst>(); _cast->dtype(ToT); _x->dtype(FromT); diff --git a/compiler/luci/pass/src/FoldDequantizePass.test.cpp b/compiler/luci/pass/src/FoldDequantizePass.test.cpp index fb5b6adc0..87dff5dc0 100644 --- a/compiler/luci/pass/src/FoldDequantizePass.test.cpp +++ b/compiler/luci/pass/src/FoldDequantizePass.test.cpp @@ -32,8 +32,8 @@ public: loco::Node *createFoldedPattern() override { - _dequantize = _g.nodes()->create<luci::CircleDequantize>(); - _input = _g.nodes()->create<luci::CircleConst>(); + _dequantize = _g.nodes()->template create<luci::CircleDequantize>(); + _input = _g.nodes()->template create<luci::CircleConst>(); _dequantize->dtype(loco::DataType::FLOAT32); _input->dtype(DT); diff --git a/compiler/luci/pass/src/FuseActivationFunctionPass.cpp b/compiler/luci/pass/src/FuseActivationFunctionPass.cpp index d83973cd5..868ccd140 100644 --- a/compiler/luci/pass/src/FuseActivationFunctionPass.cpp +++ b/compiler/luci/pass/src/FuseActivationFunctionPass.cpp @@ -42,6 +42,11 @@ bool fuse_activation_function(luci::CircleNode *node) // This will skip fuse for concat as luci-interpreter doesn't support this yet if (dynamic_cast<luci::CircleConcatenation *>(pred_node) != nullptr) return false; + // TODO remove this work-around + // This will skip fuse for TransposeConv as backends does not support this yet + // NOTE remove this when XpSepActFromTransposeConvOpPass is removed + if (dynamic_cast<luci::CircleTransposeConv *>(pred_node) != nullptr) + return false; auto fused_act = node_with_fused_act->fusedActivationFunction(); diff --git a/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.test.cpp b/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.test.cpp index 300796594..b132c6bd9 100644 --- a/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.test.cpp +++ b/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.test.cpp @@ -16,6 +16,8 @@ #include "luci/Pass/FuseAddWithFullyConnectedPass.h" +#include "helpers/CreateCircleConst.h" + #include <luci/IR/CircleNodes.h> #include <luci/test/TestIOGraph.h> @@ -27,52 +29,6 @@ namespace using namespace luci::test; -// TODO Reduce duplicate codes in ResolveCustomOpMatMulPass.cpp -template <typename T> -luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype, - const std::vector<uint32_t> &shape, - const std::vector<T> &values) -{ - auto node = g->nodes()->create<luci::CircleConst>(); - node->dtype(dtype); - node->rank(shape.size()); - - uint32_t size = 1; - for (uint32_t i = 0; i < shape.size(); ++i) - { - node->dim(i) = shape.at(i); - size *= shape.at(i); - } - node->shape_status(luci::ShapeStatus::VALID); - -#define INIT_VALUES(DT) \ - { \ - node->size<DT>(size); \ - for (uint32_t i = 0; i < values.size(); ++i) \ - node->at<DT>(i) = values[i]; \ - } - - switch (dtype) - { - case loco::DataType::U8: - INIT_VALUES(loco::DataType::U8); - break; - case loco::DataType::S16: - INIT_VALUES(loco::DataType::S16); - break; - case loco::DataType::S32: - INIT_VALUES(loco::DataType::S32); - break; - case loco::DataType::FLOAT32: - INIT_VALUES(loco::DataType::FLOAT32) - break; - default: - INTERNAL_EXN("create_const_node called with unsupported type"); - break; - } - return node; -} - /** * Simple graph for test * @@ -95,10 +51,10 @@ public: void init(loco::Graph *g) { std::vector<float> weights_val(16 * 4); - _fc_f = create_const_node(g, loco::DataType::FLOAT32, {16, 4}, weights_val); + _fc_f = luci::create_const_node(g, loco::DataType::FLOAT32, {16, 4}, weights_val); std::vector<float> bias_val(16); - _fc_b = create_const_node(g, loco::DataType::FLOAT32, {1, 16}, bias_val); + _fc_b = luci::create_const_node(g, loco::DataType::FLOAT32, {1, 16}, bias_val); _fc = g->nodes()->create<luci::CircleFullyConnected>(); _fc->weights(_fc_f); @@ -111,7 +67,7 @@ public: std::vector<float> addition_val; for (uint32_t i = 0; i < 16; i++) addition_val.push_back(static_cast<float>(i)); - _add_c = create_const_node(g, loco::DataType::FLOAT32, {1, 16}, addition_val); + _add_c = luci::create_const_node(g, loco::DataType::FLOAT32, {1, 16}, addition_val); _add = g->nodes()->create<luci::CircleAdd>(); _add->x(_fc); diff --git a/compiler/luci/pass/src/FuseAddWithTConvPass.cpp b/compiler/luci/pass/src/FuseAddWithTConvPass.cpp index 852bc8b63..d8e9f11f5 100644 --- a/compiler/luci/pass/src/FuseAddWithTConvPass.cpp +++ b/compiler/luci/pass/src/FuseAddWithTConvPass.cpp @@ -44,6 +44,9 @@ namespace */ bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv) { + // skip if tconv has fused activation + if (tconv->fusedActivationFunction() != luci::FusedActFunc::NONE) + return false; // check whether it has bias or not. This optimization works only if it doesn't. auto bias = dynamic_cast<luci::CircleOutputExclude *>(tconv->bias()); if (not bias) diff --git a/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp b/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp index 265a8398b..919ce6edc 100644 --- a/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp +++ b/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp @@ -87,6 +87,9 @@ bool fused_batch_norm_with_tconv(luci::CircleAdd *add) return false; if (not luci::fill(&scale, &tconv).with_commutative_args_of(mul)) return false; + // skip if tconv has fused activation + if (tconv->fusedActivationFunction() != luci::FusedActFunc::NONE) + return false; // check scale and shift constant attributes // TODO maybe rank check is not needed @@ -215,6 +218,9 @@ bool fused_batch_norm_with_tconv(luci::CircleAdd *add) fused_tconv->stride()->h(tconv->stride()->h()); fused_tconv->stride()->w(tconv->stride()->w()); fused_tconv->name(name + "/TransposeConv"); + // TODO set activation from Add and remove adding following Relu/Relu6 Op + // when all of our backends supports fused activation of TransposeConv + fused_tconv->fusedActivationFunction(luci::FusedActFunc::NONE); luci::add_origin(fused_tconv, luci::composite_origin( {luci::get_origin(add), luci::get_origin(mul), luci::get_origin(tconv)})); diff --git a/compiler/luci/pass/src/FuseGeluPass.cpp b/compiler/luci/pass/src/FuseGeluPass.cpp new file mode 100644 index 000000000..e3e7cecb3 --- /dev/null +++ b/compiler/luci/pass/src/FuseGeluPass.cpp @@ -0,0 +1,347 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseGeluPass.h" +#include "helpers/NodeFiller.h" + +#include <luci/IR/CircleNodes.h> + +#include <luci/Profile/CircleNodeOrigin.h> +#include <luci/Service/CircleNodeClone.h> + +#include <cmath> + +#include <cassert> + +// Helper to fuse Gelu +namespace +{ + +// Float comparison +bool same(float a, float b) { return fabs(a - b) < 1e-5; } + +class GeluPatternBase +{ +public: + GeluPatternBase(luci::CircleMul *candidate) { _pattern_last_node = candidate; } + + virtual ~GeluPatternBase() = default; + +public: + virtual bool matched() = 0; + +public: + luci::CircleNode *_ifm = nullptr; + luci::CircleMul *_mul_sqrt = nullptr; + luci::CircleCustom *_erf = nullptr; + luci::CircleCustomOut *_erf_out = nullptr; + luci::CircleAdd *_add_one = nullptr; + luci::CircleMul *_mul = nullptr; + luci::CircleMul *_mul_half = nullptr; + luci::CircleConst *_const_sqrt = nullptr; + luci::CircleConst *_const_one = nullptr; + luci::CircleConst *_const_half = nullptr; + luci::CircleMul *_pattern_last_node = nullptr; +}; + +/** + * Below diagram shows Gelu pattern to fuse. + * - Gelu(x) = 0.5 * x * (1.0 + erf(x / sqrt(2.0))) + * - the below pattern will be replaced with one Gelu + * + * [In] + * | + * V + * +---- ifm + * | | + * | V + * | mul_sqrt (1/sqrt(2) = 0.707106..) + * | | + * | V + * | erf + * | | + * | V + * | add_one (1.0) + * | | + * | V + * +---> mul + * | + * V + * mul_half (0.5) + * | + * V + * [Out] + * + */ +class GeluPattern1 final : public GeluPatternBase +{ +public: + GeluPattern1(luci::CircleMul *candidate) : GeluPatternBase(candidate) + { + assert(candidate); + _mul_half = candidate; + } + +public: + bool matched() override; +}; + +/** + * Below diagram shows Gelu pattern to fuse. + * - Gelu(x) = 0.5 * x * (1.0 + erf(x / sqrt(2.0))) + * - the below pattern will be replaced with one Gelu + * + * [In] + * | + * V + * +----------- ifm + * | | + * | V + * | mul_sqrt (1/sqrt(2) = 0.707106..) + * | | + * | V + * | erf + * mul_half (0.5) | + * | V + * | add_one (1.0) + * | | + * | V + * +----------> mul + * | + * | + * V + * [Out] + * + */ +class GeluPattern2 final : public GeluPatternBase +{ +public: + GeluPattern2(luci::CircleMul *candidate) : GeluPatternBase(candidate) + { + assert(candidate); + _mul = candidate; + } + + ~GeluPattern2() override = default; + +public: + bool matched() override; +}; + +#define CHECK_OR_FALSE(condition) \ + if (not(condition)) \ + return false; + +bool GeluPattern1::matched() +{ + // check pattern + CHECK_OR_FALSE(luci::fill(&_mul, &_const_half).with_commutative_args_of(_mul_half)); + CHECK_OR_FALSE(luci::fill(&_ifm, &_add_one).with_commutative_args_of(_mul)); + CHECK_OR_FALSE(luci::fill(&_erf_out, &_const_one).with_commutative_args_of(_add_one)); + + if (auto erf = dynamic_cast<luci::CircleCustom *>(_erf_out->input())) + _erf = erf; + + CHECK_OR_FALSE(_erf != nullptr); + + // Check erf + CHECK_OR_FALSE(_erf->custom_code() == "Erf"); + CHECK_OR_FALSE(_erf->numInputs() == 1); + CHECK_OR_FALSE(_erf->numOutputs() == 1); + + if (auto mul_sqrt = dynamic_cast<luci::CircleMul *>(_erf->inputs(0))) + _mul_sqrt = mul_sqrt; + + CHECK_OR_FALSE(_mul_sqrt != nullptr); + + CHECK_OR_FALSE(luci::fill(&_ifm, &_const_sqrt).with_commutative_args_of(_mul_sqrt)); + + CHECK_OR_FALSE(_mul_sqrt->x() == _ifm); + CHECK_OR_FALSE(_mul->x() == _ifm); + + // Check Activation to be NONE + CHECK_OR_FALSE(_mul_sqrt->fusedActivationFunction() == luci::FusedActFunc::NONE); + CHECK_OR_FALSE(_add_one->fusedActivationFunction() == luci::FusedActFunc::NONE); + CHECK_OR_FALSE(_mul->fusedActivationFunction() == luci::FusedActFunc::NONE); + CHECK_OR_FALSE(_mul_half->fusedActivationFunction() == luci::FusedActFunc::NONE); + + // check _const_sqrt condition + CHECK_OR_FALSE(_const_sqrt->dtype() == loco::DataType::FLOAT32); + CHECK_OR_FALSE(_const_sqrt->size<loco::DataType::FLOAT32>() == 1); + CHECK_OR_FALSE(::same(_const_sqrt->at<loco::DataType::FLOAT32>(0), sqrtf(0.5f))); + + // check if _const_half is 0.5 (fp32) + CHECK_OR_FALSE(_const_half->dtype() == loco::DataType::FLOAT32); + CHECK_OR_FALSE(_const_half->size<loco::DataType::FLOAT32>() == 1); + CHECK_OR_FALSE(_const_half->at<loco::DataType::FLOAT32>(0) == 0.5); + + // check _const_one condition + CHECK_OR_FALSE(_const_one->dtype() == loco::DataType::FLOAT32); + CHECK_OR_FALSE(_const_one->size<loco::DataType::FLOAT32>() == 1); + CHECK_OR_FALSE(_const_one->at<loco::DataType::FLOAT32>(0) == 1); + + return true; +} + +bool GeluPattern2::matched() +{ + // check pattern + CHECK_OR_FALSE(luci::fill(&_mul_half, &_add_one).with_commutative_args_of(_mul)); + CHECK_OR_FALSE(luci::fill(&_ifm, &_const_half).with_commutative_args_of(_mul_half)); + CHECK_OR_FALSE(luci::fill(&_erf_out, &_const_one).with_commutative_args_of(_add_one)); + + CHECK_OR_FALSE(_mul_half->x() == _ifm); + + if (auto erf = dynamic_cast<luci::CircleCustom *>(_erf_out->input())) + _erf = erf; + + CHECK_OR_FALSE(_erf != nullptr); + + // Check erf + CHECK_OR_FALSE(_erf->custom_code() == "Erf"); + CHECK_OR_FALSE(_erf->numInputs() == 1); + CHECK_OR_FALSE(_erf->numOutputs() == 1); + + if (auto mul_sqrt = dynamic_cast<luci::CircleMul *>(_erf->inputs(0))) + _mul_sqrt = mul_sqrt; + + CHECK_OR_FALSE(_mul_sqrt != nullptr); + + CHECK_OR_FALSE(luci::fill(&_ifm, &_const_sqrt).with_commutative_args_of(_mul_sqrt)); + + CHECK_OR_FALSE(_mul_sqrt->x() == _ifm); + + // Check Activation to be NONE + CHECK_OR_FALSE(_mul_sqrt->fusedActivationFunction() == luci::FusedActFunc::NONE); + CHECK_OR_FALSE(_add_one->fusedActivationFunction() == luci::FusedActFunc::NONE); + CHECK_OR_FALSE(_mul->fusedActivationFunction() == luci::FusedActFunc::NONE); + CHECK_OR_FALSE(_mul_half->fusedActivationFunction() == luci::FusedActFunc::NONE); + + // check _const_sqrt condition + CHECK_OR_FALSE(_const_sqrt->dtype() == loco::DataType::FLOAT32); + CHECK_OR_FALSE(_const_sqrt->size<loco::DataType::FLOAT32>() == 1); + CHECK_OR_FALSE(::same(_const_sqrt->at<loco::DataType::FLOAT32>(0), sqrtf(0.5f))); + + // check if _const_half is 0.5 (fp32) + CHECK_OR_FALSE(_const_half->dtype() == loco::DataType::FLOAT32); + CHECK_OR_FALSE(_const_half->size<loco::DataType::FLOAT32>() == 1); + CHECK_OR_FALSE(_const_half->at<loco::DataType::FLOAT32>(0) == 0.5); + + // check _const_one condition + CHECK_OR_FALSE(_const_one->dtype() == loco::DataType::FLOAT32); + CHECK_OR_FALSE(_const_one->size<loco::DataType::FLOAT32>() == 1); + CHECK_OR_FALSE(_const_one->at<loco::DataType::FLOAT32>(0) == 1); + + return true; +} + +#undef CHECK_OR_FALSE + +class FuseGelu final +{ +public: + FuseGelu(const GeluPatternBase *p) : _p(p) {} + +public: + void apply(void); + +private: + luci::CircleGelu *create_gelu(loco::Graph *graph); + +private: + const GeluPatternBase *_p; +}; + +luci::CircleGelu *FuseGelu::create_gelu(loco::Graph *graph) +{ + assert(graph); + + auto gelu = graph->nodes()->create<luci::CircleGelu>(); + gelu->features(_p->_ifm); + // TODO Support approximate = True pattern + gelu->approximate(false); + gelu->name(_p->_pattern_last_node->name() + "_gelu"); + return gelu; +} + +void FuseGelu::apply() +{ + auto graph = _p->_pattern_last_node->graph(); + + auto gelu = create_gelu(graph); + + // set origin + std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{ + luci::get_origin(_p->_mul_sqrt), luci::get_origin(_p->_erf), luci::get_origin(_p->_add_one), + luci::get_origin(_p->_mul), luci::get_origin(_p->_mul_half)}; + + luci::add_origin(gelu, luci::composite_origin(origin_vec)); + + replace(_p->_pattern_last_node).with(gelu); +} + +} // namespace + +namespace +{ + +bool fuse_gelu(luci::CircleMul *mul) +{ + assert(mul); + + // check first pattern + GeluPattern1 pattern(mul); + if (pattern.matched()) + { + FuseGelu fuse(&pattern); + fuse.apply(); + return true; + } + + // check second pattern + GeluPattern2 pattern2(mul); + if (pattern2.matched()) + { + FuseGelu fuse(&pattern2); + fuse.apply(); + return true; + } + return false; +} + +} // namespace + +namespace luci +{ + +bool FuseGeluPass::run(loco::Graph *g) +{ + bool changed = false; + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto mul = dynamic_cast<luci::CircleMul *>(node); + if (not mul) + continue; + + if (fuse_gelu(mul)) + changed = true; + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FuseGeluPass.test.cpp b/compiler/luci/pass/src/FuseGeluPass.test.cpp new file mode 100644 index 000000000..db6f6993a --- /dev/null +++ b/compiler/luci/pass/src/FuseGeluPass.test.cpp @@ -0,0 +1,251 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseGeluPass.h" + +#include <luci/IR/CircleNodes.h> + +#include <luci/test/TestIOGraph.h> + +#include <cmath> +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class GeluGraphlet +{ +public: + GeluGraphlet() = default; + + void init(loco::Graph *g) + { + _ifm = g->nodes()->create<luci::CircleAbs>(); + _mul_sqrt = g->nodes()->create<luci::CircleMul>(); + _erf = g->nodes()->create<luci::CircleCustom>(1, 1); + _erf_out = g->nodes()->create<luci::CircleCustomOut>(); + _add_one = g->nodes()->create<luci::CircleAdd>(); + _mul = g->nodes()->create<luci::CircleMul>(); + _mul_half = g->nodes()->create<luci::CircleMul>(); + _const_sqrt = g->nodes()->create<luci::CircleConst>(); + _const_one = g->nodes()->create<luci::CircleConst>(); + _const_half = g->nodes()->create<luci::CircleConst>(); + + _mul->fusedActivationFunction(luci::FusedActFunc::NONE); + _mul_sqrt->fusedActivationFunction(luci::FusedActFunc::NONE); + _mul_half->fusedActivationFunction(luci::FusedActFunc::NONE); + _add_one->fusedActivationFunction(luci::FusedActFunc::NONE); + + _ifm->name("ifm"); + _mul_sqrt->name("mul_sqrt"); + _erf->name("erf"); + _erf_out->name("erf_out"); + _add_one->name("add_one"); + _mul->name("mul"); + _mul_half->name("mul_half"); + _const_one->name("const_one"); + _const_sqrt->name("const_sqrt"); + _const_half->name("const_half"); + + _erf->custom_code("Erf"); + + _const_sqrt->dtype(loco::DataType::FLOAT32); + _const_sqrt->size<loco::DataType::FLOAT32>(1); + _const_sqrt->shape({1}); + _const_sqrt->at<loco::DataType::FLOAT32>(0) = sqrtf(0.5f); + _const_sqrt->shape_status(luci::ShapeStatus::VALID); + + _const_one->dtype(loco::DataType::FLOAT32); + _const_one->size<loco::DataType::FLOAT32>(1); + _const_one->shape({1}); + _const_one->at<loco::DataType::FLOAT32>(0) = 1.0; + _const_one->shape_status(luci::ShapeStatus::VALID); + + _const_half->dtype(loco::DataType::FLOAT32); + _const_half->size<loco::DataType::FLOAT32>(1); + _const_half->shape({1}); + _const_half->at<loco::DataType::FLOAT32>(0) = 0.5; + _const_half->shape_status(luci::ShapeStatus::VALID); + } + + void invalid_half() { _const_half->at<loco::DataType::FLOAT32>(0) = 0.1; } + void invalid_act() { _add_one->fusedActivationFunction(luci::FusedActFunc::RELU); } + +protected: + luci::CircleAbs *_ifm = nullptr; + luci::CircleMul *_mul_sqrt = nullptr; + luci::CircleCustom *_erf = nullptr; + luci::CircleCustomOut *_erf_out = nullptr; + luci::CircleAdd *_add_one = nullptr; + luci::CircleMul *_mul = nullptr; + luci::CircleMul *_mul_half = nullptr; + luci::CircleConst *_const_sqrt = nullptr; + luci::CircleConst *_const_one = nullptr; + luci::CircleConst *_const_half = nullptr; +}; + +class FuseGeluTestGraph1 : public TestIOGraph, public GeluGraphlet +{ +public: + FuseGeluTestGraph1() = default; + + void init(void) + { + TestIOGraph::init({1}, {1}); + GeluGraphlet::init(g()); + + _ifm->x(input()); + _mul_sqrt->x(_ifm); + _mul_sqrt->y(_const_sqrt); + _erf->inputs(0, _mul_sqrt); + _erf_out->input(_erf); + _add_one->x(_erf_out); + _add_one->y(_const_one); + _mul->x(_ifm); + _mul->y(_add_one); + _mul_half->x(_mul); + _mul_half->y(_const_half); + + output()->from(_mul_half); + } +}; + +class FuseGeluTestGraph2 : public TestIOGraph, public GeluGraphlet +{ +public: + FuseGeluTestGraph2() = default; + + void init(void) + { + TestIOGraph::init({1}, {1}); + GeluGraphlet::init(g()); + + _ifm->x(input()); + _mul_sqrt->x(_ifm); + _mul_sqrt->y(_const_sqrt); + _erf->inputs(0, _mul_sqrt); + _erf_out->input(_erf); + _add_one->x(_erf_out); + _add_one->y(_const_one); + _mul_half->x(_ifm); + _mul_half->y(_const_half); + _mul->x(_mul_half); + _mul->y(_add_one); + + output()->from(_mul); + } +}; + +class FuseGeluTestNegGraph : public TestIOGraph, public GeluGraphlet +{ +public: + FuseGeluTestNegGraph() = default; + + void init(void) + { + TestIOGraph::init({1}, {1}); + GeluGraphlet::init(g()); + + _ifm->x(input()); + _mul_sqrt->x(_ifm); + // NOTE y is incorrect (should be _const_sqrt) + _mul_sqrt->y(_ifm); + _erf->inputs(0, _mul_sqrt); + _erf_out->input(_erf); + _add_one->x(_erf_out); + _add_one->y(_const_one); + _mul->x(_ifm); + _mul->y(_add_one); + _mul_half->x(_mul); + _mul_half->y(_const_half); + + output()->from(_mul_half); + } +}; + +} // namespace + +TEST(FuseGeluPassTest, name) +{ + luci::FuseGeluPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST(FuseGeluPassTest, fuse_pattern1) +{ + FuseGeluTestGraph1 g; + luci::FuseGeluPass pass; + + g.init(); + + EXPECT_TRUE(pass.run(g.g())); +} + +TEST(FuseGeluPassTest, fuse_pattern2) +{ + FuseGeluTestGraph2 g; + luci::FuseGeluPass pass; + + g.init(); + + EXPECT_TRUE(pass.run(g.g())); +} + +TEST(FuseGeluPassTest, fuse_invalid_half_NEG) +{ + FuseGeluTestNegGraph g; + luci::FuseGeluPass pass; + + g.init(); + g.invalid_half(); + + EXPECT_FALSE(pass.run(g.g())); +} + +TEST(FuseGeluPassTest, fuse_pattern2_invalid_half_NEG) +{ + FuseGeluTestGraph2 g; + luci::FuseGeluPass pass; + + g.init(); + g.invalid_half(); + + EXPECT_FALSE(pass.run(g.g())); +} + +TEST(FuseGeluPassTest, fuse_invalid_act_NEG) +{ + FuseGeluTestNegGraph g; + luci::FuseGeluPass pass; + + g.init(); + g.invalid_act(); + + EXPECT_FALSE(pass.run(g.g())); +} + +TEST(FuseGeluPassTest, fuse_NEG) +{ + FuseGeluTestNegGraph g; + luci::FuseGeluPass pass; + + g.init(); + + EXPECT_FALSE(pass.run(g.g())); +} diff --git a/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp b/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp index e8fa2a478..18617e3b7 100644 --- a/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp +++ b/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp @@ -28,6 +28,25 @@ namespace { +// Return true if node is a virtual node +bool virtual_op(const luci::CircleOpcode opcode) +{ + switch (opcode) + { +#define CIRCLE_NODE(OPCODE, CIRCLE_CLASS) \ + case luci::CircleOpcode::OPCODE: \ + return false; +#define CIRCLE_VNODE(OPCODE, CIRCLE_CLASS) \ + case luci::CircleOpcode::OPCODE: \ + return true; +#include <luci/IR/CircleNodes.lst> +#undef CIRCLE_NODE +#undef CIRCLE_VNODE + default: + throw std::runtime_error("Unknown opcode detected"); + } +} + void quant_const_values(luci::CircleConst *const_node, float scaling_factor, float zerop, loco::DataType quant_type) { @@ -448,6 +467,50 @@ struct PropagateQParamBackward final : public luci::CircleNodeMutableVisitor<voi void visit(luci::CirclePack *node) { propagate_pack_quantparam(node); } void visit(luci::CirclePadV2 *node) { propagate_pad_v2_quantparam(node); } + + // Propagate qparam for non-value changing Ops + // (ex: Reshape, Transpose, etc.) + // TODO Add more Ops + + void visit(luci::CircleReshape *node) + { + auto input_node = loco::must_cast<luci::CircleNode *>(node->tensor()); + + // Do not propagate qparam if input node has multiple users + if (loco::succs(input_node).size() > 1) + return; + + const auto input_opcode = input_node->opcode(); + + // Do not propagate qparam if input node is virtual Op (except CIRCLEINPUT) + // Why? It is not safe to propagate qparam to some virtual nodes. For example, + // const node, multi-out nodes. Let's block them for now. + // TODO Revisit this condition + if (virtual_op(input_opcode) and input_opcode != luci::CircleOpcode::CIRCLEINPUT) + return; + + overwrite_quantparam(node, input_node); + } + + void visit(luci::CircleTranspose *node) + { + auto input_node = loco::must_cast<luci::CircleNode *>(node->a()); + + // Do not propagate qparam if input node has multiple users + if (loco::succs(input_node).size() > 1) + return; + + const auto input_opcode = input_node->opcode(); + + // Do not propagate qparam if input node is virtual Op (except CIRCLEINPUT) + // Why? It is not safe to propagate qparam to some virtual nodes. For example, + // const node, multi-out nodes. Let's block them for now. + // TODO Revisit this condition + if (virtual_op(input_opcode) and input_opcode != luci::CircleOpcode::CIRCLEINPUT) + return; + + overwrite_quantparam(node, input_node); + } }; } // namespace diff --git a/compiler/luci/pass/src/PropagateQParamBackwardPass.test.cpp b/compiler/luci/pass/src/PropagateQParamBackwardPass.test.cpp index 33af70449..04573cc45 100644 --- a/compiler/luci/pass/src/PropagateQParamBackwardPass.test.cpp +++ b/compiler/luci/pass/src/PropagateQParamBackwardPass.test.cpp @@ -129,6 +129,119 @@ public: CircleOutput *output = nullptr; }; +/** + * BEFORE + * + * [Input] + * | + * [Conv] (qparam 1) + * | + * [Reshape] (qparam 2) + * | + * [Output] + * + * AFTER + * + * [Input] + * | + * [Conv] (qparam 2) + * | + * [Reshape] (qparam 2) + * | + * [Output] + */ +class ConvReshapeGraph +{ +public: + ConvReshapeGraph() + { + input = g.nodes()->create<luci::CircleInput>(); + conv = g.nodes()->create<luci::CircleConv2D>(); + reshape = g.nodes()->create<luci::CircleReshape>(); + output = g.nodes()->create<luci::CircleOutput>(); + + auto graph_input = g.inputs()->create(); + input->index(graph_input->index()); + auto graph_output = g.outputs()->create(); + output->index(graph_output->index()); + + set_qparam(conv, 2.0, 2); + set_qparam(reshape, 1.0, 1); + + conv->input(input); + reshape->tensor(conv); + output->from(reshape); + } + +public: + loco::Graph g; + luci::CircleInput *input = nullptr; + luci::CircleConv2D *conv = nullptr; + luci::CircleReshape *reshape = nullptr; + luci::CircleOutput *output = nullptr; +}; + +/** + * BEFORE + * + * [Input] + * | + * [Conv] (qparam 1) + * | + * +---------------------+ + * | | + * [Reshape] (qparam 2) [Output] + * | + * [Output] + * + * AFTER (qparam is not propagated as Conv has multiple users) + * + * [Input] + * | + * [Conv] (qparam 1) + * | + * +---------------------+ + * | | + * [Reshape] (qparam 2) [Output] + * | + * [Output] + */ +class ConvReshapeMultiOutGraph +{ +public: + ConvReshapeMultiOutGraph() + { + input = g.nodes()->create<luci::CircleInput>(); + conv = g.nodes()->create<luci::CircleConv2D>(); + reshape = g.nodes()->create<luci::CircleReshape>(); + output1 = g.nodes()->create<luci::CircleOutput>(); + output2 = g.nodes()->create<luci::CircleOutput>(); + + auto graph_input = g.inputs()->create(); + input->index(graph_input->index()); + auto graph_output1 = g.outputs()->create(); + output1->index(graph_output1->index()); + auto graph_output2 = g.outputs()->create(); + output2->index(graph_output2->index()); + + set_qparam(conv, 2.0, 2); + set_qparam(reshape, 1.0, 1); + + conv->input(input); + reshape->tensor(conv); + output1->from(reshape); + output2->from(conv); + } + +public: + loco::Graph g; + luci::CircleInput *input = nullptr; + luci::CircleConv2D *conv = nullptr; + luci::CircleReshape *reshape = nullptr; + luci::CircleOutput *output1 = nullptr; + luci::CircleOutput *output2 = nullptr; +}; + } // namespace TEST(PropagateQParamBackwardPassTest, name) @@ -165,3 +278,33 @@ TEST(PropagateQParamBackwardPassTest, subsequent_propagation) EXPECT_EQ(3.0, graph.input->quantparam()->scale[0]); EXPECT_EQ(3, graph.input->quantparam()->zerop[0]); } + +TEST(PropagateQParamBackwardPassTest, reshape) +{ + ConvReshapeGraph graph; + + EXPECT_NE(graph.conv->quantparam()->scale, graph.reshape->quantparam()->scale); + EXPECT_NE(graph.conv->quantparam()->zerop, graph.reshape->quantparam()->zerop); + + luci::PropagateQParamBackwardPass pass(loco::DataType::U8); + + pass.run(&graph.g); + + EXPECT_EQ(graph.conv->quantparam()->scale, graph.reshape->quantparam()->scale); + EXPECT_EQ(graph.conv->quantparam()->zerop, graph.reshape->quantparam()->zerop); +} + +TEST(PropagateQParamBackwardPassTest, reshape_multi_use_NEG) +{ + ConvReshapeMultiOutGraph graph; + + EXPECT_NE(graph.conv->quantparam()->scale, graph.reshape->quantparam()->scale); + EXPECT_NE(graph.conv->quantparam()->zerop, graph.reshape->quantparam()->zerop); + + luci::PropagateQParamBackwardPass pass(loco::DataType::U8); + + pass.run(&graph.g); + + EXPECT_NE(graph.conv->quantparam()->scale, graph.reshape->quantparam()->scale); + EXPECT_NE(graph.conv->quantparam()->zerop, graph.reshape->quantparam()->zerop); +} diff --git a/compiler/luci/pass/src/QuantizationUtils.cpp b/compiler/luci/pass/src/QuantizationUtils.cpp index 45d229a0b..3e3cdde34 100644 --- a/compiler/luci/pass/src/QuantizationUtils.cpp +++ b/compiler/luci/pass/src/QuantizationUtils.cpp @@ -73,14 +73,14 @@ void asymmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float } void symmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float max, - float &scaling_factor, int64_t &zp, float &nudged_min, + float &scaling_factor, float &nudged_min, float &nudged_max) { const int32_t kMaxScale = std::numeric_limits<int16_t>::max(); const int32_t kMinScale = -kMaxScale; uint32_t size = node->size<loco::DataType::FLOAT32>(); - compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max); + compute_sym_scale(min, max, scaling_factor, nudged_min, nudged_max); const float scaling_factor_inv = 1.0 / scaling_factor; std::vector<int32_t> quantized_values(size); for (uint32_t i = 0; i < size; ++i) @@ -101,12 +101,14 @@ void symmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float } } -void compute_sym_scale_zp(float min, float max, float &scaling_factor, int64_t &zp, - float &nudged_min, float &nudged_max) +void compute_sym_scale(float min, float max, float &scaling_factor, float &nudged_min, + float &nudged_max, loco::DataType out_type) { assert(min <= max); + assert(out_type == loco::DataType::S8 || out_type == loco::DataType::S16); - const int32_t kMaxScale = std::numeric_limits<int16_t>::max(); + const int32_t kMaxScale = (out_type == loco::DataType::S16) ? std::numeric_limits<int16_t>::max() + : std::numeric_limits<int8_t>::max(); const int32_t kMinScale = -kMaxScale; const double qmin_double = kMinScale; const double qmax_double = kMaxScale; @@ -126,10 +128,9 @@ void compute_sym_scale_zp(float min, float max, float &scaling_factor, int64_t & : scale_factor_from_max_side; // protect scale from being very low to avoid overflow/underflow - if (scaling_factor < 1e-8) - scaling_factor = 1e-8; + const float kMinScalingFactor = (out_type == loco::DataType::S16) ? 1e-8 : 1e-5; + scaling_factor = std::max(scaling_factor, kMinScalingFactor); - zp = 0; nudged_min = static_cast<float>(qmin_double * scaling_factor); nudged_max = static_cast<float>(qmax_double * scaling_factor); } @@ -424,7 +425,7 @@ void quant_const(luci::CircleConst *node, loco::DataType quant_type) nudged_max); break; case loco::DataType::S16: - symmetric_wquant_with_minmax_per_layer(node, min, max, scaling_factor, zp, nudged_min, + symmetric_wquant_with_minmax_per_layer(node, min, max, scaling_factor, nudged_min, nudged_max); break; default: diff --git a/compiler/luci/pass/src/QuantizationUtils.h b/compiler/luci/pass/src/QuantizationUtils.h index 0720c9839..93c4045b5 100644 --- a/compiler/luci/pass/src/QuantizationUtils.h +++ b/compiler/luci/pass/src/QuantizationUtils.h @@ -23,9 +23,9 @@ namespace luci { -// Compute scale/zp using given min/max for symmetric quantization (int16) -void compute_sym_scale_zp(float min, float max, float &scaling_factor, int64_t &zp, - float &nudged_min, float &nudged_max); +// Compute scale using given min/max for symmetric quantization (int8/int16) +void compute_sym_scale(float min, float max, float &scaling_factor, float &nudged_min, + float &nudged_max, loco::DataType out_type = loco::DataType::S16); // Compute scale/zp using given min/max for asymmetric quantization (uint8) void compute_asym_scale_zp(float min, float max, float &scaling_factor, int64_t &zp, @@ -40,7 +40,7 @@ void asymmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float // Symmetric per-layer quantization of weights (const tensor) using given min/max values // NOTE: in-place update of node data void symmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float max, - float &scaling_factor, int64_t &zp, float &nudged_min, + float &scaling_factor, float &nudged_min, float &nudged_max); // Helper function to get channel dimension diff --git a/compiler/luci/pass/src/QuantizeActivation.cpp b/compiler/luci/pass/src/QuantizeActivation.cpp index 214e61c1e..913450083 100644 --- a/compiler/luci/pass/src/QuantizeActivation.cpp +++ b/compiler/luci/pass/src/QuantizeActivation.cpp @@ -78,7 +78,7 @@ void QuantizeActivation::visit(luci::CircleNode *node) } else { - compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max); + compute_sym_scale(min, max, scaling_factor, nudged_min, nudged_max); node->dtype(loco::DataType::S16); } @@ -171,7 +171,10 @@ void QuantizeConstInputActivation::visit(luci::CircleNode *node) auto input_node = node->arg(i); auto const_node = dynamic_cast<luci::CircleConst *>(input_node); if (const_node != nullptr) - throw std::runtime_error("Unsupported Op for const inputs"); + { + std::string msg = "Unsupported Op for const inputs: " + node->name(); + throw std::runtime_error(msg); + } } } @@ -221,6 +224,7 @@ QUANTIZE_SINGLE_CONST_INPUT(luci::CircleElu, features) QUANTIZE_SINGLE_CONST_INPUT(luci::CircleExp, x) QUANTIZE_SINGLE_CONST_INPUT(luci::CircleFloor, x) QUANTIZE_SINGLE_CONST_INPUT(luci::CircleGather, params) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleGelu, features) QUANTIZE_SINGLE_CONST_INPUT(luci::CircleLocalResponseNormalization, input) QUANTIZE_SINGLE_CONST_INPUT(luci::CircleLogistic, x) QUANTIZE_SINGLE_CONST_INPUT(luci::CircleMean, input) @@ -242,6 +246,7 @@ QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSpaceToDepth, input) QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSplit, input) QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSplitV, input) QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSqrt, x) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSqueeze, input) QUANTIZE_SINGLE_CONST_INPUT(luci::CircleStridedSlice, input) QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSum, input) QUANTIZE_SINGLE_CONST_INPUT(luci::CircleTanh, x) @@ -256,6 +261,7 @@ QUANTIZE_TWO_CONST_INPUTS(luci::CircleBatchMatMul, x, y) QUANTIZE_TWO_CONST_INPUTS(luci::CircleDiv, x, y) QUANTIZE_TWO_CONST_INPUTS(luci::CircleEqual, x, y) QUANTIZE_TWO_CONST_INPUTS(luci::CircleFloorDiv, x, y) +QUANTIZE_TWO_CONST_INPUTS(luci::CircleFloorMod, x, y) QUANTIZE_TWO_CONST_INPUTS(luci::CircleGreater, x, y) QUANTIZE_TWO_CONST_INPUTS(luci::CircleGreaterEqual, x, y) QUANTIZE_TWO_CONST_INPUTS(luci::CircleLess, x, y) diff --git a/compiler/luci/pass/src/QuantizeActivation.h b/compiler/luci/pass/src/QuantizeActivation.h index c6c991a76..ba3bc59f2 100644 --- a/compiler/luci/pass/src/QuantizeActivation.h +++ b/compiler/luci/pass/src/QuantizeActivation.h @@ -111,6 +111,7 @@ private: void visit(luci::CircleExp *node); void visit(luci::CircleFloor *node); void visit(luci::CircleGather *node); + void visit(luci::CircleGelu *node); void visit(luci::CircleLocalResponseNormalization *node); void visit(luci::CircleLogistic *node); void visit(luci::CircleMean *node); @@ -132,6 +133,7 @@ private: void visit(luci::CircleSplit *node); void visit(luci::CircleSplitV *node); void visit(luci::CircleSqrt *node); + void visit(luci::CircleSqueeze *node); void visit(luci::CircleStridedSlice *node); void visit(luci::CircleSum *node); void visit(luci::CircleTanh *node); @@ -146,6 +148,7 @@ private: void visit(luci::CircleDiv *node); void visit(luci::CircleEqual *node); void visit(luci::CircleFloorDiv *node); + void visit(luci::CircleFloorMod *node); void visit(luci::CircleGreater *node); void visit(luci::CircleGreaterEqual *node); void visit(luci::CircleLess *node); diff --git a/compiler/luci/pass/src/QuantizeBias.test.cpp b/compiler/luci/pass/src/QuantizeBias.test.cpp index 0104a191b..9030f59e9 100644 --- a/compiler/luci/pass/src/QuantizeBias.test.cpp +++ b/compiler/luci/pass/src/QuantizeBias.test.cpp @@ -16,6 +16,8 @@ #include "QuantizeBias.h" +#include "helpers/CreateCircleConst.h" + #include <luci/test/TestIOGraph.h> #include <luci/IR/CircleNodes.h> #include <luci/IR/CircleQuantParam.h> @@ -29,51 +31,6 @@ namespace using namespace luci::test; -// TODO Reduce duplicate codes in ResolveCustomOpMatMulPass.cpp -template <typename T> -luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype, - const std::vector<uint32_t> &shape, T value) -{ - auto node = g->nodes()->create<luci::CircleConst>(); - node->dtype(dtype); - node->rank(shape.size()); - - uint32_t size = 1; - for (uint32_t i = 0; i < shape.size(); ++i) - { - node->dim(i) = shape.at(i); - size *= shape.at(i); - } - node->shape_status(luci::ShapeStatus::VALID); - -#define INIT_VALUES(DT) \ - { \ - node->size<DT>(size); \ - for (uint32_t i = 0; i < size; ++i) \ - node->at<DT>(i) = value; \ - } - - switch (dtype) - { - case loco::DataType::U8: - INIT_VALUES(loco::DataType::U8); - break; - case loco::DataType::S16: - INIT_VALUES(loco::DataType::S16); - break; - case loco::DataType::S32: - INIT_VALUES(loco::DataType::S32); - break; - case loco::DataType::FLOAT32: - INIT_VALUES(loco::DataType::FLOAT32) - break; - default: - INTERNAL_EXN("create_const_node called with unsupported type"); - break; - } - return node; -} - /** * Simple graph for test * diff --git a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp index ef047d35d..f8989c9e0 100644 --- a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp +++ b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp @@ -110,8 +110,8 @@ void cal_minmax_per_channel(CircleConst *node, std::vector<float> &min, std::vec } void sym_wquant_per_channel(CircleConst *node, std::vector<float> &min, std::vector<float> &max, - std::vector<float> &scaling_factor, std::vector<int64_t> &zp, - std::vector<float> &nudged_min, std::vector<float> &nudged_max) + std::vector<float> &scaling_factor, std::vector<float> &nudged_min, + std::vector<float> &nudged_max) { assert(node->dtype() == loco::DataType::FLOAT32); const int32_t kMaxScale = std::numeric_limits<int16_t>::max(); @@ -122,7 +122,7 @@ void sym_wquant_per_channel(CircleConst *node, std::vector<float> &min, std::vec for (size_t i = 0; i < min.size(); ++i) { - compute_sym_scale_zp(min[i], max[i], scaling_factor[i], zp[i], nudged_min[i], nudged_max[i]); + compute_sym_scale(min[i], max[i], scaling_factor[i], nudged_min[i], nudged_max[i]); } auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) { @@ -322,7 +322,7 @@ private: } else { - sym_wquant_per_channel(weights, min, max, scaling_factor, zp, nudged_min, nudged_max); + sym_wquant_per_channel(weights, min, max, scaling_factor, nudged_min, nudged_max); sym_wdequant_per_channel(weights, scaling_factor); } diff --git a/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp b/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp index 788353cd8..8f6a96f33 100644 --- a/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp +++ b/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp @@ -206,6 +206,7 @@ public: transpose_conv->outBackprop(input_1); transpose_conv->filter(filter); transpose_conv->inputSizes(input_sizes); + transpose_conv->fusedActivationFunction(luci::FusedActFunc::NONE); if (make_valid) { diff --git a/compiler/luci/pass/src/QuantizeWeights.cpp b/compiler/luci/pass/src/QuantizeWeights.cpp index 29cdaffff..59329c19e 100644 --- a/compiler/luci/pass/src/QuantizeWeights.cpp +++ b/compiler/luci/pass/src/QuantizeWeights.cpp @@ -92,9 +92,8 @@ void asym_wquant_per_channel(CircleConst *node, std::vector<float> &min, // TODO Reduce duplicate code with QuantizeDequantizeWeights void sym_wquant_per_channel(CircleConst *node, std::vector<float> &min, std::vector<float> &max, - std::vector<float> &scaling_factor, std::vector<int64_t> &zp, - std::vector<float> &nudged_min, std::vector<float> &nudged_max, - int32_t &channel_dim_index) + std::vector<float> &scaling_factor, std::vector<float> &nudged_min, + std::vector<float> &nudged_max, int32_t &channel_dim_index) { assert(node->dtype() == loco::DataType::FLOAT32); const int32_t kMaxScale = std::numeric_limits<int16_t>::max(); @@ -105,7 +104,7 @@ void sym_wquant_per_channel(CircleConst *node, std::vector<float> &min, std::vec for (size_t i = 0; i < min.size(); ++i) { - compute_sym_scale_zp(min[i], max[i], scaling_factor[i], zp[i], nudged_min[i], nudged_max[i]); + compute_sym_scale(min[i], max[i], scaling_factor[i], nudged_min[i], nudged_max[i]); } auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) { @@ -383,7 +382,7 @@ void QuantizeWeights::quantize_weights(luci::CircleConst *weights) } else { - sym_wquant_per_channel(weights, min, max, scaling_factor, zp, nudged_min, nudged_max, + sym_wquant_per_channel(weights, min, max, scaling_factor, nudged_min, nudged_max, channel_dim_index); } diff --git a/compiler/luci/pass/src/QuantizeWeightsOnly.cpp b/compiler/luci/pass/src/QuantizeWeightsOnly.cpp new file mode 100644 index 000000000..e69a7b6a8 --- /dev/null +++ b/compiler/luci/pass/src/QuantizeWeightsOnly.cpp @@ -0,0 +1,224 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "QuantizeWeightsOnly.h" +#include "QuantizationUtils.h" + +#include <luci/Service/Nodes/CircleConst.h> +#include <luci/Log.h> + +#include <cmath> +#include <vector> +#include <functional> +#include <limits> + +using namespace luci; + +namespace +{ + +using IterFunc = std::function<void(uint32_t *, loco::TensorShape &, int32_t)>; + +void iterate_per_channel(CircleConst *node, int32_t &channel_dim_index, IterFunc func) +{ + loco::TensorShape dimension; + dimension.rank(4); + uint32_t indices[4] = { + 0, + }; + + if (!get_channel_dim_index(node, dimension, channel_dim_index)) + { + assert(false); + return; + } + + for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++) + { + for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++) + { + for (indices[2] = 0; indices[2] < dimension.dim(2).value(); indices[2]++) + { + for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++) + { + func(indices, dimension, channel_dim_index); + } + } + } + } +} + +// TODO Reduce duplicate code with QuantizeDequantizeWeights +template <loco::DataType out_type> +void sym_wquant_per_channel(CircleConst *node, std::vector<float> &min, std::vector<float> &max, + std::vector<float> &scaling_factor, std::vector<float> &nudged_min, + std::vector<float> &nudged_max, int32_t &channel_dim_index) +{ + assert(node->dtype() == loco::DataType::FLOAT32); + assert(out_type == loco::DataType::S8 || out_type == loco::DataType::S16); + const int32_t kMaxScale = (out_type == loco::DataType::S8) ? std::numeric_limits<int8_t>::max() + : std::numeric_limits<int16_t>::max(); + const int32_t kMinScale = -kMaxScale; + + uint32_t size = node->size<loco::DataType::FLOAT32>(); + std::vector<int32_t> quantized_values(size); + + for (size_t i = 0; i < min.size(); ++i) + { + compute_sym_scale(min[i], max[i], scaling_factor[i], nudged_min[i], nudged_max[i], out_type); + } + + auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) { + int channel_idx = indices[channel_dim_index]; + const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx]; + auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices)); + data = data < nudged_min[channel_idx] ? nudged_min[channel_idx] : data; + data = data > nudged_max[channel_idx] ? nudged_max[channel_idx] : data; + quantized_values[cal_offset(dimension, indices)] = + static_cast<int32_t>(std::round(data * scaling_factor_inv)); + }; + + iterate_per_channel(node, channel_dim_index, quantize); + + node->dtype(out_type); // change the type of tensor + node->size<out_type>(size); // resize tensor + for (uint32_t i = 0; i < size; ++i) + { + node->at<out_type>(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i])); + } +} + +void cal_minmax_per_channel(CircleConst *node, std::vector<float> &min, std::vector<float> &max, + int32_t &channel_dim_index) +{ + loco::TensorShape dimension; + dimension.rank(4); + + if (!get_channel_dim_index(node, dimension, channel_dim_index)) + { + throw std::runtime_error("Failed to find channel index in " + node->name()); + } + auto size = dimension.dim(channel_dim_index).value(); + + std::vector<bool> has_min_max_value(size, false); + min.resize(size); + max.resize(size); + + auto cal_minmax = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) { + int channel_idx = indices[channel_dim_index]; + auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices)); + if (has_min_max_value[channel_idx]) + { + min[channel_idx] = data < min[channel_idx] ? data : min[channel_idx]; + max[channel_idx] = data > max[channel_idx] ? data : max[channel_idx]; + } + else + { + min[channel_idx] = data; + max[channel_idx] = data; + has_min_max_value[channel_idx] = true; + } + }; + + iterate_per_channel(node, channel_dim_index, cal_minmax); +} + +} // namespace + +namespace luci +{ + +void QuantizeWeightsOnly::quantize_weights(luci::CircleConst *weights) +{ + // Find min/max per channel-wise + if (granularity == QuantizationGranularity::ChannelWise) + { + auto quantparam = weights->quantparam(); + if (quantparam == nullptr) + { + // Find min/max on the fly + // NOTE This is for the case when QuantizeDequantizeWeights is skipped + // TODO Reduce duplicate codes + std::vector<float> min; + std::vector<float> max; + int32_t channel_dim_index = 0; + + cal_minmax_per_channel(weights, min, max, channel_dim_index); + + std::vector<float> nudged_min(min.size()); + std::vector<float> nudged_max(min.size()); + std::vector<float> scaling_factor(min.size()); + std::vector<int64_t> zp(min.size()); + + if (output_type == loco::DataType::S8) + { + sym_wquant_per_channel<loco::DataType::S8>(weights, min, max, scaling_factor, nudged_min, + nudged_max, channel_dim_index); + } + else if (output_type == loco::DataType::S16) + { + sym_wquant_per_channel<loco::DataType::S16>(weights, min, max, scaling_factor, nudged_min, + nudged_max, channel_dim_index); + } + else + { + throw std::runtime_error("Weights-only quantization supports s8 and s16"); + } + + auto quantparam = std::make_unique<CircleQuantParam>(); + quantparam->scale = scaling_factor; + quantparam->zerop = zp; + quantparam->quantized_dimension = channel_dim_index; + weights->quantparam(std::move(quantparam)); + + return; + } + } + else + throw std::runtime_error("Weights-only quantization does not support layer-wise"); +} + +void QuantizeWeightsOnly::visit(luci::CircleConv2D *node) +{ + LOGGER(l); + INFO(l) << "QuantizeWeightsOnly visits node: " << node->name() << std::endl; + + auto weights = loco::must_cast<luci::CircleConst *>(node->filter()); + if (!is_quantized(weights)) + { + auto new_weights = luci::clone(weights); + node->filter(new_weights); + quantize_weights(new_weights); + } +} + +void QuantizeWeightsOnly::visit(luci::CircleDepthwiseConv2D *node) +{ + LOGGER(l); + INFO(l) << "QuantizeWeightsOnly visits node: " << node->name() << std::endl; + + auto weights = loco::must_cast<luci::CircleConst *>(node->filter()); + if (!is_quantized(weights)) + { + auto new_weights = luci::clone(weights); + node->filter(new_weights); + quantize_weights(new_weights); + } +} + +void QuantizeWeightsOnly::visit(luci::CircleNode *) {} + +} // namespace luci diff --git a/compiler/luci/pass/src/QuantizeWeightsOnly.h b/compiler/luci/pass/src/QuantizeWeightsOnly.h new file mode 100644 index 000000000..ff6ad3261 --- /dev/null +++ b/compiler/luci/pass/src/QuantizeWeightsOnly.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_QUANTIZE_WEIGHTS_ONLY_H__ +#define __LUCI_QUANTIZE_WEIGHTS_ONLY_H__ + +#include <luci/Pass/QuantizationParameters.h> +#include <luci/IR/CircleNodeVisitor.h> + +namespace luci +{ + +/** + * @brief QuantizeWeightsOnly quantizes tensors for weights + * @details Find min/max values on the fly and then quantize + */ +struct QuantizeWeightsOnly final : public luci::CircleNodeMutableVisitor<void> +{ + QuantizeWeightsOnly(loco::DataType input, loco::DataType output, QuantizationGranularity gr) + : input_type(input), output_type(output), granularity(gr) + { + } + + loco::DataType input_type; + loco::DataType output_type; + QuantizationGranularity granularity; + +private: + void quantize_weights(luci::CircleConst *weights); + + void visit(luci::CircleConv2D *node); + void visit(luci::CircleDepthwiseConv2D *node); + void visit(luci::CircleNode *); +}; + +} // namespace luci + +#endif // __LUCI_QUANTIZE_WEIGHTS_ONLY_H__ diff --git a/compiler/luci/pass/src/QuantizeWeightsPass.cpp b/compiler/luci/pass/src/QuantizeWeightsPass.cpp new file mode 100644 index 000000000..9ac203e77 --- /dev/null +++ b/compiler/luci/pass/src/QuantizeWeightsPass.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/QuantizeWeightsPass.h" +#include "QuantizeWeightsOnly.h" +#include "QuantizationUtils.h" + +#include <luci/Log.h> + +namespace luci +{ + +bool QuantizeWeightsPass::run(loco::Graph *g) +{ + LOGGER(l); + INFO(l) << "QuantizeWeightsPass Start" << std::endl; + + if (_ctx->input_model_dtype != loco::DataType::FLOAT32) + throw std::runtime_error("Weights-only quantization supports float32 input only"); + + // Quantize weights + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + QuantizeWeightsOnly qw(_ctx->input_model_dtype, _ctx->output_model_dtype, _ctx->granularity); + circle_node->accept(&qw); + } + + INFO(l) << "QuantizeWeightsPass End" << std::endl; + return false; // one time run +} + +} // namespace luci diff --git a/compiler/luci/pass/src/QuantizeWeightsPass.test.cpp b/compiler/luci/pass/src/QuantizeWeightsPass.test.cpp new file mode 100644 index 000000000..058e029ab --- /dev/null +++ b/compiler/luci/pass/src/QuantizeWeightsPass.test.cpp @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/QuantizeWeightsPass.h" +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +namespace +{ +struct QuantizeWeightsPassTest : public ::testing::Test +{ + /** + * nconv graph + * + * [CircleInput] + * | + * | + * [CircleConv2D] + * | + * | + * [CircleOutput] + */ + void MakeGraph() + { + const int N = 1; + const int H = 4; + const int W = 4; + const int C = 3; // IC = OC + + // graph input and output + auto graph_input = _g.inputs()->create(); + auto graph_output = _g.outputs()->create(); + + // CircleInput + auto input = _g.nodes()->create<luci::CircleInput>(); + input->index(graph_input->index()); + input->shape({N, H, W, C}); + input->dtype(loco::DataType::FLOAT32); + input->name("input"); + + // CircleConv2D + auto conv = _g.nodes()->create<luci::CircleConv2D>(); + conv->input(input); + auto bias = _g.nodes()->create<luci::CircleConst>(); + bias->dtype(loco::DataType::FLOAT32); + bias->shape({C}); + bias->name("conv_bias"); + conv->bias(bias); + auto weight = _g.nodes()->create<luci::CircleConst>(); + weight->dtype(loco::DataType::FLOAT32); + weight->shape({C, H, W, C}); + weight->size<loco::DataType::FLOAT32>(C * H * W * C); + conv->filter(weight); + conv->padding(luci::Padding::SAME); + conv->fusedActivationFunction(luci::FusedActFunc::NONE); + conv->dtype(loco::DataType::FLOAT32); + conv->name("nconv"); + + // CircleOutput + auto output = _g.nodes()->create<luci::CircleOutput>(); + output->index(graph_output->index()); + output->from(conv); + output->shape({N, H, W, C}); + output->dtype(loco::DataType::FLOAT32); + output->name("output"); + } + virtual void SetUp() { MakeGraph(); } + loco::Graph _g; +}; + +} // namespace + +TEST_F(QuantizeWeightsPassTest, name) +{ + luci::QuantizeWeightsPass pass(loco::DataType::FLOAT32, loco::DataType::S8, + luci::QuantizationGranularity::ChannelWise); + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST_F(QuantizeWeightsPassTest, name_ctx) +{ + auto ctx = std::make_unique<luci::QuantizeWeightsPass::Context>(); + { + ctx->input_model_dtype = loco::DataType::FLOAT32; + ctx->output_model_dtype = loco::DataType::S8; + ctx->granularity = luci::QuantizationGranularity::ChannelWise; + } + + luci::QuantizeWeightsPass pass(std::move(ctx)); + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST_F(QuantizeWeightsPassTest, run_input_U8_NEG) +{ + loco::Graph g; + luci::QuantizeWeightsPass pass(loco::DataType::U8, loco::DataType::S8, + luci::QuantizationGranularity::ChannelWise); + EXPECT_THROW(pass.run(&_g), std::runtime_error); +} + +TEST_F(QuantizeWeightsPassTest, run_output_f32_NEG) +{ + loco::Graph g; + luci::QuantizeWeightsPass pass(loco::DataType::FLOAT32, loco::DataType::FLOAT32, + luci::QuantizationGranularity::ChannelWise); + EXPECT_THROW(pass.run(&_g), std::runtime_error); +} diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp index c68e06712..4f4edaf36 100644 --- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp +++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp @@ -101,7 +101,7 @@ luci::CircleQuantize *create_quantize_op(luci::CircleNode *node, loco::DataType else { assert(out_type == loco::DataType::S16); - compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max); + compute_sym_scale(min, max, scaling_factor, nudged_min, nudged_max); } auto quantparam = std::make_unique<CircleQuantParam>(); @@ -271,6 +271,7 @@ private: INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleFloor, x) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleFullyConnected, input) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleGather, params) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleGelu, features) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleInstanceNorm, input) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLeakyRelu, features) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLocalResponseNormalization, input) @@ -433,7 +434,7 @@ void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const else { assert(user_given_dtype == loco::DataType::S16); - compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max); + compute_sym_scale(min, max, scaling_factor, nudged_min, nudged_max); } input->quantparam()->scale[0] = scaling_factor; input->quantparam()->zerop[0] = zp; @@ -479,15 +480,15 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const if (user_given_dtype == loco::DataType::FLOAT32) { auto dequant_op = create_dequantize(from); - loco::replace(from).with(dequant_op); dequant_op->input(from); + output->from(dequant_op); } else { // Insert Quantize Op for non-float32 output_type auto quant_op = create_quantize_op(from, user_given_dtype); - loco::replace(from).with(quant_op); quant_op->input(from); + output->from(quant_op); // TODO Set a proper origin (Quantize should have its own Origin) luci::add_origin(quant_op, luci::get_origin(from)); @@ -629,6 +630,13 @@ bool QuantizeWithMinMaxPass::run(loco::Graph *g) for (auto node : loco::active_nodes(loco::output_nodes(g))) { auto circle_node = loco::must_cast<luci::CircleNode *>(node); + + // At this point, all activations have to be quantized. + // Un-quantized nodes are not the quantization target (ex: int32 tensor), + // so we skip them + if (circle_node->quantparam() == nullptr) + continue; + QuantizeSpecialActivation qsa(_ctx->input_model_dtype, quantize_dtype(circle_node)); circle_node->accept(&qsa); } diff --git a/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp b/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp index 05ec31727..ae02edb3d 100644 --- a/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp +++ b/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp @@ -66,6 +66,8 @@ template <Type T> luci::CircleConst *create_dummy_const(loco::Graph *g, luci::te // Fill with index node->at<T>(i) = static_cast<int16_t>(i); break; + default: + break; } } } @@ -470,15 +472,15 @@ public: void init(void) override { TestIOGraph::init({32}, {32}); - _begin = g()->nodes()->create<luci::CircleConst>(); + _begin = g()->nodes()->template create<luci::CircleConst>(); { _begin->dtype(indexT); } - _size = g()->nodes()->create<luci::CircleConst>(); + _size = g()->nodes()->template create<luci::CircleConst>(); { _size->dtype(indexT); } - _slice = g()->nodes()->create<luci::CircleSlice>(); + _slice = g()->nodes()->template create<luci::CircleSlice>(); { _slice->input(input()); _slice->begin(_begin); @@ -595,6 +597,31 @@ private: luci::CircleConst *_strides = nullptr; }; +class SumTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({4, 3, 2}, {2}); + + _axis = create_const<Type::S32, int32_t>(g(), {2}, {1, 0}); + _sum = g()->nodes()->create<luci::CircleSum>(); + { + _sum->input(input()); + _sum->reduction_indices(_axis); + _sum->name("test"); + _sum->keep_dims(false); + } + output()->from(_sum); + + set_minmax_to_non_const(g(), -1, 1); + } + +private: + luci::CircleSum *_sum = nullptr; + luci::CircleConst *_axis = nullptr; +}; + class ReshapeTestGraph final : public SimpleTestGraph { public: @@ -669,11 +696,11 @@ public: TestIOGraph::init({32}, {1}); // output dtype is float by default, but ArgMax should have indexType (s32/s64) output()->dtype(indexT); - _dimension = g()->nodes()->create<luci::CircleConst>(); + _dimension = g()->nodes()->template create<luci::CircleConst>(); { _dimension->dtype(indexT); } - _argmax = g()->nodes()->create<luci::CircleArgMax>(); + _argmax = g()->nodes()->template create<luci::CircleArgMax>(); { _argmax->input(input()); _argmax->dimension(_dimension); @@ -978,7 +1005,7 @@ public: TestIOGraph::init({32}, {32}); output()->dtype(loco::DataType::BOOL); _y = create_dummy_const<Type::FLOAT32>(g(), {32}); - _op = g()->nodes()->create<Op>(); + _op = g()->nodes()->template create<Op>(); { _op->x(input()); _op->y(_y); @@ -1011,7 +1038,7 @@ public: input()->dtype(loco::DataType::BOOL); output()->dtype(loco::DataType::BOOL); _y = create_dummy_const<Type::BOOL>(g(), {32}); - _op = g()->nodes()->create<Op>(); + _op = g()->nodes()->template create<Op>(); { _op->x(input()); _op->y(_y); @@ -1315,7 +1342,7 @@ public: TypedTestGraph::init(T, {32}, {32}); _const = create_dummy_const<T>(g(), {32}); - _mul = g()->nodes()->create<luci::CircleMul>(); + _mul = g()->nodes()->template create<luci::CircleMul>(); { _mul->x(input()); _mul->y(_const); @@ -1370,7 +1397,7 @@ public: TypedTestGraph::init(T, {32}, {32}); _const = create_dummy_const<T>(g(), {32}); - _add = g()->nodes()->create<luci::CircleAdd>(); + _add = g()->nodes()->template create<luci::CircleAdd>(); { _add->x(input()); _add->y(_const); @@ -1786,6 +1813,34 @@ TEST(QuantizedModelVerifierTest, StridedSlice_wrong_granularity_NEG) SUCCEED(); } +TEST(QuantizedModelVerifierTest, Sum) +{ + TEST_WITH_GRAPH(SumTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(SumTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(SumTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(SumTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(SumTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(SumTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Sum_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(SumTestGraph, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(SumTestGraph, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(SumTestGraph, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Sum_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY(SumTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_WRONG_GRANULARITY(SumTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_WRONG_GRANULARITY(SumTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + TEST(QuantizedModelVerifierTest, ArgMax) { TEST_WITH_GRAPH(ArgMaxTestGraph<Type::S32>, Type::U8, Granularity::LayerWise); diff --git a/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp b/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp index 93024f3f7..194893f01 100644 --- a/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp +++ b/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp @@ -16,6 +16,8 @@ #include "luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h" +#include "helpers/CreateCircleConst.h" + #include <luci/test/TestIOGraph.h> #include <luci/IR/CircleNodes.h> @@ -26,52 +28,6 @@ namespace using namespace luci::test; -// TODO Reduce duplicate codes in ResolveCustomOpMatMulPass.cpp -template <typename T> -luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype, - const std::vector<uint32_t> &shape, - const std::vector<T> &values) -{ - auto node = g->nodes()->create<luci::CircleConst>(); - node->dtype(dtype); - node->rank(shape.size()); - - uint32_t size = 1; - for (uint32_t i = 0; i < shape.size(); ++i) - { - node->dim(i) = shape.at(i); - size *= shape.at(i); - } - node->shape_status(luci::ShapeStatus::VALID); - -#define INIT_VALUES(DT) \ - { \ - node->size<DT>(size); \ - for (uint32_t i = 0; i < values.size(); ++i) \ - node->at<DT>(i) = values[i]; \ - } - - switch (dtype) - { - case loco::DataType::U8: - INIT_VALUES(loco::DataType::U8); - break; - case loco::DataType::S16: - INIT_VALUES(loco::DataType::S16); - break; - case loco::DataType::S32: - INIT_VALUES(loco::DataType::S32); - break; - case loco::DataType::FLOAT32: - INIT_VALUES(loco::DataType::FLOAT32) - break; - default: - INTERNAL_EXN("create_const_node called with unsupported type"); - break; - } - return node; -} - /** * Simple graph for test * @@ -104,7 +60,7 @@ public: _tr_y = g->nodes()->create<luci::CircleTranspose>(); _tr_y->a(_y); std::vector<int32_t> tr_val = {1, 0}; - _tr_y->perm(create_const_node(g, loco::DataType::S32, {2}, tr_val)); + _tr_y->perm(luci::create_const_node(g, loco::DataType::S32, {2}, tr_val)); _fc = g->nodes()->create<luci::CircleFullyConnected>(); _fc->input(_x); @@ -114,7 +70,7 @@ public: _fc->shape(r_shape); auto l = _fc->dim(_fc->rank() - 1).value(); std::vector<float> bias_val(l, bv); - _fc->bias(create_const_node(g, loco::DataType::FLOAT32, {l}, bias_val)); + _fc->bias(luci::create_const_node(g, loco::DataType::FLOAT32, {l}, bias_val)); _fc->name("fc"); } diff --git a/compiler/luci/pass/src/ReplaceSubWithAddPass.cpp b/compiler/luci/pass/src/ReplaceSubWithAddPass.cpp index 6bd83f5c5..f9102d836 100644 --- a/compiler/luci/pass/src/ReplaceSubWithAddPass.cpp +++ b/compiler/luci/pass/src/ReplaceSubWithAddPass.cpp @@ -17,6 +17,7 @@ #include "luci/Pass/ReplaceSubWithAddPass.h" #include <luci/IR/CircleNodes.h> +#include <luci/Profile/CircleNodeOrigin.h> #include <luci/Service/Nodes/CircleConst.h> namespace @@ -47,6 +48,7 @@ bool replace_sub_with_const_rhs(luci::CircleSub *sub) add->y(neg_const_rhs); add->name(sub->name()); add->fusedActivationFunction(sub->fusedActivationFunction()); + luci::add_origin(add, luci::get_origin(sub)); loco::replace(sub).with(add); return true; } diff --git a/compiler/luci/pass/src/RequantizePass.cpp b/compiler/luci/pass/src/RequantizePass.cpp index a56536251..77c55324a 100644 --- a/compiler/luci/pass/src/RequantizePass.cpp +++ b/compiler/luci/pass/src/RequantizePass.cpp @@ -32,37 +32,9 @@ namespace luci namespace { -// Check if the node is the bias of Conv2D, DepthwiseConv2D, or FullyConnected layer -bool is_bias(CircleConst *node) -{ - if (node == nullptr) - return false; - - auto succs = loco::succs(node); - if (succs.size() != 1) // assume bias is used by only one node - return false; - - for (auto out : succs) - { - auto conv = dynamic_cast<CircleConv2D *>(out); - if (conv != nullptr && conv->bias() == node) - return true; - - auto dw_conv = dynamic_cast<CircleDepthwiseConv2D *>(out); - if (dw_conv != nullptr && dw_conv->bias() == node) - return true; - - auto fc = dynamic_cast<CircleFullyConnected *>(out); - if (fc != nullptr && fc->bias() == node) - return true; - - auto tconv = dynamic_cast<CircleTransposeConv *>(out); - if (tconv != nullptr && tconv->bias() == node) - return true; - } - return false; -} - +// Requantize Non-const node from int8 to uint8 +// Original values: -128 ~ 127 +// After requantization: 0 ~ 255 void requant_nonconst_int8_to_uint8(CircleNode *circle_node) { assert(circle_node->dtype() == loco::DataType::S8); @@ -107,99 +79,48 @@ void requant_const_int8_to_uint8(CircleConst *node) } } +#define RETURN_UNLESS(cond) \ + if (not(cond)) \ + return; + /** - * @brief RequantizeNonConst requantizes tensors for activations + * @brief Requantize int8 quantized tensors to uint8 tensors */ -struct RequantizeNonConst final : public luci::CircleNodeMutableVisitor<bool> +struct RequantizeS8ToU8 final : public luci::CircleNodeMutableVisitor<void> { - RequantizeNonConst(loco::DataType input, loco::DataType output) - : _input_type(input), _output_type(output) - { - } - - loco::DataType _input_type; - loco::DataType _output_type; - - // Requantize input tensors of each node - bool visit(luci::CircleNode *node) + // Requantize non-const tensors + void visit(luci::CircleNode *node) { LOGGER(l); - INFO(l) << "RequantizeNonConst visit node: " << node->name() << std::endl; - auto arity = node->arity(); - for (uint32_t i = 0; i < arity; i++) - { - auto input_node = node->arg(i); - auto circle_node = loco::must_cast<luci::CircleNode *>(input_node); + INFO(l) << "RequantizeS8ToU8 visit non-const node: " << node->name() << std::endl; - // Check if this was quantized (only quantized tensors are requantized) - if (circle_node->quantparam() == nullptr) - continue; + // Ignore non-quantized tensors + RETURN_UNLESS(node->quantparam() != nullptr); - // Check if this is already requantized - if (circle_node->dtype() == _output_type) - continue; + // Check dtype is int8 + RETURN_UNLESS(node->dtype() == loco::DataType::S8); - // Check if this is not const (only non-const is requantized in this function) - auto circle_const = dynamic_cast<CircleConst *>(circle_node); - if (circle_const != nullptr) - continue; - - if (_input_type == loco::DataType::S8 && _output_type == loco::DataType::U8) - requant_nonconst_int8_to_uint8(circle_node); - } - return false; - } -}; - -/** - * @brief RequantizeConst requantizes tensors for weights - */ -struct RequantizeConst final : public luci::CircleNodeMutableVisitor<bool> -{ - RequantizeConst(loco::DataType input, loco::DataType output) - : _input_type(input), _output_type(output) - { + requant_nonconst_int8_to_uint8(node); } - loco::DataType _input_type; - loco::DataType _output_type; - - // Requantize input tensors of each node - bool visit(luci::CircleNode *node) + // Requantize const tensors + void visit(luci::CircleConst *node) { LOGGER(l); - INFO(l) << "RequantizeConst visit node: " << node->name() << std::endl; - auto arity = node->arity(); - for (uint32_t i = 0; i < arity; i++) - { - auto input_node = node->arg(i); - auto circle_node = loco::must_cast<luci::CircleNode *>(input_node); + INFO(l) << "RequantizeS8ToU8 visit const node: " << node->name() << std::endl; - // Check if this was quantized (only quantized tensors are requantized) - if (circle_node->quantparam() == nullptr) - continue; + // Ignore non-quantized tensors + RETURN_UNLESS(node->quantparam() != nullptr); - // Check if this is already requantized - if (circle_node->dtype() == _output_type) - continue; + // Check dtype is int8 + RETURN_UNLESS(node->dtype() == loco::DataType::S8); - // Check if this is const (only const is requantized in this function) - auto circle_const = dynamic_cast<CircleConst *>(circle_node); - if (circle_const == nullptr) - continue; - - // Check if this is not bias - // bias is not requantized when int8 -> uint8 - if (is_bias(circle_const)) - continue; - - if (_input_type == loco::DataType::S8 && _output_type == loco::DataType::U8) - requant_const_int8_to_uint8(circle_const); - } - return false; + requant_const_int8_to_uint8(node); } }; +#undef RETURN_UNLESS + } // namespace bool RequantizePass::run(loco::Graph *g) @@ -207,20 +128,21 @@ bool RequantizePass::run(loco::Graph *g) LOGGER(l); INFO(l) << "RequantizePass Start" << std::endl; - // Requantize non-const (activations) - for (auto node : loco::active_nodes(loco::output_nodes(g))) + // Input: int8 model + // Output: uint8 model + if (_input_dtype == loco::DataType::S8 and _output_dtype == loco::DataType::U8) { - RequantizeNonConst rqnc(_input_dtype, _output_dtype); - auto circle_node = loco::must_cast<luci::CircleNode *>(node); - circle_node->accept(&rqnc); + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + RequantizeS8ToU8 rq; + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + circle_node->accept(&rq); + } } - - // Requantize const (including weights, constants) - for (auto node : loco::active_nodes(loco::output_nodes(g))) + else { - RequantizeConst rqc(_input_dtype, _output_dtype); - auto circle_node = loco::must_cast<luci::CircleNode *>(node); - circle_node->accept(&rqc); + // Ignore other cases + return false; } // Update output dtype @@ -228,7 +150,8 @@ bool RequantizePass::run(loco::Graph *g) for (auto node : loco::output_nodes(g)) { auto circle_node = loco::must_cast<luci::CircleOutput *>(node); - if (static_cast<luci::CircleNode *>(circle_node->from())->dtype() == _output_dtype) + auto from_node = loco::must_cast<luci::CircleNode *>(circle_node->from()); + if (from_node->dtype() == _output_dtype) { circle_node->dtype(_output_dtype); auto graph_output = graph_outputs->at(circle_node->index()); diff --git a/compiler/luci/pass/src/RequantizePass.test.cpp b/compiler/luci/pass/src/RequantizePass.test.cpp index d26743c9d..a9293ce27 100644 --- a/compiler/luci/pass/src/RequantizePass.test.cpp +++ b/compiler/luci/pass/src/RequantizePass.test.cpp @@ -16,11 +16,167 @@ #include "luci/Pass/RequantizePass.h" +#include "helpers/CreateCircleConst.h" + +#include <luci/test/TestIOGraph.h> +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleQuantParam.h> + +#include <vector> + #include <gtest/gtest.h> +using namespace luci; +using namespace luci::test; + +namespace +{ + +/** + * Simple graph for test + * + * BEFORE + * + * [IFM (S8)] [W (S8)] [B (S32)] + * | | | + * +-------+--------+ + * | + * V + * [FC] + * | + * V + * [OFM(S8)] + * + * AFTER + * + * [IFM (U8)] [W (U8)] [B (S32)] + * | | | + * +-------+--------+ + * | + * V + * [FC] + * | + * V + * [OFM(U8)] + */ +struct S8FCGraphlet +{ +public: + S8FCGraphlet() = default; + virtual ~S8FCGraphlet() = default; + + void init(loco::Graph *g, const ShapeU32 out_shape, const ShapeU32 w_shape, + const ShapeU32 bias_shape) + { + _fc = g->nodes()->create<CircleFullyConnected>(); + _fc->input(_x); + _x->dtype(loco::DataType::S8); + { + auto quantparam = std::make_unique<CircleQuantParam>(); + quantparam->scale.push_back(1.0); + quantparam->zerop.push_back(0); + quantparam->quantized_dimension = 0; + _x->quantparam(std::move(quantparam)); + } + + _weights = create_const_node<int8_t>(g, loco::DataType::S8, w_shape, 1.0); + { + auto w_qparam = std::make_unique<CircleQuantParam>(); + std::vector<float> w_scale(_weights->dim(0).value(), 1.0); + std::vector<int64_t> w_zp(_weights->dim(0).value(), 0); + w_qparam->scale = w_scale; + w_qparam->zerop = w_zp; + w_qparam->quantized_dimension = 0; + _weights->quantparam(std::move(w_qparam)); + } + _fc->weights(_weights); + + _bias = create_const_node<int32_t>(g, loco::DataType::S32, bias_shape, 1.0); + { + auto b_qparam = std::make_unique<CircleQuantParam>(); + const auto bias_size = _bias->size<loco::DataType::S32>(); + std::vector<float> b_scale(bias_size, 1.0); + std::vector<int64_t> b_zp(bias_size, 0); + b_qparam->scale = b_scale; + b_qparam->zerop = b_zp; + b_qparam->quantized_dimension = 0; + _bias->quantparam(std::move(b_qparam)); + } + + _fc->fusedActivationFunction(luci::FusedActFunc::NONE); + _fc->dtype(loco::DataType::S8); + _fc->shape(out_shape); + _fc->bias(_bias); + _fc->name("fc"); + { + auto quantparam = std::make_unique<CircleQuantParam>(); + quantparam->scale.push_back(1.0); + quantparam->zerop.push_back(0); + quantparam->quantized_dimension = 0; + _fc->quantparam(std::move(quantparam)); + } + } + +public: + CircleFullyConnected *_fc = nullptr; + CircleInput *_x = nullptr; + CircleConst *_weights = nullptr; + CircleConst *_bias = nullptr; +}; + +struct S8FCGraph final : public TestIGraphlet, public TestOGraphlet, public S8FCGraphlet +{ + void init(const ShapeU32 in_shape, const ShapeU32 w_shape, const ShapeU32 out_shape, + const ShapeU32 bias_shape) + { + TestIGraphlet::init(g(), in_shape); + TestOGraphlet::init(g(), out_shape); + _x = input(); + S8FCGraphlet::init(g(), out_shape, w_shape, bias_shape); + output()->from(_fc); + } +}; + +class RequantizeS8ToU8FCTest : public ::testing::Test +{ +public: + S8FCGraph g; +}; + +} // namespace + TEST(RequantizePassTest, name) { luci::RequantizePass pass(loco::DataType::FLOAT32, loco::DataType::U8); auto const name = pass.name(); ASSERT_NE(nullptr, name); } + +TEST_F(RequantizeS8ToU8FCTest, FC) +{ + g.init({1, 18, 80} /* ifm shape */, {256, 80} /* weights shape*/, {18, 256} /* ofm shape */, + {1, 256} /* bias shape*/); + + luci::RequantizePass rq(loco::DataType::S8, loco::DataType::U8); + rq.run(g.g()); + + EXPECT_EQ(loco::DataType::U8, g._x->dtype()); + EXPECT_EQ(loco::DataType::U8, g._fc->dtype()); + EXPECT_EQ(loco::DataType::U8, g._weights->dtype()); + EXPECT_EQ(loco::DataType::S32, g._bias->dtype()); +} + +TEST_F(RequantizeS8ToU8FCTest, FC_wrong_dtype_NEG) +{ + g.init({1, 18, 80} /* ifm shape */, {256, 80} /* weights shape*/, {18, 256} /* ofm shape */, + {1, 256} /* bias shape*/); + + // Wrong dtype + luci::RequantizePass rq(loco::DataType::U8, loco::DataType::S8); + rq.run(g.g()); + + EXPECT_EQ(loco::DataType::S8, g._x->dtype()); + EXPECT_EQ(loco::DataType::S8, g._fc->dtype()); + EXPECT_EQ(loco::DataType::S8, g._weights->dtype()); + EXPECT_EQ(loco::DataType::S32, g._bias->dtype()); +} diff --git a/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp b/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp index f61882796..add55f66c 100644 --- a/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp +++ b/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp @@ -16,6 +16,8 @@ #include "luci/Pass/ResolveCustomOpMatMulPass.h" +#include "helpers/CreateCircleConst.h" + #include <loco/IR/DataTypeTraits.h> #include <luci/IR/CircleNodes.h> @@ -29,51 +31,6 @@ namespace { -template <typename T> -luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype, - const std::vector<uint32_t> &shape, - const std::vector<T> &values) -{ - auto node = g->nodes()->create<luci::CircleConst>(); - node->dtype(dtype); - node->rank(shape.size()); - - uint32_t size = 1; - for (uint32_t i = 0; i < shape.size(); ++i) - { - node->dim(i) = shape.at(i); - size *= shape.at(i); - } - node->shape_status(luci::ShapeStatus::VALID); - -#define INIT_VALUES(DT) \ - { \ - node->size<DT>(size); \ - for (uint32_t i = 0; i < values.size(); ++i) \ - node->at<DT>(i) = values[i]; \ - } - - switch (dtype) - { - case loco::DataType::U8: - INIT_VALUES(loco::DataType::U8); - break; - case loco::DataType::S16: - INIT_VALUES(loco::DataType::S16); - break; - case loco::DataType::S32: - INIT_VALUES(loco::DataType::S32); - break; - case loco::DataType::FLOAT32: - INIT_VALUES(loco::DataType::FLOAT32) - break; - default: - INTERNAL_EXN("create_const_node called with unsupported type"); - break; - } - return node; -} - bool resolve_matmul(luci::CircleCustom *cop) { #define CHECK_OR_FALSE(condition) \ @@ -121,11 +78,12 @@ bool resolve_matmul(luci::CircleCustom *cop) if (transpose_a) { // Create a permutation constant node - std::vector<uint32_t> perm; - for (uint32_t i = 0; i < circle_lhs->rank(); ++i) + std::vector<int32_t> perm; + const auto lhs_rank = static_cast<int32_t>(circle_lhs->rank()); + for (int32_t i = 0; i < lhs_rank; ++i) perm.push_back(i); std::swap(perm[circle_lhs->rank() - 1], perm[circle_lhs->rank() - 2]); - auto perm_node = create_const_node(graph, S32, {circle_lhs->rank()}, perm); + auto perm_node = luci::create_const_node(graph, S32, {circle_lhs->rank()}, perm); perm_node->name(name + "/lhs/Transpose/perm"); // Now make a transpose node auto transpose_node = graph->nodes()->create<luci::CircleTranspose>(); @@ -141,8 +99,8 @@ bool resolve_matmul(luci::CircleCustom *cop) // in row-major order, thus we need to convert between them. if (!transpose_b) { - const std::vector<uint32_t> perm{1, 0}; - auto perm_node = create_const_node(graph, S32, {2}, perm); + const std::vector<int32_t> perm{1, 0}; + auto perm_node = luci::create_const_node(graph, S32, {2}, perm); perm_node->name(name + "/rhs/Transpose/perm"); auto transpose_node = graph->nodes()->create<luci::CircleTranspose>(); transpose_node->a(rhs); diff --git a/compiler/luci/pass/src/SubstituteSplitVToSplitPass.test.cpp b/compiler/luci/pass/src/SubstituteSplitVToSplitPass.test.cpp index 6e30103f9..43f9cc116 100644 --- a/compiler/luci/pass/src/SubstituteSplitVToSplitPass.test.cpp +++ b/compiler/luci/pass/src/SubstituteSplitVToSplitPass.test.cpp @@ -16,6 +16,8 @@ #include "luci/Pass/SubstituteSplitVToSplitPass.h" +#include "helpers/CreateCircleConst.h" + #include <luci/test/TestIOGraph.h> #include <gtest/gtest.h> @@ -30,51 +32,6 @@ const int C = 32; const int H = 8; const int W = 8; -// Reduce duplicate codes in ResolveCustomOpMatMulPass.cpp -template <typename T> -luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype, - const std::vector<uint32_t> &shape, - const std::vector<T> &values) -{ - auto node = g->nodes()->create<luci::CircleConst>(); - node->dtype(dtype); - node->rank(shape.size()); - - uint32_t size = 1; - for (uint32_t i = 0; i < shape.size(); ++i) - { - node->dim(i) = shape.at(i); - size *= shape.at(i); - } - node->shape_status(luci::ShapeStatus::VALID); - -#define INIT_VALUES(DT) \ - { \ - node->size<DT>(size); \ - for (uint32_t i = 0; i < values.size(); ++i) \ - node->at<DT>(i) = values[i]; \ - } - - switch (dtype) - { - case loco::DataType::U8: - INIT_VALUES(loco::DataType::U8); - break; - case loco::DataType::S16: - INIT_VALUES(loco::DataType::S16); - break; - case loco::DataType::S32: - INIT_VALUES(loco::DataType::S32); - break; - case loco::DataType::FLOAT32: - INIT_VALUES(loco::DataType::FLOAT32) - break; - default: - INTERNAL_EXN("create_const_node called with unsupported type"); - break; - } - return node; -} /** * graph having SplitV operator * @@ -95,10 +52,10 @@ public: void init(loco::Graph *g) { const std::vector<int32_t> splits{16, 16}; - auto size_splits = create_const_node(g, loco::DataType::S32, {2}, splits); + auto size_splits = luci::create_const_node(g, loco::DataType::S32, {2}, splits); const std::vector<int32_t> dim{3}; - auto split_dim = create_const_node(g, loco::DataType::S32, {1}, dim); + auto split_dim = luci::create_const_node(g, loco::DataType::S32, {1}, dim); _sv = g->nodes()->create<luci::CircleSplitV>(); _sv->size_splits(size_splits); diff --git a/compiler/luci/pass/src/VerifyQuantizedBiasScale.cpp b/compiler/luci/pass/src/VerifyQuantizedBiasScale.cpp index e65d576cd..d40c19b9b 100644 --- a/compiler/luci/pass/src/VerifyQuantizedBiasScale.cpp +++ b/compiler/luci/pass/src/VerifyQuantizedBiasScale.cpp @@ -31,7 +31,7 @@ namespace bool same(float a, float b) { constexpr float epsilon = 1e-10; - return abs(a - b) < epsilon; + return std::abs(a - b) < epsilon; } // Check bias scale = input scale * weight scale diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h index 6bf7ff698..cc618bf0e 100644 --- a/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h +++ b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h @@ -298,6 +298,13 @@ private: return true; } + bool visit(const luci::CircleSum *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->input())); + return true; + } + bool visit(const luci::CircleArgMax *node) { // node's output is index, thus not quantized @@ -333,6 +340,13 @@ private: return true; } + bool visit(const luci::CircleGelu *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->features())); + return true; + } + bool visit(const luci::CircleGreater *node) { RETURN_FALSE_UNLESS(is_lwq(node->x())); diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp b/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp index 3ce32555b..4bad9522b 100644 --- a/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp +++ b/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp @@ -181,6 +181,12 @@ bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleFullyCon } template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleGelu *node) +{ + return group_has_type(node, Qtype); +} + +template <loco::DataType Qtype, loco::DataType Btype> bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleGreater *node) { RETURN_FALSE_UNLESS(has_type(node, loco::DataType::BOOL)) @@ -454,6 +460,15 @@ bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleStridedS } template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSum *node) +{ + RETURN_FALSE_UNLESS(has_type(node, Qtype)) + RETURN_FALSE_UNLESS(has_type(node->input(), Qtype)) + RETURN_FALSE_UNLESS(has_type(node->reduction_indices(), loco::DataType::S32)) + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleTranspose *node) { RETURN_FALSE_UNLESS(has_type(node, Qtype)) diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeType.h b/compiler/luci/pass/src/VerifyQuantizedNodeType.h index 789d3c7cd..03f1e1d86 100644 --- a/compiler/luci/pass/src/VerifyQuantizedNodeType.h +++ b/compiler/luci/pass/src/VerifyQuantizedNodeType.h @@ -88,6 +88,7 @@ private: bool visit(const luci::CircleFloor *node); bool visit(const luci::CircleFloorDiv *node); bool visit(const luci::CircleFullyConnected *node); + bool visit(const luci::CircleGelu *node); bool visit(const luci::CircleGreater *node); bool visit(const luci::CircleGreaterEqual *node); bool visit(const luci::CircleInstanceNorm *node); @@ -119,6 +120,7 @@ private: bool visit(const luci::CircleSplitVOut *node); bool visit(const luci::CircleSqrt *node); bool visit(const luci::CircleStridedSlice *node); + bool visit(const luci::CircleSum *node); bool visit(const luci::CircleTranspose *node); bool visit(const luci::CircleTransposeConv *node); bool visit(const luci::CircleUnpack *node); diff --git a/compiler/luci/pass/src/helpers/CreateCircleConst.cpp b/compiler/luci/pass/src/helpers/CreateCircleConst.cpp new file mode 100644 index 000000000..bf1b0baf7 --- /dev/null +++ b/compiler/luci/pass/src/helpers/CreateCircleConst.cpp @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CreateCircleConst.h" + +// NOTE Do NOT delete this file; this file enforces compiler to check whether 'CreateCircleConst.h' +// is complete. diff --git a/compiler/luci/pass/src/helpers/CreateCircleConst.h b/compiler/luci/pass/src/helpers/CreateCircleConst.h new file mode 100644 index 000000000..89c1a47be --- /dev/null +++ b/compiler/luci/pass/src/helpers/CreateCircleConst.h @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_PASS_HELPERS_CREATE_CIRCLE_CONST_H__ +#define __LUCI_PASS_HELPERS_CREATE_CIRCLE_CONST_H__ + +#include <luci/IR/CircleNodes.h> + +#include "TypeMapper.h" + +#include <vector> + +namespace luci +{ + +// Create CircleConst filled with a single value +// Never return nullptr +// TODO Remove dtype from the argument +template <typename T> +CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype, + const std::vector<uint32_t> &shape, const T value) +{ + auto node = g->nodes()->create<CircleConst>(); + node->dtype(dtype); + node->rank(shape.size()); + + uint32_t size = 1; + for (uint32_t i = 0; i < shape.size(); ++i) + { + node->dim(i) = shape.at(i); + size *= shape.at(i); + } + node->shape_status(ShapeStatus::VALID); + + node->size<TypeMapper<T>::get()>(size); + for (uint32_t i = 0; i < size; i++) + { + node->at<TypeMapper<T>::get()>(i) = value; + } + + return node; +} + +// Create CircleConst filled with values +// Never return nullptr +// TODO Remove dtype from the argument +template <typename T> +luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype, + const std::vector<uint32_t> &shape, + const std::vector<T> &values) +{ + auto node = g->nodes()->create<luci::CircleConst>(); + node->dtype(dtype); + node->rank(shape.size()); + + uint32_t size = 1; + for (uint32_t i = 0; i < shape.size(); ++i) + { + node->dim(i) = shape.at(i); + size *= shape.at(i); + } + node->shape_status(luci::ShapeStatus::VALID); + + node->size<TypeMapper<T>::get()>(size); + for (uint32_t i = 0; i < size; i++) + { + node->at<TypeMapper<T>::get()>(i) = values[i]; + } + + return node; +} + +} // namespace luci + +#endif // __LUCI_PASS_HELPERS_CREATE_CIRCLE_CONST_H__ diff --git a/compiler/luci/pass/src/helpers/TypeMapper.h b/compiler/luci/pass/src/helpers/TypeMapper.h index 90760e95b..a3e27d259 100644 --- a/compiler/luci/pass/src/helpers/TypeMapper.h +++ b/compiler/luci/pass/src/helpers/TypeMapper.h @@ -14,6 +14,9 @@ * limitations under the License. */ +#ifndef __LUCI_PASS_HELPERS_TYPE_MAPPER_H__ +#define __LUCI_PASS_HELPERS_TYPE_MAPPER_H__ + #include <loco/IR/DataType.h> #include <cstdint> @@ -75,3 +78,5 @@ template <> struct TypeMapper<int64_t> }; } // namespace luci + +#endif // __LUCI_PASS_HELPERS_TYPE_MAPPER_H__ |