diff options
author | Hyeongseok Oh <hseok82.oh@samsung.com> | 2023-04-12 15:42:02 +0900 |
---|---|---|
committer | Hyeongseok Oh <hseok82.oh@samsung.com> | 2023-04-12 15:42:02 +0900 |
commit | 323663bb115ef625642391a5a8e9b35fee8b2ae3 (patch) | |
tree | 17e2a6b91535e6f53f4cacda5e4db6aa0303dd22 /compiler/luci/pass | |
parent | c690d52bdd137ed6a17353aa7af35e8141ece77b (diff) | |
download | nnfw-323663bb115ef625642391a5a8e9b35fee8b2ae3.tar.gz nnfw-323663bb115ef625642391a5a8e9b35fee8b2ae3.tar.bz2 nnfw-323663bb115ef625642391a5a8e9b35fee8b2ae3.zip |
Imported Upstream version 1.22.0upstream/1.22.0
Diffstat (limited to 'compiler/luci/pass')
53 files changed, 4375 insertions, 443 deletions
diff --git a/compiler/luci/pass/CMakeLists.txt b/compiler/luci/pass/CMakeLists.txt index d9d004db9..ac18a5f8d 100644 --- a/compiler/luci/pass/CMakeLists.txt +++ b/compiler/luci/pass/CMakeLists.txt @@ -31,7 +31,7 @@ target_link_libraries(luci_pass PRIVATE luci_log) target_link_libraries(luci_pass PRIVATE luci_service) target_link_libraries(luci_pass PRIVATE luci_logex) target_link_libraries(luci_pass PRIVATE luci_profile) -target_link_libraries(luci_pass PRIVATE mio_tflite280_inc) +target_link_libraries(luci_pass PRIVATE luci_compute) target_link_libraries(luci_pass PRIVATE nncc_common) target_link_libraries(luci_pass PRIVATE pepper_csv2vec) target_link_libraries(luci_pass PRIVATE oops) diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h index b94822c35..d77e89db1 100644 --- a/compiler/luci/pass/include/luci/CircleOptimizer.h +++ b/compiler/luci/pass/include/luci/CircleOptimizer.h @@ -52,14 +52,17 @@ public: FoldCast, FoldDensify, FoldDepthwiseConv2D, + FoldFullyConnected, FoldDequantize, FoldGather, FoldSparseToDense, ForwardReshapeToUnaryOp, + ForwardTransposeOp, SparsifyTensorPass, FusePreActivationBatchNorm, MakeBatchNormGammaPositive, FuseActivationFunction, + FusePRelu, ShuffleWeightTo16x1Float32, RemoveRedundantTranspose, ReplaceMulAddWithDepthwiseConv, @@ -83,6 +86,8 @@ public: RemoveRedundantReshape, RemoveFakeQuant, RemoveQuantDequantSeq, + RemoveDuplicateConst, + UnrollUnidirSeqLSTM, }; enum AlgorithmParameters diff --git a/compiler/luci/pass/include/luci/Pass/FoldFullyConnectedPass.h b/compiler/luci/pass/include/luci/Pass/FoldFullyConnectedPass.h new file mode 100644 index 000000000..bd36ff149 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FoldFullyConnectedPass.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FOLD_FULLY_CONNECTED_PASS_H__ +#define __LUCI_FOLD_FULLY_CONNECTED_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to fold FullyConnected with constant input and filter into a + * constant tensor + */ +struct FoldFullyConnectedPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FoldFullyConnectedPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FOLD_FULLY_CONNECTED_PASS_H__ diff --git a/compiler/luci/pass/src/test/TestIOGraph.test.cpp b/compiler/luci/pass/include/luci/Pass/ForwardTransposeOpPass.h index e58a13f2b..b44b1bde1 100644 --- a/compiler/luci/pass/src/test/TestIOGraph.test.cpp +++ b/compiler/luci/pass/include/luci/Pass/ForwardTransposeOpPass.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,6 +14,24 @@ * limitations under the License. */ -#include "TestIOGraph.h" +#ifndef __LUCI_FORWARD_TRANSPOSE_OP_PASS_H__ +#define __LUCI_FORWARD_TRANSPOSE_OP_PASS_H__ -// This file validates "TestIOGraph.h". Pleaes DO NOT remove this file. +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to Forward Transpose Ops for further optimization. + */ +struct ForwardTransposeOpPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::ForwardTransposeOpPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FORWARD_TRANSPOSE_OP_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/FusePReluPass.h b/compiler/luci/pass/include/luci/Pass/FusePReluPass.h new file mode 100644 index 000000000..a21acf49d --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FusePReluPass.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FUSE_PRELU_PASS_H__ +#define __LUCI_FUSE_PRELU_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to fuse certain pattern of subgraph into CirclePRelu + * with auxiliary nodes + * + * For detailed subgraph pattern to be fused, please check its implementation. + */ +struct FusePReluPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FusePReluPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FUSE_PRELU_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h b/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h index ea6db85d1..6874046f0 100644 --- a/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h +++ b/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h @@ -39,29 +39,12 @@ public: loco::DataType input_model_dtype = loco::DataType::Unknown; loco::DataType output_model_dtype = loco::DataType::Unknown; QuantizationGranularity granularity = QuantizationGranularity::ChannelWise; - loco::DataType input_type = loco::DataType::Unknown; - loco::DataType output_type = loco::DataType::Unknown; + std::vector<loco::DataType> input_types; + std::vector<loco::DataType> output_types; bool TF_style_maxpool = false; std::vector<LayerInfo> layers_info; }; - // For backward-compatibility - // TODO Remove this constructor -public: - QuantizeWithMinMaxPass(loco::DataType input_model_dtype, loco::DataType output_model_dtype, - QuantizationGranularity granularity) - { - _ctx = std::make_unique<Context>(); - { - _ctx->input_model_dtype = input_model_dtype; - _ctx->output_model_dtype = output_model_dtype; - _ctx->granularity = granularity; - _ctx->input_type = output_model_dtype; - _ctx->output_type = output_model_dtype; - _ctx->TF_style_maxpool = false; - } - } - public: QuantizeWithMinMaxPass(std::unique_ptr<Context> &&ctx) : _ctx{std::move(ctx)} { diff --git a/compiler/luci/pass/include/luci/Pass/RemoveDuplicateConstPass.h b/compiler/luci/pass/include/luci/Pass/RemoveDuplicateConstPass.h new file mode 100644 index 000000000..000cdcc43 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/RemoveDuplicateConstPass.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_REMOVE_DUPLICATE_CONST_PASS_H__ +#define __LUCI_REMOVE_DUPLICATE_CONST_PASS_H__ + +#include <luci/IR/CircleNodes.h> +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to remove duplicate Const nodes. + */ +struct RemoveDuplicateConstPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::RemoveDuplicateConstPass"; } + + bool run(loco::Graph *g) final; + +private: + bool remove_duplicate_const(); + + template <loco::DataType DT> void add_to_map(luci::CircleConst *const_node); + + std::map<float, std::vector<CircleConst *>> _sum_to_const; +}; + +} // namespace luci + +#endif // __LUCI_REMOVE_DUPLICATE_CONST_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/UnrollUnidirectionalSequenceLSTMPass.h b/compiler/luci/pass/include/luci/Pass/UnrollUnidirectionalSequenceLSTMPass.h new file mode 100644 index 000000000..fd5a708e8 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/UnrollUnidirectionalSequenceLSTMPass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_UNROLL_UNIDIRECTIONALSEQUENCELSTM_PASS_H__ +#define __LUCI_UNROLL_UNIDIRECTIONALSEQUENCELSTM_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to Unroll UnidirectionalSequenceLSTM + */ +struct UnrollUnidirectionalSequenceLSTMPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::UnrollUnidirectionalSequenceLSTMPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_UNROLL_UNIDIRECTIONALSEQUENCELSTM_PASS_H__ diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp index 74c569d20..5e1613ad9 100644 --- a/compiler/luci/pass/src/CircleOptimizer.cpp +++ b/compiler/luci/pass/src/CircleOptimizer.cpp @@ -23,9 +23,11 @@ #include "luci/Pass/FoldDensifyPass.h" #include "luci/Pass/FoldDepthwiseConv2DPass.h" #include "luci/Pass/FoldDequantizePass.h" +#include "luci/Pass/FoldFullyConnectedPass.h" #include "luci/Pass/FoldGatherPass.h" #include "luci/Pass/FoldSparseToDensePass.h" #include "luci/Pass/ForwardReshapeToUnaryOpPass.h" +#include "luci/Pass/ForwardTransposeOpPass.h" #include "luci/Pass/FuseActivationFunctionPass.h" #include "luci/Pass/FuseAddWithFullyConnectedPass.h" #include "luci/Pass/FuseAddWithTConvPass.h" @@ -36,8 +38,10 @@ #include "luci/Pass/FuseInstanceNormPass.h" #include "luci/Pass/FuseMeanWithMeanPass.h" #include "luci/Pass/FusePreActivationBatchNormPass.h" +#include "luci/Pass/FusePReluPass.h" #include "luci/Pass/FuseTransposeWithMeanPass.h" #include "luci/Pass/MakeBatchNormGammaPositivePass.h" +#include "luci/Pass/RemoveDuplicateConstPass.h" #include "luci/Pass/RemoveFakeQuantPass.h" #include "luci/Pass/RemoveQuantDequantSeqPass.h" #include "luci/Pass/RemoveRedundantReshapePass.h" @@ -66,6 +70,7 @@ #include "luci/Pass/SubstituteTransposeToReshapePass.h" #include "luci/Pass/TransformMinMaxToRelu6Pass.h" #include "luci/Pass/TransformMinReluToRelu6Pass.h" +#include "luci/Pass/UnrollUnidirectionalSequenceLSTMPass.h" // TODO add more passes #include "luci/Pass/CircleShapeInferencePass.h" @@ -274,6 +279,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique<FuseActivationFunctionPass>()); } + if (_options->query(Options::Algorithm::FusePRelu)) + { + phase.emplace_back(std::make_unique<FusePReluPass>()); + } if (_options->query(Options::Algorithm::FuseTransposeWithMean)) { phase.emplace_back(std::make_unique<FuseTransposeWithMeanPass>()); @@ -298,6 +307,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique<luci::FoldDequantizePass>()); } + if (_options->query(Options::Algorithm::FoldFullyConnected)) + { + phase.emplace_back(std::make_unique<luci::FoldFullyConnectedPass>()); + } if (_options->query(Options::Algorithm::FoldGather)) { phase.emplace_back(std::make_unique<luci::FoldGatherPass>()); @@ -310,6 +323,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique<luci::ForwardReshapeToUnaryOpPass>()); } + if (_options->query(Options::Algorithm::ForwardTransposeOp)) + { + phase.emplace_back(std::make_unique<luci::ForwardTransposeOpPass>()); + } if (_options->query(Options::Algorithm::FusePreActivationBatchNorm)) { phase.emplace_back(std::make_unique<luci::FusePreActivationBatchNormPass>()); @@ -326,6 +343,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique<luci::ExpandBroadcastConstPass>()); } + if (_options->query(Options::Algorithm::RemoveDuplicateConst)) + { + phase.emplace_back(std::make_unique<luci::RemoveDuplicateConstPass>()); + } if (_options->query(Options::Algorithm::RemoveFakeQuant)) { phase.emplace_back(std::make_unique<luci::RemoveFakeQuantPass>()); @@ -407,6 +428,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique<luci::TransformMinReluToRelu6Pass>()); } + if (_options->query(Options::Algorithm::UnrollUnidirSeqLSTM)) + { + phase.emplace_back(std::make_unique<luci::UnrollUnidirectionalSequenceLSTMPass>()); + } /* TRANSFORM DECLARATION END */ diff --git a/compiler/luci/pass/src/CircleQuantizer.cpp b/compiler/luci/pass/src/CircleQuantizer.cpp index 9a6550b9f..3ffa1180c 100644 --- a/compiler/luci/pass/src/CircleQuantizer.cpp +++ b/compiler/luci/pass/src/CircleQuantizer.cpp @@ -40,6 +40,7 @@ #include <luci/IR/CircleNode.h> #include <logo/Phase.h> +#include <pepper/csv2vec.h> #include <memory> @@ -49,6 +50,154 @@ namespace using namespace luci; using LayerParam = luci::CircleQuantizer::Options::LayerParam; +// This function updates user-given input_type to match with the input signature of graph +// If user gives only one input_type, it will be expanded to the number of graph inputs +void canonicalize_input_type(loco::Graph *g, std::vector<loco::DataType> &input_type) +{ + if (g == nullptr) + return; + + const auto inputs = g->inputs(); + + assert(inputs); // FIX_CALLER_UNLESS + + // Check validity of the number of input dtype given by a user + if (input_type.size() != 1 and input_type.size() != inputs->size()) + { + throw std::runtime_error( + "Invalid number of input dtype. The number of input dtype should be 1 or " + "the same as the number of graph inputs."); + } + + // Handle the case when a user gives only one input dtype + if (input_type.size() == 1) + { + const auto user_given_dtype = input_type[0]; + input_type.clear(); + + // Expand input dtype to the number of graph inputs + // Since quantizer can only quantize float32, user_given_dtype is set only for float32 inputs + auto input_nodes = loco::input_nodes(g); + for (uint32_t i = 0; i < input_nodes.size(); i++) + { + auto input = loco::must_cast<luci::CircleInput *>(input_nodes[i]); + + if (input->dtype() == loco::DataType::FLOAT32) + input_type.push_back(user_given_dtype); + else + input_type.push_back(input->dtype()); + } + } + + // Finally, check validity of input_type + // input_type is valid if + // C1. for non-float32 model input, input_type == model's input dtype + // or + // C2. for float32 model input, input_type == uint8, int16, or float32 + auto input_nodes = loco::input_nodes(g); + for (uint32_t i = 0; i < input_nodes.size(); i++) + { + auto input = loco::must_cast<luci::CircleInput *>(input_nodes[i]); + assert(i == input->index()); // FIX_ME_UNLESS + + if (input->dtype() != loco::DataType::FLOAT32) + { + // C1 + if (input->dtype() != input_type[i]) + throw std::runtime_error( + "Input dtype of " + input->name() + + " is invalid. It has to be the same with the model's input dtype."); + } + else + { + // C2 + if (input_type[i] != loco::DataType::FLOAT32 and input_type[i] != loco::DataType::U8 and + input_type[i] != loco::DataType::S16) + { + throw std::runtime_error("Input dtype of " + input->name() + + " is invalid. For float32 input, the input dtype after " + "quantization must be one of uint8, int16, or float32."); + } + } + } +} + +// This function updates user-given output_type to match with the output signature of graph +// If user gives only one output_type, it will be expanded to the number of graph outputs +// NOTE This function is almost same with canonicalize_input_type, but it is written as a +// separate function for more precise error messaging. +// TODO Find a way to reduce duplicate codes +void canonicalize_output_type(loco::Graph *g, std::vector<loco::DataType> &output_type) +{ + if (g == nullptr) + return; + + const auto outputs = g->outputs(); + + assert(outputs); // FIX_CALLER_UNLESS + + // Check validity of the number of output dtype given by a user + if (output_type.size() != 1 and output_type.size() != outputs->size()) + { + throw std::runtime_error( + "Invalid number of output dtype. The number of output dtype should be 1 or " + "the same as the number of graph outputs."); + } + + // Handle the case when a user gives only one output dtype + if (output_type.size() == 1) + { + const auto user_given_dtype = output_type[0]; + output_type.clear(); + + // Expand output dtype to the number of graph outputs + // If dtype of graph output is float32, it will be replaced with user_given_dtype + // Otherwise, it will not change + auto output_nodes = loco::output_nodes(g); + for (uint32_t i = 0; i < output_nodes.size(); i++) + { + auto output = loco::must_cast<luci::CircleOutput *>(output_nodes[i]); + + if (output->dtype() == loco::DataType::FLOAT32) + output_type.push_back(user_given_dtype); + else + output_type.push_back(output->dtype()); + } + } + + // Finally, check validity of output_type + // output_type is valid if + // C1. for non-float32 model output, output_type == model's output dtype + // or + // C2. for float32 model output, output_type == uint8, int16, or float32 + auto output_nodes = loco::output_nodes(g); + for (uint32_t i = 0; i < output_nodes.size(); i++) + { + auto output = loco::must_cast<luci::CircleOutput *>(output_nodes[i]); + assert(i == output->index()); // FIX_ME_UNLESS + + if (output->dtype() != loco::DataType::FLOAT32) + { + // C1 + if (output->dtype() != output_type[i]) + throw std::runtime_error( + "Output dtype of " + output->name() + + " is invalid. It has to be the same with the model's output dtype."); + } + else + { + // C2 + if (output_type[i] != loco::DataType::FLOAT32 and output_type[i] != loco::DataType::U8 and + output_type[i] != loco::DataType::S16) + { + throw std::runtime_error("Output dtype of " + output->name() + + " is invalid. For float32 output, the output dtype after " + "quantization must be one of uint8, int16, or float32."); + } + } + } +} + template <typename T> T lexical_cast(const std::string &str) { std::istringstream ss; @@ -253,8 +402,10 @@ void CircleQuantizer::quantize(loco::Graph *g) const static const std::vector<std::string> qwmm_supported_input_model_dtype{"float32"}; static const std::vector<std::string> qwmm_supported_output_model_dtype{"uint8", "int16"}; static const std::vector<std::string> qwmm_supported_granularity{"layer", "channel"}; - static const std::vector<std::string> qwmm_supported_input_type{"uint8", "int16", "float32"}; - static const std::vector<std::string> qwmm_supported_output_type{"uint8", "int16", "float32"}; + static const std::vector<std::string> qwmm_supported_input_type{"uint8", "int16", "int32", + "int64", "float32", "bool"}; + static const std::vector<std::string> qwmm_supported_output_type{"uint8", "int16", "int32", + "int64", "float32", "bool"}; auto input_model_dtype = _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype); @@ -268,6 +419,9 @@ void CircleQuantizer::quantize(loco::Graph *g) const if (output_type.empty()) output_type = output_model_dtype; + auto input_type_vec = pepper::csv_to_vector<std::string>(input_type); + auto output_type_vec = pepper::csv_to_vector<std::string>(output_type); + bool TF_style_maxpool = _options->param(Options::AlgorithmParameters::Quantize_TF_style_maxpool) == "True"; @@ -285,13 +439,19 @@ void CircleQuantizer::quantize(loco::Graph *g) const throw std::runtime_error("Unsupported granularity. List of supported granularity: " + to_string(qwmm_supported_granularity)); - if (!in_array(to_lower_case(input_type), qwmm_supported_input_type)) - throw std::runtime_error("Unsupported input type. List of supported input types: " + - to_string(qwmm_supported_input_type)); + for (auto dtype : input_type_vec) + { + if (!in_array(to_lower_case(dtype), qwmm_supported_input_type)) + throw std::runtime_error("Unsupported input type. List of supported input types: " + + to_string(qwmm_supported_input_type)); + } - if (!in_array(to_lower_case(output_type), qwmm_supported_output_type)) - throw std::runtime_error("Unsupported output type. List of supported output types: " + - to_string(qwmm_supported_output_type)); + for (auto dtype : output_type_vec) + { + if (!in_array(to_lower_case(dtype), qwmm_supported_output_type)) + throw std::runtime_error("Unsupported output type. List of supported output types: " + + to_string(qwmm_supported_output_type)); + } if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise && str_to_dtype(output_model_dtype) != loco::DataType::U8) @@ -314,6 +474,13 @@ void CircleQuantizer::quantize(loco::Graph *g) const } } + auto input_types = str_vec_to_dtype_vec(input_type_vec); + auto output_types = str_vec_to_dtype_vec(output_type_vec); + + // Canonicalize user-given input/output_type (match with # of inputs/outputs) + canonicalize_input_type(g, input_types); + canonicalize_output_type(g, output_types); + // Input model checker for quantization luci::QuantizePreCheckerPass input_model_checker{}; input_model_checker.run(g); @@ -323,8 +490,8 @@ void CircleQuantizer::quantize(loco::Graph *g) const ctx->input_model_dtype = str_to_dtype(input_model_dtype); ctx->output_model_dtype = str_to_dtype(output_model_dtype); ctx->granularity = str_to_granularity(granularity); - ctx->input_type = str_to_dtype(input_type); - ctx->output_type = str_to_dtype(output_type); + ctx->input_types = input_types; + ctx->output_types = output_types; ctx->TF_style_maxpool = TF_style_maxpool; for (auto layer_param : layer_params) @@ -347,8 +514,8 @@ void CircleQuantizer::quantize(loco::Graph *g) const { verify_ctx->output_model_dtype = str_to_dtype(output_model_dtype); verify_ctx->granularity = str_to_granularity(granularity); - verify_ctx->input_type = str_to_dtype(input_type); - verify_ctx->output_type = str_to_dtype(output_type); + verify_ctx->input_types = input_types; + verify_ctx->output_types = output_types; verify_ctx->TF_style_maxpool = TF_style_maxpool; for (auto layer_param : layer_params) diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp index 55a29d105..99e1e2939 100644 --- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp +++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp @@ -503,43 +503,30 @@ bool is_NCHW(const luci::CirclePadV2 *node) return true; } -// NOTE Following conditions can be extended later -// NOTE Used for Maximum, Miminum as ReLU/ReLU6 -// -// Find T with an NCHW pattern described below -// - Input (non-constant) shape : [N, C, H, W] -// - Input (constant) shape : [1] or [] -// - Output shape : [N, C, H, W] -template <class T> -bool is_NCHW_with_s_const(const T *node, luci::CircleNode *&pred_node, - luci::CircleConst *&comp_const) +bool is_const(const loco::Node *node) { - auto x = dynamic_cast<luci::CircleConst *>(node->x()); - auto y = dynamic_cast<luci::CircleConst *>(node->y()); - - if (x != nullptr && y == nullptr) - { - pred_node = loco::must_cast<luci::CircleNode *>(node->y()); - comp_const = x; - } - else if (x == nullptr && y != nullptr) - { - pred_node = loco::must_cast<luci::CircleNode *>(node->x()); - comp_const = y; - } - else - { - // Ignore if T does not have a comp_const input. + if (not dynamic_cast<const luci::CircleConst *>(node)) return false; - } - if (pred_node->rank() != 4) + return true; +} + +bool is_scalar_const(const loco::Node *node) +{ + auto const_node = dynamic_cast<const luci::CircleConst *>(node); + if (not const_node) return false; - // Check if scalar - const auto const_rank = comp_const->rank(); - if (const_rank == 0 || (const_rank == 1 && comp_const->dim(0).value() == 1)) + const auto const_rank = const_node->rank(); + // shape of scalar + // 1. rank = 0 + // 2. rank = 1, dimension = 1 + if (const_rank == 0) + return true; + + if (const_rank == 1 && const_node->dim(0).value() == 1) return true; + return false; } @@ -854,22 +841,30 @@ class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool> bool visit(luci::CircleLogistic *node) { return convert_unary_x<luci::CircleLogistic>(node); } - bool visit(luci::CircleLogSoftmax *node) - { - return convert_unary_logits<luci::CircleLogSoftmax>(node); - } - bool visit(luci::CircleMaximum *node) { - luci::CircleNode *pred_node = nullptr; - luci::CircleConst *comp_constant = nullptr; - - if (is_NCHW_with_s_const<luci::CircleMaximum>(node, pred_node, comp_constant)) + if ((not is_const(node->x())) and is_scalar_const(node->y())) { auto pre_trans = create_pre_transpose(node); - pre_trans->a(pred_node); + pre_trans->a(node->x()); node->x(pre_trans); } + else if (is_scalar_const(node->x()) and (not is_const(node->y()))) + { + auto pre_trans = create_pre_transpose(node); + pre_trans->a(node->y()); + node->y(pre_trans); + } + else if ((not is_const(node->x())) and (not is_const(node->y()))) + { + auto pre_trans_x = create_pre_transpose(node); + pre_trans_x->a(node->x()); + node->x(pre_trans_x); + + auto pre_trans_y = create_pre_transpose(node); + pre_trans_y->a(node->y()); + node->y(pre_trans_y); + } else { // TODO support other cases @@ -963,15 +958,18 @@ class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool> bool visit(luci::CircleMinimum *node) { - luci::CircleNode *pred_node = nullptr; - luci::CircleConst *comp_constant = nullptr; - - if (is_NCHW_with_s_const<luci::CircleMinimum>(node, pred_node, comp_constant)) + if ((not is_const(node->x())) and is_scalar_const(node->y())) { auto pre_trans = create_pre_transpose(node); - pre_trans->a(pred_node); + pre_trans->a(node->x()); node->x(pre_trans); } + else if (is_scalar_const(node->x()) and (not is_const(node->y()))) + { + auto pre_trans = create_pre_transpose(node); + pre_trans->a(node->y()); + node->y(pre_trans); + } else { // TODO support other cases @@ -1168,14 +1166,88 @@ class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool> return true; } + // TODO Reduce duplicate codes with CircleReduceMax + bool visit(luci::CircleReduceMin *node) + { + auto input = loco::must_cast<luci::CircleNode *>(node->input()); + if (input->rank() != 4) + return false; + + auto rindices = dynamic_cast<luci::CircleConst *>(node->reduction_indices()); + if (not rindices) + return false; + + auto nhwc_rindices = create_NHWC_rindices(rindices); + if (not nhwc_rindices) + return false; + + auto pre_trans = create_pre_transpose(node); + pre_trans->a(input); + node->input(pre_trans); + + // Do shape inference for this node again. + node->shape_status(luci::ShapeStatus::UNDEFINED); + + node->reduction_indices(nhwc_rindices); + + if (node->keep_dims()) + { + auto post_trans = create_post_transpose(node); + loco::replace(node).with(post_trans); + + post_trans->a(node); + + return true; + } + + // The below codes handle the cases where node->keep_dims() == false + // 1D output never needs a transpose + if (node->rank() <= 1) + return true; + + std::vector<bool> reduced_dims_nhwc(4, false); + uint32_t num_reduced_indices = nhwc_rindices->size<loco::DataType::S32>(); + + for (uint32_t ri = 0; ri < num_reduced_indices; ++ri) + { + reduced_dims_nhwc[nhwc_rindices->at<loco::DataType::S32>(ri)] = true; + } + + // if channel dimension has been reduced, we don't need a transpose + if (reduced_dims_nhwc[3]) + return true; + + // likewise, if both space dimensions are reduced, no transpose is needed + if (reduced_dims_nhwc[1] && reduced_dims_nhwc[2]) + return true; + + std::vector<int32_t> post_trans_ind; + // case 1: only N is reduced + if (num_reduced_indices == 1 && reduced_dims_nhwc[0]) + post_trans_ind = {2, 0, 1}; + + // case 2: only H or W is reduced + if (num_reduced_indices == 1 && (reduced_dims_nhwc[1] || reduced_dims_nhwc[2])) + post_trans_ind = {0, 2, 1}; + + // case 3: N and either H or W are reduced + if (num_reduced_indices == 2) + post_trans_ind = {1, 0}; + + auto post_trans = create_Nd_transpose(node, post_trans_ind); + loco::replace(node).with(post_trans); + + post_trans->a(node); + + return true; + } + bool visit(luci::CircleRelu *node) { return convert_unary_features<luci::CircleRelu>(node); } bool visit(luci::CircleRelu6 *node) { return convert_unary_features<luci::CircleRelu6>(node); } bool visit(luci::CircleRsqrt *node) { return convert_unary_x<luci::CircleRsqrt>(node); } - bool visit(luci::CircleSoftmax *node) { return convert_unary_logits<luci::CircleSoftmax>(node); } - bool visit(luci::CircleSplitV *node) { // Change split dimension @@ -1375,6 +1447,10 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g) collect_intermediate = [&](loco::Node *n) { for (auto succ : loco::succs(n)) { + // Skip unnecessary traversal + if (intermediate.find(succ) != intermediate.end()) + continue; + // Exit condition if (is_post_transpose(succ) || is_post_reshape(succ)) continue; @@ -1429,12 +1505,13 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g) set_data_format(node, DataFormat::NCHW); } break; + // SOFTMAX, LOG_SOFTMAX are not converted, because + // tflite/circle assumes the last channel is always axis case luci::CircleOpcode::ADD: case luci::CircleOpcode::CONCATENATION: case luci::CircleOpcode::ELU: case luci::CircleOpcode::LEAKY_RELU: case luci::CircleOpcode::LOGISTIC: - case luci::CircleOpcode::LOG_SOFTMAX: case luci::CircleOpcode::MAXIMUM: case luci::CircleOpcode::MEAN: case luci::CircleOpcode::MINIMUM: @@ -1443,10 +1520,10 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g) case luci::CircleOpcode::PAD: case luci::CircleOpcode::PADV2: case luci::CircleOpcode::REDUCE_MAX: + case luci::CircleOpcode::REDUCE_MIN: case luci::CircleOpcode::RELU: case luci::CircleOpcode::RELU6: case luci::CircleOpcode::RSQRT: - case luci::CircleOpcode::SOFTMAX: case luci::CircleOpcode::SPLIT_V: case luci::CircleOpcode::SQUARED_DIFFERENCE: case luci::CircleOpcode::SUB: @@ -1487,7 +1564,8 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g) { // TODO replace the check above with the input rank check, and remove the condition below if (not dynamic_cast<luci::CircleMean *>(node) and - not dynamic_cast<luci::CircleReduceMax *>(node)) + not dynamic_cast<luci::CircleReduceMax *>(node) and + not dynamic_cast<luci::CircleReduceMin *>(node)) continue; } diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp index 6bb3d3268..fd326518e 100644 --- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp +++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp @@ -483,22 +483,6 @@ public: luci::CircleLogistic *logistic = nullptr; }; -class LogSoftmaxGraph final : public SimpleGraph -{ -protected: - loco::Node *insertGraphBody(loco::Node *input) override - { - log_softmax = g.nodes()->create<luci::CircleLogSoftmax>(); - log_softmax->logits(input); - log_softmax->name("log_softmax"); - - return log_softmax; - } - -public: - luci::CircleLogSoftmax *log_softmax = nullptr; -}; - class MaximumGraph final : public SimpleGraph { protected: @@ -530,6 +514,27 @@ public: luci::CircleConst *limit = nullptr; }; +class MaximumNonConstGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + max = g.nodes()->create<luci::CircleMaximum>(); + max->dtype(loco::DataType::FLOAT32); + max->shape({1, 16, 4, 4}); + + max->x(input); + max->y(input); + + max->name("max"); + + return max; + } + +public: + luci::CircleMaximum *max = nullptr; +}; + class MeanGraph final : public SimpleGraph { protected: @@ -874,6 +879,51 @@ private: std::initializer_list<uint32_t> _shape = {1, 16, 1, 1}; }; +class ReduceMinGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + rm = g.nodes()->create<luci::CircleReduceMin>(); + rindices = g.nodes()->create<luci::CircleConst>(); + + rm->dtype(loco::DataType::FLOAT32); + rindices->dtype(loco::DataType::S32); + + rm->shape(_shape); + rindices->shape({static_cast<uint32_t>(_axes.size())}); + + rindices->size<loco::DataType::S32>(_axes.size()); + for (uint32_t i = 0; i < _axes.size(); ++i) + { + rindices->at<loco::DataType::S32>(i) = _axes[i]; + } + + rm->input(input); + rm->reduction_indices(rindices); + rm->keep_dims(_keep_dims); + + rm->name("reduce_max"); + rindices->name("rindices"); + + return rm; + } + +public: + void keep_dims(bool val) { _keep_dims = val; } + void axes(std::vector<int32_t> val) { _axes = val; } + void shape(std::initializer_list<uint32_t> val) { _shape = val; } + +public: + luci::CircleReduceMin *rm = nullptr; + luci::CircleConst *rindices = nullptr; + +private: + bool _keep_dims = true; + std::vector<int32_t> _axes = {2, 3}; + std::initializer_list<uint32_t> _shape = {1, 16, 1, 1}; +}; + class ReluGraph final : public SimpleGraph { protected: @@ -922,22 +972,6 @@ public: luci::CircleRsqrt *rsqrt = nullptr; }; -class SoftmaxGraph final : public SimpleGraph -{ -protected: - loco::Node *insertGraphBody(loco::Node *input) override - { - softmax = g.nodes()->create<luci::CircleSoftmax>(); - softmax->logits(input); - softmax->name("softmax"); - - return softmax; - } - -public: - luci::CircleSoftmax *softmax = nullptr; -}; - class SplitVGraphlet { public: @@ -1357,44 +1391,50 @@ TEST(ConvertNCHWToNHWC, Logistic) EXPECT_EQ(16, g.logistic->dim(3).value()); } -TEST(ConvertNCHWToNHWC, LogSoftmax) +TEST(ConvertNCHWToNHWC, Maximum) { - LogSoftmaxGraph g; + MaximumGraph g; g.init(); - run_phase(&g.g, true, true); + run_phase(&g.g, false, false); + + auto input_succs = loco::succs(g.input); + EXPECT_EQ(1, input_succs.size()); + check_post_trans(*input_succs.begin()); - check_pre_trans(g.log_softmax->logits()); + check_pre_trans(g.max->x()); - auto log_softmax_succs = loco::succs(g.log_softmax); - EXPECT_EQ(1, log_softmax_succs.size()); - check_post_trans(*log_softmax_succs.begin()); + auto max_succs = loco::succs(g.max); + EXPECT_EQ(1, max_succs.size()); + check_post_trans(*max_succs.begin()); - // Check log_softmax shape - EXPECT_EQ(1, g.log_softmax->dim(0).value()); - EXPECT_EQ(4, g.log_softmax->dim(1).value()); - EXPECT_EQ(4, g.log_softmax->dim(2).value()); - EXPECT_EQ(16, g.log_softmax->dim(3).value()); + check_pre_trans(g.output->from()); } -TEST(ConvertNCHWToNHWC, Maximum) +TEST(ConvertNCHWToNHWC, Maximum_non_scalar_NEG) { MaximumGraph g; g.init(); - run_phase(&g.g, false, false); + g.limit->shape({3}); - auto input_succs = loco::succs(g.input); - EXPECT_EQ(1, input_succs.size()); - check_post_trans(*input_succs.begin()); + luci::ConvertNCHWToNHWCPass pass(true, true); + EXPECT_FALSE(pass.run(&g.g)); +} + +TEST(ConvertNCHWToNHWC, MaximumNonConst) +{ + MaximumNonConstGraph g; + g.init(); + + run_phase(&g.g, true, true); check_pre_trans(g.max->x()); + check_pre_trans(g.max->y()); auto max_succs = loco::succs(g.max); EXPECT_EQ(1, max_succs.size()); check_post_trans(*max_succs.begin()); - - check_pre_trans(g.output->from()); } TEST(ConvertNCHWToNHWC, Mean) @@ -1553,6 +1593,17 @@ TEST(ConvertNCHWToNHWC, Minimum) check_pre_trans(g.output->from()); } +TEST(ConvertNCHWToNHWC, Minimum_non_scalar_NEG) +{ + MinimumGraph g; + g.init(); + + g.limit->shape({3}); + + luci::ConvertNCHWToNHWCPass pass(true, true); + EXPECT_FALSE(pass.run(&g.g)); +} + TEST(ConvertNCHWToNHWC, Mul) { MulGraph g; @@ -1893,6 +1944,85 @@ TEST(ConvertNCHWToNHWC, ReduceMax_keep_dims_false) } } +TEST(ConvertNCHWToNHWC, ReduceMin) +{ + ReduceMinGraph g; + g.init(); + + run_phase(&g.g, true, true); + + check_pre_trans(g.rm->input()); + + auto rm_succs = loco::succs(g.rm); + EXPECT_EQ(1, rm_succs.size()); + check_post_trans(*rm_succs.begin()); + + auto new_rindices = dynamic_cast<luci::CircleConst *>(g.rm->reduction_indices()); + EXPECT_NE(nullptr, new_rindices); + EXPECT_EQ(1, new_rindices->rank()); + EXPECT_EQ(2, new_rindices->dim(0).value()); + EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>()); + EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0)); + EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1)); +} + +TEST(ConvertNCHWToNHWC, ReduceMin_keep_dims_false) +{ + struct TC + { + std::vector<int32_t> nchw_ind; + std::vector<int32_t> nhwc_ind; + std::initializer_list<uint32_t> shape; + bool needs_transpose = false; + }; + + uint32_t n = 1; + uint32_t c = 16; + uint32_t h = 4; + uint32_t w = 4; + + std::vector<TC> test_cases{{{0}, {0}, {c, h, w}, true}, {{1}, {3}, {n, h, w}, false}, + {{2}, {1}, {n, c, w}, true}, {{3}, {2}, {n, c, h}, true}, + {{0, 1}, {0, 3}, {h, w}, false}, {{0, 2}, {0, 1}, {c, w}, true}, + {{0, 3}, {0, 2}, {c, h}, true}, {{1, 2}, {3, 1}, {n, w}, false}, + {{1, 3}, {3, 2}, {n, h}, false}, {{2, 3}, {1, 2}, {n, c}, false}, + {{0, 1, 2}, {0, 3, 1}, {w}, false}}; + + for (auto &tc : test_cases) + { + ReduceMinGraph g; + g.keep_dims(false); + g.axes(tc.nchw_ind); + g.shape(tc.shape); + g.init(); + + run_phase(&g.g, true, true); + + check_pre_trans(g.rm->input()); + + auto rm_succs = loco::succs(g.rm); + EXPECT_EQ(1, rm_succs.size()); + if (tc.needs_transpose) + { + EXPECT_NE(nullptr, dynamic_cast<luci::CircleTranspose *>(*rm_succs.begin())); + } + else + { + EXPECT_NE(nullptr, dynamic_cast<luci::CircleOutput *>(*rm_succs.begin())); + } + + auto new_rindices = dynamic_cast<luci::CircleConst *>(g.rm->reduction_indices()); + EXPECT_NE(nullptr, new_rindices); + EXPECT_EQ(1, new_rindices->rank()); + EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->dim(0).value()); + EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->size<loco::DataType::S32>()); + for (uint32_t i = 0; i < tc.nhwc_ind.size(); ++i) + { + EXPECT_EQ(tc.nhwc_ind[i], new_rindices->at<loco::DataType::S32>(i)); + } + } +} + TEST(ConvertNCHWToNHWC, Relu) { ReluGraph g; @@ -1953,26 +2083,6 @@ TEST(ConvertNCHWToNHWC, Rsqrt) EXPECT_EQ(16, g.rsqrt->dim(3).value()); } -TEST(ConvertNCHWToNHWC, Softmax) -{ - SoftmaxGraph g; - g.init(); - - run_phase(&g.g, true, true); - - check_pre_trans(g.softmax->logits()); - - auto softmax_succs = loco::succs(g.softmax); - EXPECT_EQ(1, softmax_succs.size()); - check_post_trans(*softmax_succs.begin()); - - // Check softmax shape - EXPECT_EQ(1, g.softmax->dim(0).value()); - EXPECT_EQ(4, g.softmax->dim(1).value()); - EXPECT_EQ(4, g.softmax->dim(2).value()); - EXPECT_EQ(16, g.softmax->dim(3).value()); -} - TEST(ConvertNCHWToNHWC, SplitV) { SplitVGraph g; diff --git a/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp index 72f590135..aacfce3d0 100644 --- a/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp +++ b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp @@ -31,7 +31,10 @@ namespace luci::CircleQuantize *create_quantize(luci::CircleNode *node) { auto quantize = node->graph()->nodes()->create<luci::CircleQuantize>(); - quantize->name(node->name() + "_Quantize"); + // DESIGN NOTE: Why use '_FQ_Quantize' instead of '_Quantize'? + // '_Quantize' is used in mixed-precision quantization + // We add '_FQ' to distinguish Op from mixed-precision quantization + quantize->name(node->name() + "_FQ_Quantize"); quantize->dtype(node->dtype()); quantize->rank(node->rank()); for (uint32_t i = 0; i < node->rank(); i++) @@ -50,7 +53,10 @@ luci::CircleQuantize *create_quantize(luci::CircleNode *node) luci::CircleDequantize *create_dequantize(luci::CircleNode *node) { auto dequantize = node->graph()->nodes()->create<luci::CircleDequantize>(); - dequantize->name(node->name() + "_Dequantize"); + // DESIGN NOTE: Why use '_FQ_Dequantize' instead of '_Dequantize'? + // '_Dequantize' is used in mixed-precision quantization + // We add '_FQ' to distinguish Op from mixed-precision quantization + dequantize->name(node->name() + "_FQ_Dequantize"); dequantize->dtype(loco::DataType::FLOAT32); dequantize->rank(node->rank()); for (uint32_t i = 0; i < node->rank(); i++) @@ -184,6 +190,7 @@ struct FakeQuantize final : public luci::CircleNodeMutableVisitor<void> // For non-const activation, insert Quantize-Dequantize Ops // and dequantize the node + void visit(luci::CircleAbs *node) { fq_activation(node); } void visit(luci::CircleAdd *node) { fq_activation(node); } void visit(luci::CircleAveragePool2D *node) { fq_activation(node); } void visit(luci::CircleBatchMatMul *node) { fq_activation(node); } @@ -201,6 +208,7 @@ struct FakeQuantize final : public luci::CircleNodeMutableVisitor<void> void visit(luci::CirclePad *node) { fq_activation(node); } void visit(luci::CirclePRelu *node) { fq_activation(node); } void visit(luci::CircleMean *node) { fq_activation(node); } + void visit(luci::CircleReduceProd *node) { fq_activation(node); } void visit(luci::CircleReduceMax *node) { fq_activation(node); } void visit(luci::CircleRelu *node) { fq_activation(node); } void visit(luci::CircleRelu6 *node) { fq_activation(node); } @@ -216,15 +224,20 @@ struct FakeQuantize final : public luci::CircleNodeMutableVisitor<void> // (dtype will be automatically updated by type inference) void visit(luci::CircleCast *) {} void visit(luci::CircleConcatenation *) {} + void visit(luci::CircleDepthToSpace *) {} void visit(luci::CircleGather *) {} void visit(luci::CircleSlice *) {} void visit(luci::CircleStridedSlice *) {} void visit(luci::CircleReshape *) {} + void visit(luci::CircleSpaceToDepth *) {} void visit(luci::CircleSplit *) {} void visit(luci::CircleSplitOut *) {} void visit(luci::CircleSplitV *) {} void visit(luci::CircleSplitVOut *) {} void visit(luci::CircleTranspose *) {} + void visit(luci::CirclePack *) {} + void visit(luci::CircleUnpack *) {} + void visit(luci::CircleUnpackOut *) {} // For Ops that return index, fake quantization is unnecessary void visit(luci::CircleArgMax *) {} diff --git a/compiler/luci/pass/src/ExpandBroadcastConstPass.test.cpp b/compiler/luci/pass/src/ExpandBroadcastConstPass.test.cpp index 0734e0778..5df1b72dc 100644 --- a/compiler/luci/pass/src/ExpandBroadcastConstPass.test.cpp +++ b/compiler/luci/pass/src/ExpandBroadcastConstPass.test.cpp @@ -19,6 +19,8 @@ #include <luci/IR/CircleNodes.h> +#include <limits> // std::numeric_limits + #include <gtest/gtest.h> namespace diff --git a/compiler/luci/pass/src/FoldDepthwiseConv2DPass.cpp b/compiler/luci/pass/src/FoldDepthwiseConv2DPass.cpp index 6e423e3d9..33f9f1d77 100644 --- a/compiler/luci/pass/src/FoldDepthwiseConv2DPass.cpp +++ b/compiler/luci/pass/src/FoldDepthwiseConv2DPass.cpp @@ -23,6 +23,8 @@ #include <luci/Log.h> +#include <limits> // std::numeric_limits + namespace { diff --git a/compiler/luci/pass/src/FoldDepthwiseConv2DPass.test.cpp b/compiler/luci/pass/src/FoldDepthwiseConv2DPass.test.cpp index b1ef56833..36cae0437 100644 --- a/compiler/luci/pass/src/FoldDepthwiseConv2DPass.test.cpp +++ b/compiler/luci/pass/src/FoldDepthwiseConv2DPass.test.cpp @@ -19,6 +19,8 @@ #include <luci/IR/CircleNodes.h> +#include <limits> // std::numeric_limits + #include <gtest/gtest.h> namespace diff --git a/compiler/luci/pass/src/FoldFullyConnectedPass.cpp b/compiler/luci/pass/src/FoldFullyConnectedPass.cpp new file mode 100644 index 000000000..a3bca7eda --- /dev/null +++ b/compiler/luci/pass/src/FoldFullyConnectedPass.cpp @@ -0,0 +1,198 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FoldFullyConnectedPass.h" + +#include <tensorflow/lite/kernels/internal/reference/fully_connected.h> + +#include <luci/IR/CircleNodes.h> +#include <luci/IR/AttrFusedActFunc.h> + +#include <luci/Log.h> + +#include <limits> // std::numeric_limits + +namespace +{ + +bool set_kernel_parameters(tflite::FullyConnectedParams *params, luci::CircleFullyConnected *node) +{ + switch (node->fusedActivationFunction()) + { + case luci::FusedActFunc::NONE: + case luci::FusedActFunc::TANH: + params->float_activation_min = std::numeric_limits<float>::lowest(); + params->float_activation_max = std::numeric_limits<float>::max(); + break; + case luci::FusedActFunc::RELU: + params->float_activation_min = 0; + params->float_activation_max = std::numeric_limits<float>::max(); + break; + case luci::FusedActFunc::RELU_N1_TO_1: + params->float_activation_min = -1; + params->float_activation_max = 1; + break; + case luci::FusedActFunc::RELU6: + params->float_activation_min = 0; + params->float_activation_max = 6; + break; + default: + { + LOGGER(l); + WARN(l) << "Unsupported activation: " << uint32_t(node->fusedActivationFunction()); + return false; + } + } + + assert(node->weights_format() == + luci::CircleFullyConnected::WeightsFormat::DEFAULT); // FIX_CALLER_UNLESS + params->weights_format = tflite::FullyConnectedWeightsFormat::kDefault; + + return true; +} + +#define RETURN_FALSE_UNLESS(cond) \ + if (not(cond)) \ + return false; + +/** + * Fold FullyConnected with constant input and filter into a constant tensor + * + * BEFORE + * + * [CircleConst] [CircleConst] + * | | + * [CircleFullyConnected] + * + * AFTER + * + * [CircleConst] + */ +bool fold_fully_connected(luci::CircleFullyConnected *node) +{ + RETURN_FALSE_UNLESS(node != nullptr); + + LOGGER(l); + + auto const input = dynamic_cast<luci::CircleConst *>(node->input()); + auto const weights = dynamic_cast<luci::CircleConst *>(node->weights()); + auto const bias = dynamic_cast<luci::CircleConst *>(node->bias()); + auto const no_bias = dynamic_cast<luci::CircleOutputExclude *>(node->bias()); + + RETURN_FALSE_UNLESS(input != nullptr); + RETURN_FALSE_UNLESS(weights != nullptr); + RETURN_FALSE_UNLESS(node->weights_format() == luci::CircleFullyConnected::WeightsFormat::DEFAULT); + RETURN_FALSE_UNLESS(bias != nullptr or no_bias != nullptr); + + RETURN_FALSE_UNLESS(input->dtype() == loco::DataType::FLOAT32); + RETURN_FALSE_UNLESS(weights->dtype() == loco::DataType::FLOAT32); + if (bias) + RETURN_FALSE_UNLESS(bias->dtype() == loco::DataType::FLOAT32); + + auto const input_elems = input->size<loco::DataType::FLOAT32>(); + + RETURN_FALSE_UNLESS(weights->rank() == 2); + RETURN_FALSE_UNLESS(input_elems % weights->dim(1).value() == 0); + auto const batch_size = input_elems / weights->dim(1).value(); + auto const num_units = weights->dim(0).value(); + + if (bias) + RETURN_FALSE_UNLESS(bias->size<loco::DataType::FLOAT32>() == num_units); + + tflite::FullyConnectedParams params{}; + if (!set_kernel_parameters(¶ms, node)) + return false; // Unsupported kernel parameter values + + std::vector<uint32_t> output_shape; + if (node->keep_num_dims() == false) + { + output_shape.push_back(batch_size); + output_shape.push_back(num_units); + } + else + { + output_shape.resize(input->rank()); + for (uint32_t i = 0; i < input->rank(); i++) + output_shape[i] = input->dim(i).value(); + output_shape[input->rank() - 1] = num_units; + } + + auto constant = node->graph()->nodes()->create<luci::CircleConst>(); + { + constant->name(node->name()); + constant->dtype(node->dtype()); + constant->rank(node->rank()); + constant->shape_status(luci::ShapeStatus::VALID); + uint32_t num_elem = 1; + for (uint32_t i = 0; i < node->rank(); ++i) + { + constant->dim(i).set(node->dim(i).value()); + num_elem *= node->dim(i).value(); + } + constant->size<loco::DataType::FLOAT32>(num_elem); + } + + auto tensor_shape = [](luci::CircleNode *node) { + if (node == nullptr) + return tflite::RuntimeShape(); + + tflite::RuntimeShape runtime_shape(node->rank()); + for (uint32_t i = 0; i < node->rank(); ++i) + runtime_shape.SetDim(i, node->dim(i).value()); + return runtime_shape; + }; + + auto tensor_data = [](luci::CircleConst *node) -> float * { + if (node == nullptr) + return nullptr; + + return &node->at<loco::DataType::FLOAT32>(0); + }; + + tflite::reference_ops::FullyConnected( + params, tensor_shape(input), tensor_data(input), tensor_shape(weights), tensor_data(weights), + tensor_shape(bias), tensor_data(bias), tensor_shape(constant), tensor_data(constant)); + + loco::replace(node).with(constant); + + return true; +} + +} // namespace + +namespace luci +{ + +/** + * Constant Folding for FullyConnected Op + **/ +bool FoldFullyConnectedPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto fc = dynamic_cast<CircleFullyConnected *>(node); + + if (fold_fully_connected(fc)) + changed = true; + } + + return changed; +} + +} // namespace luci + +#undef RETURN_FALSE_UNLESS diff --git a/compiler/luci/pass/src/FoldFullyConnectedPass.test.cpp b/compiler/luci/pass/src/FoldFullyConnectedPass.test.cpp new file mode 100644 index 000000000..a8e64a24b --- /dev/null +++ b/compiler/luci/pass/src/FoldFullyConnectedPass.test.cpp @@ -0,0 +1,160 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FoldFullyConnectedPass.h" +#include "PassTestGraphs.h" + +#include <luci/IR/CircleNodes.h> + +#include <limits> // std::numeric_limits + +#include <gtest/gtest.h> + +namespace +{ + +/** + * Graph has an FullyConnected Op with constant inputs + * + * BEFORE + * + * [CircleConst] [CircleConst] + * | | + * [CircleFullyConnected] + * + * AFTER + * + * [CircleConst] + */ +class FoldFullyConnectedTest : public luci::ConstantFoldingTestGraph, public ::testing::Test +{ +#define INPUT_DIM 80 +#define NUM_UNITS 32 + +public: + FoldFullyConnectedTest() : luci::ConstantFoldingTestGraph({INPUT_DIM}, loco::DataType::FLOAT32) + { + _fc = _g.nodes()->create<luci::CircleFullyConnected>(); + _fc_input = _g.nodes()->create<luci::CircleConst>(); + _fc_weights = _g.nodes()->create<luci::CircleConst>(); + _fc_bias = _g.nodes()->create<luci::CircleConst>(); + + _fc->dtype(loco::DataType::FLOAT32); + _fc->fusedActivationFunction(luci::FusedActFunc::NONE); + _fc->input(_fc_input); + _fc->weights(_fc_weights); + _fc->bias(_fc_bias); + _fc->shape({NUM_UNITS}); + _fc->weights_format(luci::CircleFullyConnected::WeightsFormat::DEFAULT); + _fc->keep_num_dims(true); + + _fc_input->dtype(loco::DataType::FLOAT32); + _fc_input->shape({INPUT_DIM}); + _fc_input->size<loco::DataType::FLOAT32>(INPUT_DIM); + + _fc_weights->dtype(loco::DataType::FLOAT32); + _fc_weights->shape({NUM_UNITS, INPUT_DIM}); + _fc_weights->size<loco::DataType::FLOAT32>(NUM_UNITS * INPUT_DIM); + + _fc_bias->dtype(loco::DataType::FLOAT32); + _fc_bias->shape({1, NUM_UNITS}); + _fc_bias->size<loco::DataType::FLOAT32>(NUM_UNITS); + + for (uint32_t i = 0; i < INPUT_DIM; ++i) + _fc_input->at<loco::DataType::FLOAT32>(i) = 1.0; + + for (uint32_t i = 0; i < INPUT_DIM * NUM_UNITS; ++i) + _fc_weights->at<loco::DataType::FLOAT32>(i) = 1.0; + + for (uint32_t i = 0; i < NUM_UNITS; ++i) + _fc_bias->at<loco::DataType::FLOAT32>(i) = 0.0; + + _output->from(_fc); + } + +protected: + void init() final {} + +protected: + loco::Node *createFoldedPattern() final { return nullptr; } + +protected: + luci::CircleConst *getFoldedPattern() final + { + return loco::must_cast<luci::CircleConst *>(_output->from()); + } + +protected: + luci::CircleFullyConnected *_fc = nullptr; + luci::CircleConst *_fc_input = nullptr; + luci::CircleConst *_fc_weights = nullptr; + luci::CircleConst *_fc_bias = nullptr; +#undef INPUT_DIM +#undef NUM_UNITS +}; + +} // namespace + +TEST_F(FoldFullyConnectedTest, fold_fc) +{ + luci::FoldFullyConnectedPass pass; + ASSERT_TRUE(pass.run(&_g)); + + auto folded_const = getFoldedPattern(); + EXPECT_EQ(folded_const->dtype(), loco::DataType::FLOAT32); + EXPECT_EQ(1, folded_const->rank()); + EXPECT_EQ(32, folded_const->dim(0)); + EXPECT_EQ(32, folded_const->size<loco::DataType::FLOAT32>()); + for (uint32_t i = 0; i < 32; ++i) + EXPECT_NEAR(folded_const->at<loco::DataType::FLOAT32>(i), 80, + std::numeric_limits<float>::min()); +} + +TEST_F(FoldFullyConnectedTest, fold_fc_no_bias) +{ + auto no_bias = _g.nodes()->create<luci::CircleOutputExclude>(); + _fc->bias(no_bias); + + luci::FoldFullyConnectedPass pass; + ASSERT_TRUE(pass.run(&_g)); + + auto folded_const = getFoldedPattern(); + EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype()); + EXPECT_EQ(1, folded_const->rank()); + EXPECT_EQ(32, folded_const->dim(0)); + EXPECT_EQ(32, folded_const->size<loco::DataType::FLOAT32>()); + for (uint32_t i = 0; i < 32; ++i) + EXPECT_NEAR(folded_const->at<loco::DataType::FLOAT32>(i), 80, + std::numeric_limits<float>::min()); +} + +TEST_F(FoldFullyConnectedTest, fold_fc_NEG) +{ + auto new_fc = _g.nodes()->create<luci::CircleFullyConnected>(); + _fc->input(new_fc); + + luci::FoldFullyConnectedPass pass; + ASSERT_FALSE(pass.run(&_g)); +} + +TEST_F(FoldFullyConnectedTest, fold_fc_weight_format_NEG) +{ + auto new_fc = _g.nodes()->create<luci::CircleFullyConnected>(); + _fc->weights_format(luci::CircleFullyConnected::WeightsFormat::SHUFFLED4x16INT8); + + luci::FoldFullyConnectedPass pass; + ASSERT_FALSE(pass.run(&_g)); +} diff --git a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp index bc09abee2..3494a6e60 100644 --- a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp +++ b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp @@ -76,6 +76,26 @@ luci::CircleReshape *create_cloned_reshape(luci::CircleReshape *reshape) return new_reshape; } +bool forward_reshape(luci::CircleReshape *reshape, luci::CircleAbs *abs) +{ + assert(reshape != nullptr); // FIX_CALLER_UNLESS + assert(abs != nullptr); // FIX_CALLER_UNLESS + + auto new_reshape = create_cloned_reshape(reshape); + if (not new_reshape) + return false; + + // reconnect network + loco::replace(abs).with(new_reshape); + abs->x(reshape->tensor()); + new_reshape->tensor(abs); + + // Do shape inference for this node again. + abs->shape_status(luci::ShapeStatus::UNDEFINED); + + return true; +} + bool forward_reshape(luci::CircleReshape *reshape, luci::CircleNeg *neg) { assert(reshape != nullptr); @@ -136,6 +156,14 @@ protected: return false; } + bool visit(luci::CircleAbs *node) + { + auto reshape = as_reshape(node->x()); + if (reshape == nullptr) + return false; + return forward_reshape(reshape, node); + } + bool visit(luci::CircleNeg *node) { auto reshape = as_reshape(node->x()); diff --git a/compiler/luci/pass/src/ForwardTransposeOpPass.cpp b/compiler/luci/pass/src/ForwardTransposeOpPass.cpp new file mode 100644 index 000000000..c76d73344 --- /dev/null +++ b/compiler/luci/pass/src/ForwardTransposeOpPass.cpp @@ -0,0 +1,366 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/ForwardTransposeOpPass.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNodeVisitor.h> +#include <luci/Profile/CircleNodeOrigin.h> +#include <luci/Service/Nodes/CircleConst.h> +#include <luci/Service/CircleNodeClone.h> + +using namespace luci; + +namespace +{ + +// Create new Transpose Op including perm +// Return nullptr if failed +CircleTranspose *create_cloned_transpose(CircleTranspose *transpose) +{ + assert(transpose != nullptr); // FIX_CALLER_UNLESS + + auto perm = dynamic_cast<CircleConst *>(transpose->perm()); + if (not perm) + return nullptr; + + CircleConst *cloned_perm = clone(perm); + if (cloned_perm == nullptr) + return nullptr; + + cloned_perm->name(perm->name() + "_C"); + luci::add_origin(cloned_perm, luci::get_origin(perm)); + + auto cloned_node = clone_node(transpose, transpose->graph()); + if (cloned_node == nullptr) + return nullptr; + + auto new_transpose = loco::must_cast<luci::CircleTranspose *>(cloned_node); + new_transpose->perm(cloned_perm); + new_transpose->name(transpose->name() + "_C"); + luci::add_origin(new_transpose, luci::get_origin(transpose)); + + return new_transpose; +} + +uint32_t cal_offset(const std::vector<uint32_t> &shape, const std::vector<uint32_t> &indices) +{ + assert(shape.size() == indices.size()); // FIX_CALLER_UNLESS + + uint32_t offset = 0; + for (uint32_t i = 0; i < indices.size(); i++) + { + uint32_t index = indices[i]; + for (uint32_t j = shape.size() - 1; j > i; j--) + { + index *= shape[j]; + } + offset += index; + } + return offset; +} + +// Return reverse-transpose of 'node' +// i.e., Transpose(return value) = node +CircleConst *reverse_transposed(CircleConst *node, std::vector<uint32_t> &t) +{ + assert(node->rank() == t.size()); // FIX_CALLER_UNLESS + assert(node->rank() == 4); // FIX_CALLER_UNLESS + + std::vector<uint32_t> orig_shape(node->rank()); + std::vector<uint32_t> new_shape(node->rank()); + + for (uint32_t i = 0; i < node->rank(); i++) + { + assert(t[i] < node->rank()); // FIX_CALLER_UNLESS + + orig_shape[i] = node->dim(i).value(); + new_shape[t[i]] = node->dim(i).value(); + } + + auto clone_const = clone(node); + for (uint32_t i = 0; i < node->rank(); i++) + clone_const->dim(i).set(new_shape[i]); + + clone_const->name(clone_const->name() + "_r_transposed"); + add_origin(clone_const, luci::get_origin(node)); + + for (uint32_t n = 0; n < clone_const->dim(0).value(); n++) + { + for (uint32_t h = 0; h < clone_const->dim(1).value(); h++) + { + for (uint32_t w = 0; w < clone_const->dim(2).value(); w++) + { + for (uint32_t c = 0; c < clone_const->dim(3).value(); c++) + { + std::vector<uint32_t> new_indices{n, h, w, c}; + std::vector<uint32_t> orig_indices{new_indices[t[0]], new_indices[t[1]], + new_indices[t[2]], new_indices[t[3]]}; + + const auto data = node->at<loco::DataType::FLOAT32>(cal_offset(orig_shape, orig_indices)); + clone_const->at<loco::DataType::FLOAT32>(cal_offset(new_shape, new_indices)) = data; + } + } + } + } + + return clone_const; +} + +bool check_rank_four(const CircleConst *c) { return c->rank() == 4; } + +// Return true if below conditions are met +// 1. t->perm() is CircleConst +// 2. t->perm() is S32 +bool check_perm(const CircleTranspose *t) +{ + auto perm = dynamic_cast<CircleConst *>(t->perm()); + if (not perm) + return false; + + switch (perm->dtype()) + { + case loco::DataType::S32: + for (uint32_t i = 0; i < perm->size<loco::DataType::S32>(); i++) + { + auto data = perm->at<loco::DataType::S32>(i); + // TODO Support not normalized index + if (data < 0 or data >= static_cast<int32_t>(t->rank())) + return false; + } + break; + // TODO Support S64 data type + default: + return false; + } + + return true; +} + +#define RETURN_FALSE_UNLESS(COND) \ + if (not(COND)) \ + return false; + +// Elementwise Binary Operator with const +class EBOWithConstPattern final : public CircleNodeMutableVisitor<bool> +{ +private: + template <typename CIRCLE_OP_PTR> bool has_pattern(CIRCLE_OP_PTR node) + { + if (auto x = dynamic_cast<luci::CircleConst *>(node->x())) + { + if (auto y = dynamic_cast<luci::CircleTranspose *>(node->y())) + { + RETURN_FALSE_UNLESS(check_rank_four(x)); + RETURN_FALSE_UNLESS(check_perm(y)); + + auto new_const = gen_new_const(y, x); + assert(new_const); // FIX_ME_UNLESS + + auto new_transpose = create_cloned_transpose(y); + assert(new_transpose); // FIX_ME_UNLESS + + // Reconnect network + node->x(new_const); + node->y(y->a()); + loco::replace(node).with(new_transpose); + new_transpose->a(node); + + // Do shape inference for this node again. + node->shape_status(luci::ShapeStatus::UNDEFINED); + + return true; + } + } + + if (auto y = dynamic_cast<luci::CircleConst *>(node->y())) + { + if (auto x = dynamic_cast<luci::CircleTranspose *>(node->x())) + { + RETURN_FALSE_UNLESS(check_rank_four(y)); + RETURN_FALSE_UNLESS(check_perm(x)); + + auto new_const = gen_new_const(x, y); + assert(new_const); // FIX_ME_UNLESS + + auto new_transpose = create_cloned_transpose(x); + assert(new_transpose); // FIX_ME_UNLESS + + // Reconnect network + node->y(new_const); + node->x(x->a()); + loco::replace(node).with(new_transpose); + new_transpose->a(node); + + // Do shape inference for this node again. + node->shape_status(luci::ShapeStatus::UNDEFINED); + + return true; + } + } + + return false; + } + +public: + // Default + bool visit(luci::CircleNode *) { return false; } + + bool visit(luci::CircleAdd *node) { return has_pattern(node); } + + bool visit(luci::CircleMul *node) { return has_pattern(node); } + +private: + // Return a new const node after Tranpose Op is forwarded + // Return nullptr if unsupported cases + CircleConst *gen_new_const(CircleTranspose *t, CircleConst *c) + { + const auto perm = dynamic_cast<CircleConst *>(t->perm()); + + // Only support constant perm + if (not perm) + return nullptr; + + std::vector<uint32_t> perm_data; + switch (perm->dtype()) + { + case loco::DataType::S32: + for (uint32_t i = 0; i < perm->size<loco::DataType::S32>(); i++) + { + auto data = perm->at<loco::DataType::S32>(i); + assert(data >= 0 and data < static_cast<int32_t>(t->rank())); + perm_data.emplace_back(static_cast<uint32_t>(data)); + } + break; + // TODO Support S64 data type + default: + return nullptr; + } + + assert(perm_data.size() == t->rank()); // FIX_CALLER_UNLESS + + return reverse_transposed(c, perm_data); + } +}; + +// Elementwise Unary Operator +class EwUnaryPattern final : public CircleNodeMutableVisitor<bool> +{ +private: + // input is 'x' + template <typename CIRCLE_OP_PTR> bool has_pattern_x(CIRCLE_OP_PTR node) + { + if (auto x = dynamic_cast<luci::CircleTranspose *>(node->x())) + { + RETURN_FALSE_UNLESS(check_perm(x)); + + auto new_transpose = create_cloned_transpose(x); + assert(new_transpose); // FIX_ME_UNLESS + + // Reconnect network + node->x(x->a()); + loco::replace(node).with(new_transpose); + new_transpose->a(node); + + // Do shape inference for this node again. + node->shape_status(luci::ShapeStatus::UNDEFINED); + + return true; + } + + return false; + } + +public: + // Default + bool visit(luci::CircleNode *) { return false; } + + bool visit(luci::CircleAbs *node) { return has_pattern_x(node); } +}; + +} // namespace + +namespace luci +{ + +/** + * BEFORE + * | + * [CircleNode] [CircleConst] + * | / + * [CircleTranspose] [CircleConst] + * / | / + * [CircleNode] [(BinaryOp)] + * | | \ + * | | [CircleNode] + * | | | + * + * BinaryOp: CircleAdd, CircleMul, ... + * + * | + * [CircleNode] [CircleConst] + * | / + * [CircleTranspose] + * / | + * [CircleNode] [(UnaryOp)] + * | | \ + * | | [CircleNode] + * | | | + * + * UnaryOp: CircleAbs, ... + * + * AFTER + * | + * [CircleConst] [CircleNode] [CircleConst(updated)] + * | / | / + * [CircleTranspose] [(BinaryOp)] [CircleConst] + * | | / + * [CircleNode] [CircleTranspose] + * | | \ + * | | [CircleNode] + * | | | + * + * | + * [CircleConst] [CircleNode] + * | / | + * [CircleTranspose] [(UnaryOp)] [CircleConst] + * | | / + * [CircleNode] [CircleTranspose] + * | | \ + * | | [CircleNode] + * | | | + * + * Note: new [CircleTranspose] is added after [(BinaryOp)] + */ +bool ForwardTransposeOpPass::run(loco::Graph *g) +{ + bool changed = false; + EBOWithConstPattern eboc; + EwUnaryPattern ewu; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + if (circle_node->accept(&eboc)) + changed = true; + else if (circle_node->accept(&ewu)) + changed = true; + } + return changed; +} + +#undef RETURN_FALSE_UNLESS + +} // namespace luci diff --git a/compiler/luci/pass/src/ForwardTransposeOpPass.test.cpp b/compiler/luci/pass/src/ForwardTransposeOpPass.test.cpp new file mode 100644 index 000000000..2d061c2a3 --- /dev/null +++ b/compiler/luci/pass/src/ForwardTransposeOpPass.test.cpp @@ -0,0 +1,524 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/ForwardTransposeOpPass.h" +#include "luci/Pass/CircleShapeInferencePass.h" + +#include <logo/Phase.h> +#include <luci/IR/CircleNodes.h> +#include <luci/test/TestIOGraph.h> + +#include <gtest/gtest.h> + +#include <vector> + +namespace +{ + +using namespace luci::test; + +template <typename T> class TransposeBinaryOpGraphlet +{ +public: + TransposeBinaryOpGraphlet() = default; + +public: + virtual ~TransposeBinaryOpGraphlet() = default; + +public: + void init(loco::Graph *g, const ShapeU32 shape_in, const ShapeU32 perm) + { + std::vector<uint32_t> shape_in_v = shape_in; + std::vector<uint32_t> perm_v = perm; + + assert(shape_in_v.size() == perm_v.size()); // FIX_CALLER_UNLESS + + _perm = g->nodes()->create<luci::CircleConst>(); + _const = g->nodes()->create<luci::CircleConst>(); + _transpose = g->nodes()->create<luci::CircleTranspose>(); + _binary = g->nodes()->create<T>(); + + _perm->dtype(loco::DataType::S32); + _perm->rank(1); + _perm->dim(0).set(perm_v.size()); + _perm->shape_status(luci::ShapeStatus::VALID); + + _const->dtype(loco::DataType::FLOAT32); + _const->rank(shape_in_v.size()); + for (uint32_t i = 0; i < shape_in_v.size(); i++) + _const->dim(i).set(shape_in_v[perm_v[i]]); + _const->shape_status(luci::ShapeStatus::VALID); + + // values + const auto size = perm_v.size(); + _perm->size<loco::DataType::S32>(size); + for (uint32_t i = 0; i < size; i++) + _perm->at<loco::DataType::S32>(i) = perm_v[i]; + + uint32_t elems = 1; + for (uint32_t i = 0; i < size; i++) + elems *= shape_in_v[i]; + + _const->size<loco::DataType::FLOAT32>(elems); + for (uint32_t i = 0; i < elems; i++) + _const->at<loco::DataType::FLOAT32>(i) = i; + + _perm->name("transpose_perm"); + _transpose->name("transpose"); + _binary->name("binary"); + } + + luci::CircleTranspose *transpose(void) { return _transpose; } + + void switch_xy(void) + { + assert(_binary); // FIX_CALLER_UNLESS + auto temp = _binary->x(); + _binary->x(_binary->y()); + _binary->y(temp); + } + +protected: + luci::CircleTranspose *_transpose = nullptr; + T *_binary = nullptr; + luci::CircleConst *_perm = nullptr; + luci::CircleConst *_const = nullptr; +}; + +using TransposeAddGraphlet = TransposeBinaryOpGraphlet<luci::CircleAdd>; +using TransposeMulGraphlet = TransposeBinaryOpGraphlet<luci::CircleMul>; + +class ForwardTransposeToAddGraph : public TestIOGraph, public TransposeAddGraphlet +{ +public: + void init(const ShapeU32 shape_in, const ShapeU32 shape_out) + { + TestIOGraph::init(shape_in, shape_out); + TransposeAddGraphlet::init(g(), shape_in, shape_out); + + // connect network + _transpose->a(input()); + _transpose->perm(_perm); + _binary->x(_transpose); + _binary->y(_const); + + output()->from(_binary); + } +}; + +class ForwardTransposeToAddInvalidGraph : public TestIOGraph, public TransposeAddGraphlet +{ +public: + void init(const ShapeU32 shape_in, const ShapeU32 shape_out) + { + TestIOGraph::init(shape_in, shape_out); + TransposeAddGraphlet::init(g(), shape_in, shape_out); + + // connect network + _transpose->a(input()); + _transpose->perm(_perm); + _binary->x(_transpose); + _binary->y(_transpose); + + output()->from(_binary); + } +}; + +class ForwardTransposeToMulGraph : public TestIOGraph, public TransposeMulGraphlet +{ +public: + void init(const ShapeU32 shape_in, const ShapeU32 shape_out) + { + TestIOGraph::init(shape_in, shape_out); + TransposeMulGraphlet::init(g(), shape_in, shape_out); + + // connect network + _transpose->a(input()); + _transpose->perm(_perm); + _binary->x(_transpose); + _binary->y(_const); + + output()->from(_binary); + } +}; + +void run_phase(loco::Graph *g) +{ + logo::Phase phase; + + // Default passes. + phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>()); + + // Pass to test + phase.emplace_back(std::make_unique<luci::ForwardTransposeOpPass>()); + + logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g}; + phase_runner.run(phase); +} + +class ForwardTransposeToAddGraphTest : public ::testing::Test +{ +public: + void run_pass(void) { run_phase(_graph.g()); } + +protected: + ForwardTransposeToAddGraph _graph; +}; + +class ForwardTransposeToAddGraphNegTest : public ::testing::Test +{ +public: + void run_pass(void) { run_phase(_graph.g()); } + +protected: + ForwardTransposeToAddInvalidGraph _graph; +}; + +class ForwardTransposeToMulGraphTest : public ::testing::Test +{ +public: + void run_pass(void) { run_phase(_graph.g()); } + +protected: + ForwardTransposeToMulGraph _graph; +}; + +} // namespace + +TEST_F(ForwardTransposeToAddGraphTest, forward_add_xy) +{ + _graph.init({1, 64, 51, 1}, {0, 3, 2, 1}); + + run_pass(); + + auto transpose = dynamic_cast<luci::CircleTranspose *>(_graph.output()->from()); + EXPECT_NE(nullptr, transpose); + EXPECT_EQ(4, transpose->rank()); + EXPECT_EQ(1, transpose->dim(0).value()); + EXPECT_EQ(1, transpose->dim(1).value()); + EXPECT_EQ(51, transpose->dim(2).value()); + EXPECT_EQ(64, transpose->dim(3).value()); + + auto add = dynamic_cast<luci::CircleAdd *>(transpose->a()); + EXPECT_NE(nullptr, add); + EXPECT_EQ(4, add->rank()); + EXPECT_EQ(1, add->dim(0).value()); + EXPECT_EQ(64, add->dim(1).value()); + EXPECT_EQ(51, add->dim(2).value()); + EXPECT_EQ(1, add->dim(3).value()); + + auto add_const = dynamic_cast<luci::CircleConst *>(add->y()); + EXPECT_NE(nullptr, add_const); + EXPECT_EQ(4, add_const->rank()); + EXPECT_EQ(1, add_const->dim(0).value()); + EXPECT_EQ(64, add_const->dim(1).value()); + EXPECT_EQ(51, add_const->dim(2).value()); + EXPECT_EQ(1, add_const->dim(3).value()); +} + +TEST_F(ForwardTransposeToAddGraphTest, forward_add_yx) +{ + _graph.init({1, 64, 51, 1}, {0, 3, 2, 1}); + _graph.switch_xy(); + + run_pass(); + + auto transpose = dynamic_cast<luci::CircleTranspose *>(_graph.output()->from()); + EXPECT_NE(nullptr, transpose); + EXPECT_EQ(4, transpose->rank()); + EXPECT_EQ(1, transpose->dim(0).value()); + EXPECT_EQ(1, transpose->dim(1).value()); + EXPECT_EQ(51, transpose->dim(2).value()); + EXPECT_EQ(64, transpose->dim(3).value()); + + auto mul = dynamic_cast<luci::CircleAdd *>(transpose->a()); + EXPECT_NE(nullptr, mul); + EXPECT_EQ(4, mul->rank()); + EXPECT_EQ(1, mul->dim(0).value()); + EXPECT_EQ(64, mul->dim(1).value()); + EXPECT_EQ(51, mul->dim(2).value()); + EXPECT_EQ(1, mul->dim(3).value()); + + auto mul_const = dynamic_cast<luci::CircleConst *>(mul->x()); + EXPECT_NE(nullptr, mul_const); + EXPECT_EQ(4, mul_const->rank()); + EXPECT_EQ(1, mul_const->dim(0).value()); + EXPECT_EQ(64, mul_const->dim(1).value()); + EXPECT_EQ(51, mul_const->dim(2).value()); + EXPECT_EQ(1, mul_const->dim(3).value()); +} + +TEST_F(ForwardTransposeToMulGraphTest, forward_mul_xy) +{ + _graph.init({1, 64, 51, 1}, {0, 3, 2, 1}); + + run_pass(); + + auto transpose = dynamic_cast<luci::CircleTranspose *>(_graph.output()->from()); + EXPECT_NE(nullptr, transpose); + EXPECT_EQ(4, transpose->rank()); + EXPECT_EQ(1, transpose->dim(0).value()); + EXPECT_EQ(1, transpose->dim(1).value()); + EXPECT_EQ(51, transpose->dim(2).value()); + EXPECT_EQ(64, transpose->dim(3).value()); + + auto mul = dynamic_cast<luci::CircleMul *>(transpose->a()); + EXPECT_NE(nullptr, mul); + EXPECT_EQ(4, mul->rank()); + EXPECT_EQ(1, mul->dim(0).value()); + EXPECT_EQ(64, mul->dim(1).value()); + EXPECT_EQ(51, mul->dim(2).value()); + EXPECT_EQ(1, mul->dim(3).value()); + + auto mul_const = dynamic_cast<luci::CircleConst *>(mul->y()); + EXPECT_NE(nullptr, mul_const); + EXPECT_EQ(4, mul_const->rank()); + EXPECT_EQ(1, mul_const->dim(0).value()); + EXPECT_EQ(64, mul_const->dim(1).value()); + EXPECT_EQ(51, mul_const->dim(2).value()); + EXPECT_EQ(1, mul_const->dim(3).value()); +} + +TEST_F(ForwardTransposeToMulGraphTest, forward_mul_yx) +{ + _graph.init({1, 64, 51, 1}, {0, 3, 2, 1}); + _graph.switch_xy(); + + run_pass(); + + auto transpose = dynamic_cast<luci::CircleTranspose *>(_graph.output()->from()); + EXPECT_NE(nullptr, transpose); + EXPECT_EQ(4, transpose->rank()); + EXPECT_EQ(1, transpose->dim(0).value()); + EXPECT_EQ(1, transpose->dim(1).value()); + EXPECT_EQ(51, transpose->dim(2).value()); + EXPECT_EQ(64, transpose->dim(3).value()); + + auto mul = dynamic_cast<luci::CircleMul *>(transpose->a()); + EXPECT_NE(nullptr, mul); + EXPECT_EQ(4, mul->rank()); + EXPECT_EQ(1, mul->dim(0).value()); + EXPECT_EQ(64, mul->dim(1).value()); + EXPECT_EQ(51, mul->dim(2).value()); + EXPECT_EQ(1, mul->dim(3).value()); + + auto mul_const = dynamic_cast<luci::CircleConst *>(mul->x()); + EXPECT_NE(nullptr, mul_const); + EXPECT_EQ(4, mul_const->rank()); + EXPECT_EQ(1, mul_const->dim(0).value()); + EXPECT_EQ(64, mul_const->dim(1).value()); + EXPECT_EQ(51, mul_const->dim(2).value()); + EXPECT_EQ(1, mul_const->dim(3).value()); +} + +TEST_F(ForwardTransposeToAddGraphTest, forward_transpose_add_NEG) +{ + _graph.init({1, 64, 51, 1}, {0, 3, 2, 1}); + + // Remove add + _graph.output()->from(_graph.transpose()); + + luci::ForwardTransposeOpPass pass; + EXPECT_FALSE(pass.run(_graph.g())); +} + +TEST_F(ForwardTransposeToAddGraphNegTest, forward_transpose_add_non_const_NEG) +{ + _graph.init({1, 64, 51, 1}, {0, 3, 2, 1}); + + luci::ForwardTransposeOpPass pass; + EXPECT_FALSE(pass.run(_graph.g())); +} + +TEST_F(ForwardTransposeToMulGraphTest, forward_transpose_mul_NEG) +{ + _graph.init({1, 64, 51, 1}, {0, 3, 2, 1}); + + // Remove mul + _graph.output()->from(_graph.transpose()); + + luci::ForwardTransposeOpPass pass; + EXPECT_FALSE(pass.run(_graph.g())); +} + +// Unary + +namespace +{ + +template <typename T> class TransposeUnaryOpGraphlet +{ +public: + TransposeUnaryOpGraphlet() = default; + +public: + virtual ~TransposeUnaryOpGraphlet() = default; + +public: + void init(loco::Graph *g, const ShapeU32 shape_in, const ShapeU32 perm) + { + std::vector<uint32_t> shape_in_v = shape_in; + std::vector<uint32_t> perm_v = perm; + + assert(shape_in_v.size() == perm_v.size()); // FIX_CALLER_UNLESS + + _perm = g->nodes()->create<luci::CircleConst>(); + _const = g->nodes()->create<luci::CircleConst>(); + _transpose = g->nodes()->create<luci::CircleTranspose>(); + _unary = g->nodes()->create<T>(); + + _perm->dtype(loco::DataType::S32); + _perm->rank(1); + _perm->dim(0).set(perm_v.size()); + _perm->shape_status(luci::ShapeStatus::VALID); + + _const->dtype(loco::DataType::FLOAT32); + _const->rank(shape_in_v.size()); + for (uint32_t i = 0; i < shape_in_v.size(); i++) + _const->dim(i).set(shape_in_v[perm_v[i]]); + _const->shape_status(luci::ShapeStatus::VALID); + + // values + const auto size = perm_v.size(); + _perm->size<loco::DataType::S32>(size); + for (uint32_t i = 0; i < size; i++) + _perm->at<loco::DataType::S32>(i) = perm_v[i]; + + uint32_t elems = 1; + for (uint32_t i = 0; i < size; i++) + elems *= shape_in_v[i]; + + _const->size<loco::DataType::FLOAT32>(elems); + for (uint32_t i = 0; i < elems; i++) + _const->at<loco::DataType::FLOAT32>(i) = i; + + _perm->name("transpose_perm"); + _transpose->name("transpose"); + _unary->name("_unary"); + } + + luci::CircleTranspose *transpose(void) { return _transpose; } + +protected: + luci::CircleTranspose *_transpose = nullptr; + T *_unary = nullptr; + luci::CircleConst *_perm = nullptr; + luci::CircleConst *_const = nullptr; +}; + +using TransposeAbsGraphlet = TransposeUnaryOpGraphlet<luci::CircleAbs>; + +class ForwardTransposeToAbsGraph : public TestIOGraph, public TransposeAbsGraphlet +{ +public: + void init(const ShapeU32 shape_in, const ShapeU32 shape_out) + { + TestIOGraph::init(shape_in, shape_out); + TransposeAbsGraphlet::init(g(), shape_in, shape_out); + + // connect network + _transpose->a(input()); + _transpose->perm(_perm); + _unary->x(_transpose); + + output()->from(_unary); + } +}; + +class ForwardTransposeToAbsInvalidGraph : public TestIOGraph, public TransposeAbsGraphlet +{ +public: + void init(const ShapeU32 shape_in, const ShapeU32 shape_out) + { + TestIOGraph::init(shape_in, shape_out); + TransposeAbsGraphlet::init(g(), shape_in, shape_out); + + _relu = g()->nodes()->create<luci::CircleRelu>(); + _relu->dtype(loco::DataType::FLOAT32); + _relu->name("relu"); + + // connect network + _relu->features(input()); + _unary->x(_relu); + + output()->from(_unary); + } + +protected: + luci::CircleRelu *_relu = nullptr; +}; + +class ForwardTransposeToAbsGraphTest : public ::testing::Test +{ +public: + void run_pass(void) { run_phase(_graph.g()); } + +protected: + ForwardTransposeToAbsGraph _graph; +}; + +class ForwardTransposeToAbsGraphNegTest : public ::testing::Test +{ +public: + void run_pass(void) { run_phase(_graph.g()); } + +protected: + ForwardTransposeToAbsInvalidGraph _graph; +}; + +} // namespace + +TEST_F(ForwardTransposeToAbsGraphTest, forward_abs_x) +{ + _graph.init({1, 64, 51, 1}, {0, 3, 2, 1}); + + run_pass(); + + auto transpose = dynamic_cast<luci::CircleTranspose *>(_graph.output()->from()); + EXPECT_NE(nullptr, transpose); + EXPECT_EQ(4, transpose->rank()); + EXPECT_EQ(1, transpose->dim(0).value()); + EXPECT_EQ(1, transpose->dim(1).value()); + EXPECT_EQ(51, transpose->dim(2).value()); + EXPECT_EQ(64, transpose->dim(3).value()); + + auto abs = dynamic_cast<luci::CircleAbs *>(transpose->a()); + EXPECT_NE(nullptr, abs); + EXPECT_EQ(4, abs->rank()); + EXPECT_EQ(1, abs->dim(0).value()); + EXPECT_EQ(64, abs->dim(1).value()); + EXPECT_EQ(51, abs->dim(2).value()); + EXPECT_EQ(1, abs->dim(3).value()); +} + +TEST_F(ForwardTransposeToAbsGraphTest, forward_transpose_abs_NEG) +{ + _graph.init({1, 64, 51, 1}, {0, 3, 2, 1}); + + // Remove abs + _graph.output()->from(_graph.transpose()); + + luci::ForwardTransposeOpPass pass; + EXPECT_FALSE(pass.run(_graph.g())); +} + +TEST_F(ForwardTransposeToAbsGraphNegTest, forward_transpose_abs_non_transpose_NEG) +{ + _graph.init({1, 64, 51, 1}, {0, 3, 2, 1}); + + luci::ForwardTransposeOpPass pass; + EXPECT_FALSE(pass.run(_graph.g())); +} diff --git a/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp index 3cf31ed10..1d4a2e3bf 100644 --- a/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp +++ b/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp @@ -86,6 +86,14 @@ bool fuse_add_with_fc(luci::CircleFullyConnected *fc) if (not(addition->dim(rank - 1) == weights->dim(0))) return false; + auto bias = loco::must_cast<luci::CircleNode *>(fc->bias()); + + // We only support (1) constant bias (2) no bias + // If bias is neither (1) nor (2), it would be a feature map + if (bias->opcode() != luci::CircleOpcode::CIRCLECONST and + bias->opcode() != luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE) + return false; + auto fused_bias = luci::clone(addition); // Add existing bias values diff --git a/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.test.cpp b/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.test.cpp index 4cc2eb599..300796594 100644 --- a/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.test.cpp +++ b/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.test.cpp @@ -125,6 +125,15 @@ public: public: luci::CircleFullyConnected *fc() { return _fc; } +public: + void to_fm_bias(void) + { + assert(_fc != nullptr); // FIX_ME_UNLESS + + auto new_fc = _fc->graph()->nodes()->create<luci::CircleFullyConnected>(); + _fc->bias(new_fc); + } + protected: luci::CircleFullyConnected *_fc = nullptr; luci::CircleAdd *_add = nullptr; @@ -174,3 +183,14 @@ TEST_F(FuseAddWithFullyConnectedPassTest, simple_test) EXPECT_EQ(i, bias->at<loco::DataType::FLOAT32>(i)); } } + +TEST_F(FuseAddWithFullyConnectedPassTest, fm_bias_NEG) +{ + g.init(); + + // Bias is a feature map. Add is not fused. + g.to_fm_bias(); + + auto ret = pass.run(g.g()); + EXPECT_EQ(false, ret); +} diff --git a/compiler/luci/pass/src/FuseBCQPass.cpp b/compiler/luci/pass/src/FuseBCQPass.cpp index 09180d8c1..3f8f700a9 100644 --- a/compiler/luci/pass/src/FuseBCQPass.cpp +++ b/compiler/luci/pass/src/FuseBCQPass.cpp @@ -679,7 +679,6 @@ bool FuseBCQPass::run(luci::Module *m) if (output_node->index() == 0 || (int)output_node->index() > original_output_cnt) { auto noOp = main_graph->nodes()->create<luci::CircleOutputExclude>(); - noOp->dtype(loco::DataType::FLOAT32); // TODO Remove this setting output_node->from(noOp); changed = true; } diff --git a/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp b/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp index e6b54df36..265a8398b 100644 --- a/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp +++ b/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp @@ -23,6 +23,26 @@ namespace { + +template <class CIRCLENODE> +void replace_with_relu(luci::CircleNode *target, luci::CircleNode *feature, + const std::string &relu_name) +{ + assert(target != nullptr); + assert(feature != nullptr); + + auto relu = target->graph()->nodes()->create<CIRCLENODE>(); + relu->features(feature); + relu->name(relu_name); + luci::add_origin(relu, luci::get_origin(target)); + + replace(target).with(relu); +} + +} // namespace + +namespace +{ /** * Fuse Mul-Add to TransposeConv if possible. * @@ -49,10 +69,10 @@ namespace * | / / | / * [CircleTransposeConv] [CircleAdd] * | - * ([CircleRelu6]) + * ([CircleRelu]/[CircleRelu6]) * | * - * Note: CircleRelu6 is inserted if Add activation is ReLU6 + * Note: CircleRelu or CircleRelu6 is inserted if Add activation is ReLU/ReLU6 */ bool fused_batch_norm_with_tconv(luci::CircleAdd *add) { @@ -80,7 +100,8 @@ bool fused_batch_norm_with_tconv(luci::CircleAdd *add) if (add->dtype() != loco::DataType::FLOAT32) return false; if (add->fusedActivationFunction() != luci::FusedActFunc::NONE && - add->fusedActivationFunction() != luci::FusedActFunc::RELU6) + add->fusedActivationFunction() != luci::FusedActFunc::RELU6 && + add->fusedActivationFunction() != luci::FusedActFunc::RELU) return false; // tconv bias is optional @@ -202,19 +223,23 @@ bool fused_batch_norm_with_tconv(luci::CircleAdd *add) luci::add_origin(fused_tconv, luci::get_origin(bias)); } - if (add->fusedActivationFunction() == luci::FusedActFunc::RELU6) + switch (add->fusedActivationFunction()) { - // separate relu op from add op - auto relu = add->graph()->nodes()->create<luci::CircleRelu6>(); - relu->features(fused_tconv); - relu->name(name + "/Relu6"); - luci::add_origin(relu, luci::get_origin(add)); + case luci::FusedActFunc::RELU6: + replace_with_relu<luci::CircleRelu6>(add, fused_tconv, name + "/Relu6"); + break; - replace(add).with(relu); - } - else - { - replace(add).with(fused_tconv); + case luci::FusedActFunc::RELU: + replace_with_relu<luci::CircleRelu>(add, fused_tconv, name + "/Relu"); + break; + + case luci::FusedActFunc::NONE: + replace(add).with(fused_tconv); + break; + + default: + assert(false); + break; } return true; diff --git a/compiler/luci/pass/src/FusePReluPass.cpp b/compiler/luci/pass/src/FusePReluPass.cpp new file mode 100644 index 000000000..a5ce60ebf --- /dev/null +++ b/compiler/luci/pass/src/FusePReluPass.cpp @@ -0,0 +1,202 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FusePReluPass.h" +#include "helpers/NodeFiller.h" + +#include <luci/IR/CircleNodes.h> + +#include <luci/Profile/CircleNodeOrigin.h> +#include <luci/Service/CircleNodeClone.h> + +#include <cassert> + +// Helper to fuse PRelu +namespace +{ + +/** + * Below diagram shows PRelu pattern to fuse. + * - this pattern will be replaced with one PRelu + * + * [In] + * | + * V + * +---- ifm ----+ + * | | | + * | | V + * | | abs + * | V | + * | sub <---+ + * | | + * | V + * | mul_alpha (alpha of PRelu) + * | | + * V V + * relu mul_half (0.5) + * | | + * | V + * +---> add + * | + * V + * [Out] + * + */ +class PReluPattern final +{ +public: + PReluPattern(luci::CircleAdd *candidate) + { + assert(candidate); + _add_ofm = candidate; + } + +public: + bool matched(); + +public: + luci::CircleNode *_ifm = nullptr; + luci::CircleRelu *_relu = nullptr; + luci::CircleAbs *_abs = nullptr; + luci::CircleSub *_sub = nullptr; + luci::CircleMul *_mul_alpha = nullptr; + luci::CircleMul *_mul_half = nullptr; + luci::CircleAdd *_add_ofm = nullptr; + luci::CircleConst *_const_alpha = nullptr; + luci::CircleConst *_const_half = nullptr; +}; + +#define CHECK_OR_FALSE(condition) \ + if (not(condition)) \ + return false; + +bool PReluPattern::matched() +{ + // check pattern + CHECK_OR_FALSE(luci::fill(&_relu, &_mul_half).with_commutative_args_of(_add_ofm)); + CHECK_OR_FALSE(luci::fill(&_mul_alpha, &_const_half).with_commutative_args_of(_mul_half)); + CHECK_OR_FALSE(luci::fill(&_sub, &_const_alpha).with_commutative_args_of(_mul_alpha)); + + CHECK_OR_FALSE(luci::fill(&_ifm, &_abs).with_args_of(_sub)); + + CHECK_OR_FALSE(_relu->features() == _ifm); + CHECK_OR_FALSE(_abs->x() == _ifm); + + // Check Activation to be NONE + CHECK_OR_FALSE(_sub->fusedActivationFunction() == luci::FusedActFunc::NONE); + CHECK_OR_FALSE(_mul_alpha->fusedActivationFunction() == luci::FusedActFunc::NONE); + CHECK_OR_FALSE(_mul_half->fusedActivationFunction() == luci::FusedActFunc::NONE); + CHECK_OR_FALSE(_add_ofm->fusedActivationFunction() == luci::FusedActFunc::NONE); + + // TODO support other types? + // check if _const_half is really FLOAT32 & 0.5 + CHECK_OR_FALSE(_const_half->dtype() == loco::DataType::FLOAT32); + CHECK_OR_FALSE(_const_half->size<loco::DataType::FLOAT32>() == 1); + CHECK_OR_FALSE(_const_half->at<loco::DataType::FLOAT32>(0) == 0.5); + + // check _const_alpha condition + CHECK_OR_FALSE(_const_alpha->dtype() == loco::DataType::FLOAT32); + // TODO add more if needed + + return true; +} + +#undef CHECK_OR_FALSE + +class FusePRelu final +{ +public: + FusePRelu(const PReluPattern &p) : _p(p) {} + +public: + void apply(void); + +private: + luci::CirclePRelu *create_prelu(loco::Graph *graph); + +private: + const PReluPattern &_p; +}; + +luci::CirclePRelu *FusePRelu::create_prelu(loco::Graph *graph) +{ + assert(graph); + + auto prelu = graph->nodes()->create<luci::CirclePRelu>(); + prelu->input(_p._ifm); + prelu->alpha(_p._const_alpha); + prelu->name(_p._add_ofm->name() + "_prelu"); + return prelu; +} + +void FusePRelu::apply() +{ + auto graph = _p._add_ofm->graph(); + + auto prelu = create_prelu(graph); + + // set origin + std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{ + luci::get_origin(_p._relu), luci::get_origin(_p._abs), luci::get_origin(_p._sub), + luci::get_origin(_p._mul_alpha), luci::get_origin(_p._mul_half), luci::get_origin(_p._add_ofm)}; + + luci::add_origin(prelu, luci::composite_origin(origin_vec)); + + replace(_p._add_ofm).with(prelu); +} + +} // namespace + +namespace +{ + +bool fuse_prelu(luci::CircleAdd *add) +{ + assert(add); + + PReluPattern pattern(add); + if (pattern.matched()) + { + FusePRelu fuse(pattern); + fuse.apply(); + return true; + } + return false; +} + +} // namespace + +namespace luci +{ + +bool FusePReluPass::run(loco::Graph *g) +{ + bool changed = false; + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto add = dynamic_cast<luci::CircleAdd *>(node); + if (not add) + continue; + + if (fuse_prelu(add)) + changed = true; + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FusePReluPass.test.cpp b/compiler/luci/pass/src/FusePReluPass.test.cpp new file mode 100644 index 000000000..209fe3911 --- /dev/null +++ b/compiler/luci/pass/src/FusePReluPass.test.cpp @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FusePReluPass.h" + +#include <luci/IR/CircleNodes.h> + +#include <luci/test/TestIOGraph.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class PReluGraphlet +{ +public: + PReluGraphlet() = default; + + void init(loco::Graph *g) + { + _abs = g->nodes()->create<luci::CircleAbs>(); + _sub = g->nodes()->create<luci::CircleSub>(); + _mul_alpha = g->nodes()->create<luci::CircleMul>(); + _mul_half = g->nodes()->create<luci::CircleMul>(); + _relu = g->nodes()->create<luci::CircleRelu>(); + _add = g->nodes()->create<luci::CircleAdd>(); + _const_alpha = g->nodes()->create<luci::CircleConst>(); + _const_half = g->nodes()->create<luci::CircleConst>(); + + _sub->fusedActivationFunction(luci::FusedActFunc::NONE); + _mul_alpha->fusedActivationFunction(luci::FusedActFunc::NONE); + _mul_half->fusedActivationFunction(luci::FusedActFunc::NONE); + _add->fusedActivationFunction(luci::FusedActFunc::NONE); + + _abs->name("abs"); + _sub->name("sub"); + _mul_alpha->name("mul_alpha"); + _mul_half->name("mul_half"); + _relu->name("relu"); + _add->name("add"); + _const_alpha->name("const_alpha"); + _const_half->name("const_half"); + + _const_alpha->dtype(loco::DataType::FLOAT32); + _const_alpha->size<loco::DataType::FLOAT32>(1); + _const_alpha->shape({1}); + _const_alpha->at<loco::DataType::FLOAT32>(0) = 0.1; + _const_alpha->shape_status(luci::ShapeStatus::VALID); + + _const_half->dtype(loco::DataType::FLOAT32); + _const_half->size<loco::DataType::FLOAT32>(1); + _const_half->shape({1}); + _const_half->at<loco::DataType::FLOAT32>(0) = 0.5; + _const_half->shape_status(luci::ShapeStatus::VALID); + } + + void invalid_half() { _const_half->at<loco::DataType::FLOAT32>(0) = 0.1; } + void invalid_act() { _add->fusedActivationFunction(luci::FusedActFunc::RELU); } + +protected: + luci::CircleAbs *_abs = nullptr; + luci::CircleSub *_sub = nullptr; + luci::CircleMul *_mul_alpha = nullptr; + luci::CircleMul *_mul_half = nullptr; + luci::CircleRelu *_relu = nullptr; + luci::CircleAdd *_add = nullptr; + luci::CircleConst *_const_alpha = nullptr; + luci::CircleConst *_const_half = nullptr; +}; + +class FusePReluTestGraph : public TestIOGraph, public PReluGraphlet +{ +public: + FusePReluTestGraph() = default; + + void init(void) + { + TestIOGraph::init({1}, {1}); + PReluGraphlet::init(g()); + + _relu->features(input()); + _abs->x(input()); + _sub->x(input()); + _sub->y(_abs); + _mul_alpha->x(_sub); + _mul_alpha->y(_const_alpha); + _mul_half->x(_mul_alpha); + _mul_half->y(_const_half); + _add->x(_relu); + _add->y(_mul_half); + + output()->from(_add); + } +}; + +class FusePReluTestNegGraph : public TestIOGraph, public PReluGraphlet +{ +public: + FusePReluTestNegGraph() = default; + + void init(void) + { + TestIOGraph::init({1}, {1}); + PReluGraphlet::init(g()); + + _relu->features(input()); + _abs->x(input()); + // NOTE x and y are incorrect + _sub->x(_abs); + _sub->y(input()); + _mul_alpha->x(_sub); + _mul_alpha->y(_const_alpha); + _mul_half->x(_mul_alpha); + _mul_half->y(_const_half); + _add->x(_relu); + _add->y(_mul_half); + + output()->from(_add); + } +}; + +} // namespace + +TEST(FusePReluPassTest, name) +{ + luci::FusePReluPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST(FusePReluPassTest, fuse) +{ + FusePReluTestGraph g; + luci::FusePReluPass pass; + + g.init(); + + EXPECT_TRUE(pass.run(g.g())); +} + +TEST(FusePReluPassTest, fuse_invalid_half_NEG) +{ + FusePReluTestNegGraph g; + luci::FusePReluPass pass; + + g.init(); + g.invalid_half(); + + EXPECT_FALSE(pass.run(g.g())); +} + +TEST(FusePReluPassTest, fuse_invalid_act_NEG) +{ + FusePReluTestNegGraph g; + luci::FusePReluPass pass; + + g.init(); + g.invalid_act(); + + EXPECT_FALSE(pass.run(g.g())); +} + +TEST(FusePReluPassTest, fuse_NEG) +{ + FusePReluTestNegGraph g; + luci::FusePReluPass pass; + + g.init(); + + EXPECT_FALSE(pass.run(g.g())); +} diff --git a/compiler/luci/pass/src/QuantizationUtils.cpp b/compiler/luci/pass/src/QuantizationUtils.cpp index 06a4ae9f6..45d229a0b 100644 --- a/compiler/luci/pass/src/QuantizationUtils.cpp +++ b/compiler/luci/pass/src/QuantizationUtils.cpp @@ -34,6 +34,8 @@ bool is_quantized(const CircleNode *node) node->dtype() == loco::DataType::S64); // bias (int16 quant) } +bool is_fp32(const CircleNode *node) { return node->dtype() == loco::DataType::FLOAT32; } + uint8_t fp32_to_uint8_cast(float f) { assert(std::numeric_limits<uint8_t>::min() <= f); @@ -124,8 +126,8 @@ void compute_sym_scale_zp(float min, float max, float &scaling_factor, int64_t & : scale_factor_from_max_side; // protect scale from being very low to avoid overflow/underflow - if (scaling_factor < 1e-9) - scaling_factor = 1e-9; + if (scaling_factor < 1e-8) + scaling_factor = 1e-8; zp = 0; nudged_min = static_cast<float>(qmin_double * scaling_factor); diff --git a/compiler/luci/pass/src/QuantizationUtils.h b/compiler/luci/pass/src/QuantizationUtils.h index 4d5316ccb..0720c9839 100644 --- a/compiler/luci/pass/src/QuantizationUtils.h +++ b/compiler/luci/pass/src/QuantizationUtils.h @@ -60,6 +60,9 @@ void propagate_pad_v2_quantparam(luci::CirclePadV2 *pad_v2); // Return true if the node is quantized bool is_quantized(const CircleNode *node); +// Return true if the node is fp32 +bool is_fp32(const CircleNode *node); + enum ActivationQType { MinMax, // Quantize using recorded min/max diff --git a/compiler/luci/pass/src/QuantizeActivation.cpp b/compiler/luci/pass/src/QuantizeActivation.cpp index 95251a82c..214e61c1e 100644 --- a/compiler/luci/pass/src/QuantizeActivation.cpp +++ b/compiler/luci/pass/src/QuantizeActivation.cpp @@ -44,12 +44,8 @@ void QuantizeActivation::visit(luci::CircleNode *node) LOGGER(l); INFO(l) << "QuantizeActivation visit node: " << node->name() << std::endl; - // Check if this is already quantized - if (is_quantized(node)) - return; - - // Check if this is bool type (bool type is not quantized) - if (node->dtype() == loco::DataType::BOOL) + // Check if node is fp32 + if (not is_fp32(node)) return; // Check if this is const (const activation is handled by QuantizeConstInputActivation) @@ -185,7 +181,7 @@ void QuantizeConstInputActivation::visit(luci::CircleNode *node) { \ auto input = node->INPUT_NAME(); \ auto const_node = dynamic_cast<luci::CircleConst *>(input); \ - if (const_node && !is_quantized(const_node)) \ + if (const_node && is_fp32(const_node)) \ { \ auto new_const = luci::clone(const_node); \ quant_const(new_const, _output_type); \ @@ -199,7 +195,7 @@ void QuantizeConstInputActivation::visit(luci::CircleNode *node) { \ auto input1 = node->INPUT_NAME1(); \ auto const_node1 = dynamic_cast<luci::CircleConst *>(input1); \ - if (const_node1 && !is_quantized(const_node1)) \ + if (const_node1 && is_fp32(const_node1)) \ { \ auto new_const1 = luci::clone(const_node1); \ quant_const(new_const1, _output_type); \ @@ -207,7 +203,7 @@ void QuantizeConstInputActivation::visit(luci::CircleNode *node) } \ auto input2 = node->INPUT_NAME2(); \ auto const_node2 = dynamic_cast<luci::CircleConst *>(input2); \ - if (const_node2 && !is_quantized(const_node2)) \ + if (const_node2 && is_fp32(const_node2)) \ { \ auto new_const2 = luci::clone(const_node2); \ quant_const(new_const2, _output_type); \ @@ -216,6 +212,7 @@ void QuantizeConstInputActivation::visit(luci::CircleNode *node) } // Ops that receive a single activation as an input +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleAbs, x) QUANTIZE_SINGLE_CONST_INPUT(luci::CircleArgMax, input) QUANTIZE_SINGLE_CONST_INPUT(luci::CircleArgMin, input) QUANTIZE_SINGLE_CONST_INPUT(luci::CircleBatchToSpaceND, input) @@ -278,7 +275,7 @@ void QuantizeConstInputActivation::visit(luci::CircleAddN *node) { auto input_node = node->inputs(i); auto const_node = dynamic_cast<luci::CircleConst *>(input_node); - if (const_node && !is_quantized(const_node)) + if (const_node && is_fp32(const_node)) { auto new_const = luci::clone(const_node); quant_const(new_const, _output_type); diff --git a/compiler/luci/pass/src/QuantizeActivation.h b/compiler/luci/pass/src/QuantizeActivation.h index fc32d1cde..c6c991a76 100644 --- a/compiler/luci/pass/src/QuantizeActivation.h +++ b/compiler/luci/pass/src/QuantizeActivation.h @@ -102,6 +102,7 @@ private: void visit(luci::CircleNode *node); // Ops that receive a single activation as an input + void visit(luci::CircleAbs *node); void visit(luci::CircleArgMax *node); void visit(luci::CircleArgMin *node); void visit(luci::CircleBatchToSpaceND *node); diff --git a/compiler/luci/pass/src/QuantizeWeights.cpp b/compiler/luci/pass/src/QuantizeWeights.cpp index 500ae12ed..29cdaffff 100644 --- a/compiler/luci/pass/src/QuantizeWeights.cpp +++ b/compiler/luci/pass/src/QuantizeWeights.cpp @@ -90,6 +90,118 @@ void asym_wquant_per_channel(CircleConst *node, std::vector<float> &min, } } +// TODO Reduce duplicate code with QuantizeDequantizeWeights +void sym_wquant_per_channel(CircleConst *node, std::vector<float> &min, std::vector<float> &max, + std::vector<float> &scaling_factor, std::vector<int64_t> &zp, + std::vector<float> &nudged_min, std::vector<float> &nudged_max, + int32_t &channel_dim_index) +{ + assert(node->dtype() == loco::DataType::FLOAT32); + const int32_t kMaxScale = std::numeric_limits<int16_t>::max(); + const int32_t kMinScale = -kMaxScale; + + uint32_t size = node->size<loco::DataType::FLOAT32>(); + std::vector<int32_t> quantized_values(size); + + for (size_t i = 0; i < min.size(); ++i) + { + compute_sym_scale_zp(min[i], max[i], scaling_factor[i], zp[i], nudged_min[i], nudged_max[i]); + } + + auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) { + int channel_idx = indices[channel_dim_index]; + const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx]; + auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices)); + data = data < nudged_min[channel_idx] ? nudged_min[channel_idx] : data; + data = data > nudged_max[channel_idx] ? nudged_max[channel_idx] : data; + quantized_values[cal_offset(dimension, indices)] = + static_cast<int32_t>(std::round(data * scaling_factor_inv)); + }; + + iterate_per_channel(node, channel_dim_index, quantize); + + node->dtype(loco::DataType::S16); // change the type of tensor + node->size<loco::DataType::S16>(size); // resize tensor + for (uint32_t i = 0; i < size; ++i) + { + node->at<loco::DataType::S16>(i) = + std::min(kMaxScale, std::max(kMinScale, quantized_values[i])); + } +} + +void cal_minmax_per_channel(CircleConst *node, std::vector<float> &min, std::vector<float> &max, + int32_t &channel_dim_index) +{ + loco::TensorShape dimension; + dimension.rank(4); + + if (!get_channel_dim_index(node, dimension, channel_dim_index)) + { + throw std::runtime_error("Failed to find channel index in " + node->name()); + } + auto size = dimension.dim(channel_dim_index).value(); + + std::vector<bool> has_min_max_value(size, false); + min.resize(size); + max.resize(size); + + auto cal_minmax = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) { + int channel_idx = indices[channel_dim_index]; + auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices)); + if (has_min_max_value[channel_idx]) + { + min[channel_idx] = data < min[channel_idx] ? data : min[channel_idx]; + max[channel_idx] = data > max[channel_idx] ? data : max[channel_idx]; + } + else + { + min[channel_idx] = data; + max[channel_idx] = data; + has_min_max_value[channel_idx] = true; + } + }; + + iterate_per_channel(node, channel_dim_index, cal_minmax); +} + +void asymmetric_wquant_per_channel(CircleConst *node, std::vector<float> &min, + std::vector<float> &max, std::vector<float> &scaling_factor, + std::vector<int64_t> &zp, std::vector<float> &nudged_min, + std::vector<float> &nudged_max, int32_t &channel_dim_index) +{ + assert(node->dtype() == loco::DataType::FLOAT32); + + const int32_t kMinScale = 0; + const int32_t kMaxScale = 255; + + uint32_t size = node->size<loco::DataType::FLOAT32>(); + std::vector<int32_t> quantized_values(size); + + for (size_t i = 0; i < min.size(); ++i) + { + compute_asym_scale_zp(min[i], max[i], scaling_factor[i], zp[i], nudged_min[i], nudged_max[i]); + } + + auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) { + int channel_idx = indices[channel_dim_index]; + const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx]; + auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices)); + data = data < nudged_min[channel_idx] ? nudged_min[channel_idx] : data; + data = data > nudged_max[channel_idx] ? nudged_max[channel_idx] : data; + quantized_values[cal_offset(dimension, indices)] = + static_cast<int32_t>(std::round((data - nudged_min[channel_idx]) * scaling_factor_inv)); + }; + + iterate_per_channel(node, channel_dim_index, quantize); + + node->dtype(loco::DataType::U8); // change the type of tensor + node->size<loco::DataType::U8>(size); // resize tensor + for (uint32_t i = 0; i < size; ++i) + { + node->at<loco::DataType::U8>(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i])); + } +} + void sym_wquant_per_channel(CircleConst *node, std::vector<float> &scaling_factor, int32_t &channel_dim_index) { @@ -250,7 +362,37 @@ void QuantizeWeights::quantize_weights(luci::CircleConst *weights) auto quantparam = weights->quantparam(); if (quantparam == nullptr) { - assert(false && "quantparam is nullptr"); + // Find min/max on the fly + // NOTE This is for the case when QuantizeDequantizeWeights is skipped + // TODO Reduce duplicate codes + std::vector<float> min; + std::vector<float> max; + int32_t channel_dim_index = 0; + + cal_minmax_per_channel(weights, min, max, channel_dim_index); + + std::vector<float> nudged_min(min.size()); + std::vector<float> nudged_max(min.size()); + std::vector<float> scaling_factor(min.size()); + std::vector<int64_t> zp(min.size()); + + if (output_type == loco::DataType::U8) + { + asymmetric_wquant_per_channel(weights, min, max, scaling_factor, zp, nudged_min, nudged_max, + channel_dim_index); + } + else + { + sym_wquant_per_channel(weights, min, max, scaling_factor, zp, nudged_min, nudged_max, + channel_dim_index); + } + + auto quantparam = std::make_unique<CircleQuantParam>(); + quantparam->scale = scaling_factor; + quantparam->zerop = zp; + quantparam->quantized_dimension = channel_dim_index; + weights->quantparam(std::move(quantparam)); + return; } @@ -273,8 +415,35 @@ void QuantizeWeights::quantize_weights(luci::CircleConst *weights) // Find min/max per layer-wise else { - // Quantize using recorded quantparam auto quantparam = weights->quantparam(); + if (quantparam == nullptr) + { + // Find min/max on the fly + // NOTE This is for the case when QuantizeDequantizeWeights is skipped + // TODO Reduce duplicate codes + float min = std::numeric_limits<float>::max(); + float max = std::numeric_limits<float>::lowest(); + for (uint32_t i = 0; i < weights->size<loco::DataType::FLOAT32>(); i++) + { + auto data = weights->at<loco::DataType::FLOAT32>(i); + min = data < min ? data : min; + max = data > max ? data : max; + } + float scaling_factor{0}; + int64_t zp{0}; + float nudged_min{0}; + float nudged_max{0}; + + asymmetric_wquant_with_minmax_per_layer(weights, min, max, scaling_factor, zp, nudged_min, + nudged_max); + auto quantparam = std::make_unique<CircleQuantParam>(); + quantparam->scale.push_back(scaling_factor); + quantparam->zerop.push_back(zp); + weights->quantparam(std::move(quantparam)); + return; + } + + // Quantize using recorded quantparam assert(quantparam != nullptr); assert(quantparam->min.size() == 1); // only support layer-wise quant assert(quantparam->scale.size() == 1); // only support layer-wise quant diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp index 005144516..c68e06712 100644 --- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp +++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp @@ -32,8 +32,6 @@ #include <luci/Log.h> #include <logo/Phase.h> -#include <oops/UserExn.h> - #include <iostream> #include <cmath> @@ -154,8 +152,8 @@ namespace * 2. After output feature map * * For example, if default_dtype = U8 and op_dtype = S16, - * 1. Quantize Op for U8->S16 is inserted before ifm - * 2. Quantize Op for S16->U8 is inserted after ofm + * 1. Quantize (U8->S16) is inserted before ifm + * 2. Quantize (S16->U8) is inserted after ofm * * Why not insert Quantize Op for const ifm? * We quantize const tensor at once to preserve precision. @@ -181,6 +179,10 @@ private: if (input->opcode() == luci::CircleOpcode::CIRCLECONST) return nullptr; + // input is not quantizable (ex: index) + if (input->quantparam() == nullptr) + return nullptr; + auto input_quant = create_quantize_op(input, _op_dtype); input_quant->input(input); auto origin_node = loco::must_cast<luci::CircleNode *>(origin); @@ -192,6 +194,11 @@ private: { auto output = loco::must_cast<luci::CircleNode *>(node); assert(output->opcode() != luci::CircleOpcode::CIRCLECONST); // FIX_CALLER_UNLESS + + // output is not quantizable (ex: index) + if (output->quantparam() == nullptr) + return; + auto output_quant = create_quantize_op(output, _default_dtype); luci::add_origin(output_quant, luci::get_origin(output)); @@ -253,6 +260,7 @@ private: void visit(luci::CircleUnpackOut *) {} // Ops that receive a single activation as an input + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleAbs, x) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleAveragePool2D, value) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleBatchToSpaceND, input) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleConv2D, input) @@ -365,10 +373,20 @@ private: void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const { auto inputs = g->inputs(); - for (auto node : loco::input_nodes(g)) + + assert(inputs); // FIX_CALLER_UNLESS + assert(inputs->size() == _ctx->input_types.size()); // FIX_CALLER_UNLESS + + // NOTE loco::input_nodes returns input nodes following the order of InputIndex + auto input_nodes = loco::input_nodes(g); + for (uint32_t i = 0; i < input_nodes.size(); i++) { - auto input = loco::must_cast<luci::CircleInput *>(node); - if (input->dtype() == _ctx->input_type) + auto input = loco::must_cast<luci::CircleInput *>(input_nodes[i]); + assert(i == input->index()); // Fix input_type logic + + const auto user_given_dtype = _ctx->input_types[i]; + + if (input->dtype() == user_given_dtype) continue; // Bool type is not quantizable @@ -394,7 +412,7 @@ void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const // Update qparam of input // This step is skipped if input_type is float32 - if (_ctx->input_type != loco::DataType::FLOAT32) + if (user_given_dtype != loco::DataType::FLOAT32) { auto quantparam = input->quantparam(); assert(quantparam); @@ -408,13 +426,13 @@ void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const float nudged_min{0}; float nudged_max{0}; - if (_ctx->input_type == loco::DataType::U8) + if (user_given_dtype == loco::DataType::U8) { compute_asym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max); } else { - assert(_ctx->input_type == loco::DataType::S16); + assert(user_given_dtype == loco::DataType::S16); compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max); } input->quantparam()->scale[0] = scaling_factor; @@ -422,20 +440,29 @@ void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const } // Update dtype of input - input->dtype(_ctx->input_type); + input->dtype(user_given_dtype); auto graph_input = inputs->at(input->index()); - graph_input->dtype(_ctx->input_type); + graph_input->dtype(user_given_dtype); } } void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const { auto outputs = g->outputs(); - for (auto node : loco::output_nodes(g)) + assert(outputs); // FIX_CALLER_UNLESS + assert(outputs->size() == _ctx->output_types.size()); // Fix CircleQuantizer unless + + // NOTE loco::output_nodes returns output nodes following the order of OutputIndex + auto output_nodes = loco::output_nodes(g); + for (uint32_t i = 0; i < output_nodes.size(); i++) { - auto output = loco::must_cast<luci::CircleOutput *>(node); - if (output->dtype() == _ctx->output_type) + auto output = loco::must_cast<luci::CircleOutput *>(output_nodes[i]); + assert(i == output->index()); // Fix output_type logic + + const auto user_given_dtype = _ctx->output_types[i]; + + if (output->dtype() == user_given_dtype) continue; // Bool type is not quantizable @@ -444,12 +471,12 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const auto from = loco::must_cast<luci::CircleNode *>(output->from()); - // The last Op is not quantizable Op (ex: ArgMax) + // The last Op is not quantizable (ex: ArgMax) if (not from->quantparam()) continue; // Insert Dequantize Op for float32 output_type - if (_ctx->output_type == loco::DataType::FLOAT32) + if (user_given_dtype == loco::DataType::FLOAT32) { auto dequant_op = create_dequantize(from); loco::replace(from).with(dequant_op); @@ -458,7 +485,7 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const else { // Insert Quantize Op for non-float32 output_type - auto quant_op = create_quantize_op(from, _ctx->output_type); + auto quant_op = create_quantize_op(from, user_given_dtype); loco::replace(from).with(quant_op); quant_op->input(from); @@ -467,10 +494,10 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const } // Update dtype of output - output->dtype(_ctx->output_type); + output->dtype(user_given_dtype); auto graph_output = outputs->at(output->index()); - graph_output->dtype(_ctx->output_type); + graph_output->dtype(user_given_dtype); } } @@ -493,9 +520,9 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const * Weights is quantized using min/max of its value * * Bias is quantized using input scale (s_i) and weights scale (s_w) - * - Activation and weights should be quantized earlier than bias + * - Therefore, activation and weights should be quantized earlier than bias * - * Quantization Steps + * Overall Quantization Steps * 1. Quantize Activation * - Quantize using recorded min/max (QuantizeActivation) * - Insert Quantize Ops for mixed-precision quantization (InsertQuantizeOp) @@ -550,7 +577,10 @@ bool QuantizeWithMinMaxPass::run(loco::Graph *g) }; // Quantize activation - for (auto node : loco::active_nodes(loco::output_nodes(g))) + // Why all_nodes? + // Models can have inactive (unused) inputs. + // We do not reject such models, but quantize them too + for (auto node : loco::all_nodes(g)) { auto circle_node = loco::must_cast<luci::CircleNode *>(node); QuantizeActivation qa(_ctx->input_model_dtype, quantize_dtype(circle_node)); diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp index d5fa21ffd..49c2d4652 100644 --- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp +++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp @@ -53,8 +53,14 @@ public: TEST(QuantizeWithMinMaxPassTest, name) { - luci::QuantizeWithMinMaxPass pass(loco::DataType::FLOAT32, loco::DataType::U8, - luci::QuantizationGranularity::LayerWise); + auto ctx = std::make_unique<luci::QuantizeWithMinMaxPass::Context>(); + { + ctx->input_model_dtype = loco::DataType::FLOAT32; + ctx->output_model_dtype = loco::DataType::U8; + ctx->granularity = luci::QuantizationGranularity::LayerWise; + } + + luci::QuantizeWithMinMaxPass pass(std::move(ctx)); auto const name = pass.name(); ASSERT_NE(nullptr, name); } @@ -65,8 +71,14 @@ TEST(QuantizeWithMinMaxPassTest, int_concat) { SimpleConcatGraph g(loco::DataType::S32); - luci::QuantizeWithMinMaxPass qwmm(loco::DataType::FLOAT32, loco::DataType::U8, - luci::QuantizationGranularity::LayerWise); + auto ctx = std::make_unique<luci::QuantizeWithMinMaxPass::Context>(); + { + ctx->input_model_dtype = loco::DataType::FLOAT32; + ctx->output_model_dtype = loco::DataType::U8; + ctx->granularity = luci::QuantizationGranularity::LayerWise; + } + + luci::QuantizeWithMinMaxPass qwmm(std::move(ctx)); qwmm.run(&g.g); @@ -74,3 +86,22 @@ TEST(QuantizeWithMinMaxPassTest, int_concat) EXPECT_EQ(nullptr, g.input_1->quantparam()); EXPECT_EQ(nullptr, g.input_2->quantparam()); } + +TEST(QuantizeWithMinMaxPassTest, inactive_input) +{ + SimpleConcatGraph g(loco::DataType::FLOAT32); + + // Unused input + g.g.nodes()->create<luci::CircleInput>(); + + auto ctx = std::make_unique<luci::QuantizeWithMinMaxPass::Context>(); + { + ctx->input_model_dtype = loco::DataType::FLOAT32; + ctx->output_model_dtype = loco::DataType::U8; + ctx->granularity = luci::QuantizationGranularity::LayerWise; + } + + luci::QuantizeWithMinMaxPass qwmm(std::move(ctx)); + + EXPECT_NO_THROW(qwmm.run(&g.g)); +} diff --git a/compiler/luci/pass/src/QuantizedModelVerifier.h b/compiler/luci/pass/src/QuantizedModelVerifier.h index 7409a51d7..d9bea434d 100644 --- a/compiler/luci/pass/src/QuantizedModelVerifier.h +++ b/compiler/luci/pass/src/QuantizedModelVerifier.h @@ -38,26 +38,13 @@ public: { loco::DataType output_model_dtype = loco::DataType::Unknown; QuantizationGranularity granularity = QuantizationGranularity::ChannelWise; - loco::DataType input_type = loco::DataType::Unknown; - loco::DataType output_type = loco::DataType::Unknown; + std::vector<loco::DataType> input_types; + std::vector<loco::DataType> output_types; bool TF_style_maxpool = false; std::vector<LayerInfo> layers_info; }; public: - QuantizedModelVerifier(loco::DataType quantized_dtype, QuantizationGranularity granularity) - { - _ctx = std::make_unique<Context>(); - { - _ctx->output_model_dtype = quantized_dtype; - _ctx->granularity = granularity; - _ctx->input_type = quantized_dtype; - _ctx->output_type = quantized_dtype; - _ctx->TF_style_maxpool = false; - } - } - -public: QuantizedModelVerifier(std::unique_ptr<Context> &&ctx) : _ctx{std::move(ctx)} { // DO NOTHING diff --git a/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp b/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp index 21b4fe1c6..05ec31727 100644 --- a/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp +++ b/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp @@ -18,7 +18,9 @@ #include "luci/Pass/QuantizeWithMinMaxPass.h" #include "luci/Pass/QuantizationParameters.h" +#include "luci/Pass/CircleTypeInferencePass.h" +#include <logo/Phase.h> #include <luci/test/TestIOGraph.h> #include <gtest/gtest.h> @@ -104,12 +106,56 @@ void insert_scale_zp(luci::CircleNode *node, float scale, int64_t zp) qparam->zerop.push_back(zp); } +void run_phase(loco::Graph *g, Type quantized_dtype, Granularity granularity) +{ + logo::Phase phase; + + // Default passes. + phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>()); + + auto ctx = std::make_unique<luci::QuantizeWithMinMaxPass::Context>(); + { + ctx->input_model_dtype = loco::DataType::FLOAT32; + ctx->output_model_dtype = quantized_dtype; + ctx->granularity = granularity; + // Test graph has only one input/output + ctx->input_types = {quantized_dtype}; + ctx->output_types = {quantized_dtype}; + } + + phase.emplace_back(std::make_unique<luci::QuantizeWithMinMaxPass>(std::move(ctx))); + + logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g}; + phase_runner.run(phase); +} + +void run_phase(loco::Graph *g, std::unique_ptr<luci::QuantizeWithMinMaxPass::Context> &&ctx) +{ + logo::Phase phase; + + // Default passes. + phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>()); + + phase.emplace_back(std::make_unique<luci::QuantizeWithMinMaxPass>(std::move(ctx))); + + logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g}; + phase_runner.run(phase); +} + void quantize_and_verify(loco::Graph *g, Type quantized_dtype, Granularity granularity) { - luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity); - pass.run(g); + run_phase(g, quantized_dtype, granularity); - luci::QuantizedModelVerifier verifier(quantized_dtype, granularity); + auto ctx = std::make_unique<luci::QuantizedModelVerifier::Context>(); + { + ctx->output_model_dtype = quantized_dtype; + ctx->granularity = granularity; + // Test graph has only one input/output + ctx->input_types = {quantized_dtype}; + ctx->output_types = {quantized_dtype}; + } + + luci::QuantizedModelVerifier verifier(std::move(ctx)); verifier.verify(g); } @@ -132,14 +178,14 @@ void quantize_and_verify_with_layer_info(loco::Graph *g, Type quantized_dtype, ctx->input_model_dtype = Type::FLOAT32; ctx->output_model_dtype = quantized_dtype; ctx->granularity = granularity; - ctx->input_type = quantized_dtype; - ctx->output_type = quantized_dtype; + // Test graph has only one input/output + ctx->input_types = {quantized_dtype}; + ctx->output_types = {quantized_dtype}; ctx->TF_style_maxpool = false; ctx->layers_info.push_back(info); } - luci::QuantizeWithMinMaxPass pass(std::move(ctx)); - pass.run(g); + run_phase(g, std::move(ctx)); } // Do verification @@ -148,8 +194,8 @@ void quantize_and_verify_with_layer_info(loco::Graph *g, Type quantized_dtype, { ctx->output_model_dtype = quantized_dtype; ctx->granularity = granularity; - ctx->input_type = quantized_dtype; - ctx->output_type = quantized_dtype; + ctx->input_types = {quantized_dtype}; + ctx->output_types = {quantized_dtype}; ctx->TF_style_maxpool = false; ctx->layers_info.push_back(info); } @@ -164,13 +210,21 @@ void quantize_and_verify_with_layer_info(loco::Graph *g, Type quantized_dtype, void quantize_and_verify_with_wrong_type(luci::test::TestIOGraph *g, Type quantized_dtype, Granularity granularity, Type wrong_dtype) { - luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity); - pass.run(g->g()); + run_phase(g->g(), quantized_dtype, granularity); auto node = loco::must_cast<luci::CircleNode *>(g->output()->from()); node->dtype(wrong_dtype); - luci::QuantizedModelVerifier verifier(quantized_dtype, granularity); + auto ctx = std::make_unique<luci::QuantizedModelVerifier::Context>(); + { + ctx->output_model_dtype = quantized_dtype; + ctx->granularity = granularity; + // Test graph has only one input/output + ctx->input_types = {quantized_dtype}; + ctx->output_types = {quantized_dtype}; + } + + luci::QuantizedModelVerifier verifier(std::move(ctx)); verifier.verify(g->g()); } @@ -179,13 +233,21 @@ void quantize_and_verify_with_wrong_type(luci::test::TestIOGraph *g, Type quanti void quantize_and_verify_with_wrong_granularity(luci::test::TestIOGraph *g, Type quantized_dtype, Granularity granularity) { - luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity); - pass.run(g->g()); + run_phase(g->g(), quantized_dtype, granularity); auto node = loco::must_cast<luci::CircleNode *>(g->output()->from()); insert_scale_zp(node, 1.0, 1); - luci::QuantizedModelVerifier verifier(quantized_dtype, granularity); + auto ctx = std::make_unique<luci::QuantizedModelVerifier::Context>(); + { + ctx->output_model_dtype = quantized_dtype; + ctx->granularity = granularity; + // Test graph has only one input/output + ctx->input_types = {quantized_dtype}; + ctx->output_types = {quantized_dtype}; + } + + luci::QuantizedModelVerifier verifier(std::move(ctx)); verifier.verify(g->g()); } @@ -238,6 +300,24 @@ public: virtual void init(void) = 0; }; +class TypedTestGraph : public luci::test::TestIOGraph +{ +protected: + void init(Type T, const luci::test::ShapeU32 shape_in, const luci::test::ShapeU32 shape_out) + { + TestIOGraph::init(shape_in, shape_out); + + input()->dtype(T); + output()->dtype(T); + + g()->inputs()->at(0)->dtype(T); + g()->outputs()->at(0)->dtype(T); + } + +public: + virtual void init(void) = 0; +}; + class InstanceNormTestGraph final : public SimpleTestGraph { public: @@ -603,6 +683,9 @@ public: output()->from(_argmax); set_minmax_to_non_const(g(), -1, 1); + + // Sync output dtype with graph's output dtype + g()->outputs()->at(0)->dtype(output()->dtype()); } public: @@ -904,6 +987,9 @@ public: output()->from(_op); set_minmax_to_non_const(g(), -1, 1); + + // Sync output dtype with graph's output dtype + g()->outputs()->at(0)->dtype(output()->dtype()); } loco::Node *x(void) const { return _op->x(); } @@ -934,6 +1020,9 @@ public: output()->from(_op); set_minmax_to_non_const(g(), -1, 1); + + // Sync output dtype with graph's output dtype + g()->outputs()->at(0)->dtype(output()->dtype()); } loco::Node *x(void) const { return _op->x(); } @@ -1218,6 +1307,33 @@ private: luci::CircleConst *_const = nullptr; }; +template <Type T> class IntMulTestGraph final : public TypedTestGraph +{ +public: + void init(void) override + { + TypedTestGraph::init(T, {32}, {32}); + + _const = create_dummy_const<T>(g(), {32}); + _mul = g()->nodes()->create<luci::CircleMul>(); + { + _mul->x(input()); + _mul->y(_const); + _mul->fusedActivationFunction(luci::FusedActFunc::NONE); + _mul->name("test"); + _mul->dtype(T); + } + output()->from(_mul); + } + + loco::Node *x() { return _mul->x(); } + loco::Node *y() { return _mul->y(); } + +private: + luci::CircleMul *_mul = nullptr; + luci::CircleConst *_const = nullptr; +}; + class AddTestGraph final : public SimpleTestGraph { public: @@ -1246,6 +1362,33 @@ private: luci::CircleConst *_const = nullptr; }; +template <Type T> class IntAddTestGraph final : public TypedTestGraph +{ +public: + void init(void) override + { + TypedTestGraph::init(T, {32}, {32}); + + _const = create_dummy_const<T>(g(), {32}); + _add = g()->nodes()->create<luci::CircleAdd>(); + { + _add->x(input()); + _add->y(_const); + _add->fusedActivationFunction(luci::FusedActFunc::NONE); + _add->name("test"); + _add->dtype(T); + } + output()->from(_add); + } + + loco::Node *x() { return _add->x(); } + loco::Node *y() { return _add->y(); } + +private: + luci::CircleAdd *_add = nullptr; + luci::CircleConst *_const = nullptr; +}; + } // namespace // Quantize and verify with given configurations @@ -1286,34 +1429,46 @@ private: // Quantize and verify with wrong type // Users can specify the test target -#define TEST_WITH_WRONG_TYPE_TARGET(graph, type, granularity, wrong_dtype, target) \ - do \ - { \ - graph g; \ - g.init(); \ - auto node = loco::must_cast<luci::CircleNode *>(target); \ - luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, type, granularity); \ - pass.run(g.g()); \ - auto after_node = loco::must_cast<luci::CircleNode *>(target); \ - after_node->dtype(wrong_dtype); \ - luci::QuantizedModelVerifier verifier(type, granularity); \ - EXPECT_ANY_THROW(verifier.verify(g.g())); \ +#define TEST_WITH_WRONG_TYPE_TARGET(graph, type, granularity_, wrong_dtype, target) \ + do \ + { \ + graph g; \ + g.init(); \ + auto node = loco::must_cast<luci::CircleNode *>(target); \ + run_phase(g.g(), type, granularity_); \ + auto after_node = loco::must_cast<luci::CircleNode *>(target); \ + after_node->dtype(wrong_dtype); \ + auto ctx = std::make_unique<luci::QuantizedModelVerifier::Context>(); \ + { \ + ctx->output_model_dtype = type; \ + ctx->granularity = granularity_; \ + ctx->input_types = {type}; \ + ctx->output_types = {type}; \ + } \ + luci::QuantizedModelVerifier verifier(std::move(ctx)); \ + EXPECT_ANY_THROW(verifier.verify(g.g())); \ } while (0) // Quantize and verify with wrong granularity // Users can specify the test target -#define TEST_WITH_WRONG_GRANULARITY_TARGET(graph, type, granularity, target) \ - do \ - { \ - graph g; \ - g.init(); \ - auto node = loco::must_cast<luci::CircleNode *>(target); \ - luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, type, granularity); \ - pass.run(g.g()); \ - auto after_node = loco::must_cast<luci::CircleNode *>(target); \ - insert_scale_zp(after_node, 1.0, 1); \ - luci::QuantizedModelVerifier verifier(type, granularity); \ - EXPECT_ANY_THROW(verifier.verify(g.g())); \ +#define TEST_WITH_WRONG_GRANULARITY_TARGET(graph, type, granularity_, target) \ + do \ + { \ + graph g; \ + g.init(); \ + auto node = loco::must_cast<luci::CircleNode *>(target); \ + run_phase(g.g(), type, granularity_); \ + auto after_node = loco::must_cast<luci::CircleNode *>(target); \ + insert_scale_zp(after_node, 1.0, 1); \ + auto ctx = std::make_unique<luci::QuantizedModelVerifier::Context>(); \ + { \ + ctx->output_model_dtype = type; \ + ctx->granularity = granularity_; \ + ctx->input_types = {type}; \ + ctx->output_types = {type}; \ + } \ + luci::QuantizedModelVerifier verifier(std::move(ctx)); \ + EXPECT_ANY_THROW(verifier.verify(g.g())); \ } while (0) // Test a local helper function @@ -2512,6 +2667,29 @@ TEST(QuantizedModelVerifierTest, Add_wrong_granularity_NEG) SUCCEED(); } +TEST(QuantizedModelVerifierTest, Add_inttype) +{ + // Tests for S32 + TEST_WITH_GRAPH(IntAddTestGraph<Type::S32>, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(IntAddTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(IntAddTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(IntAddTestGraph<Type::S32>, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(IntAddTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(IntAddTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise); + + // Tests for S64 + TEST_WITH_GRAPH(IntAddTestGraph<Type::S64>, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(IntAddTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(IntAddTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(IntAddTestGraph<Type::S64>, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(IntAddTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(IntAddTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise); + + SUCCEED(); +} + TEST(QuantizedModelVerifierTest, Mul) { TEST_WITH_GRAPH(MulTestGraph, Type::U8, Granularity::LayerWise); @@ -2544,6 +2722,29 @@ TEST(QuantizedModelVerifierTest, Mul_wrong_granularity_NEG) SUCCEED(); } +TEST(QuantizedModelVerifierTest, Mul_inttype) +{ + // Tests for S32 + TEST_WITH_GRAPH(IntMulTestGraph<Type::S32>, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(IntMulTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(IntMulTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(IntMulTestGraph<Type::S32>, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(IntMulTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(IntMulTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise); + + // Tests for S64 + TEST_WITH_GRAPH(IntMulTestGraph<Type::S64>, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(IntMulTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(IntMulTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(IntMulTestGraph<Type::S64>, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(IntMulTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(IntMulTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise); + + SUCCEED(); +} + // TODO Add following testcases // // CircleConv2D diff --git a/compiler/luci/pass/src/RemoveDuplicateConstPass.cpp b/compiler/luci/pass/src/RemoveDuplicateConstPass.cpp new file mode 100644 index 000000000..e50dda9e0 --- /dev/null +++ b/compiler/luci/pass/src/RemoveDuplicateConstPass.cpp @@ -0,0 +1,225 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/RemoveDuplicateConstPass.h" + +#include <luci/Log.h> + +namespace +{ + +bool compare_quant_params(luci::CircleConst *left, luci::CircleConst *right) +{ + const auto left_quant_param = left->quantparam(); + const auto right_quant_param = right->quantparam(); + + if (left_quant_param == right_quant_param) + return true; + + if (left_quant_param != nullptr and right_quant_param != nullptr) + { + if (left_quant_param->scale == right_quant_param->scale and + left_quant_param->quantized_dimension == right_quant_param->quantized_dimension and + left_quant_param->zerop == right_quant_param->zerop and + left_quant_param->min == right_quant_param->min and + left_quant_param->max == right_quant_param->max) + { + return true; + } + } + return false; +} + +bool compare_dim_values(luci::CircleConst *left, luci::CircleConst *right) +{ + const auto left_rank = left->rank(); + const auto right_rank = right->rank(); + + if (left_rank != right_rank) + return false; + + for (uint32_t i = 0; i < left_rank; ++i) + { + if (left->dim(i).value() != right->dim(i).value()) + return false; + } + + return true; +} + +template <loco::DataType DT> bool is_equal_consts(luci::CircleConst *left, luci::CircleConst *right) +{ + if (not compare_quant_params(left, right)) + return false; + + if (not compare_dim_values(left, right)) + return false; + + for (uint32_t i = 0; i < left->size<DT>(); ++i) + { + if (left->at<DT>(i) != right->at<DT>(i)) + return false; + } + + return true; +} + +} // namespace + +namespace luci +{ + +bool RemoveDuplicateConstPass::remove_duplicate_const() +{ + bool changed = false; + + for (auto &cur_pair : _sum_to_const) + { + // if single const - continue + if (cur_pair.second.size() == 1) + continue; + + for (auto reference_const : cur_pair.second) + { + if (reference_const == nullptr) + continue; + + for (uint32_t i = 0; i < cur_pair.second.size(); ++i) + { + auto cur_const = cur_pair.second.at(i); + if (cur_const == nullptr or cur_const == reference_const) + continue; + + if (cur_const->dtype() != reference_const->dtype()) + continue; + + bool is_equal = false; + + switch (cur_const->dtype()) + { + case loco::DataType::FLOAT32: + is_equal = is_equal_consts<loco::DataType::FLOAT32>(reference_const, cur_const); + break; + case loco::DataType::S32: + is_equal = is_equal_consts<loco::DataType::S32>(reference_const, cur_const); + break; + case loco::DataType::S16: + is_equal = is_equal_consts<loco::DataType::S16>(reference_const, cur_const); + break; + case loco::DataType::S8: + is_equal = is_equal_consts<loco::DataType::S8>(reference_const, cur_const); + break; + case loco::DataType::U8: + is_equal = is_equal_consts<loco::DataType::U8>(reference_const, cur_const); + break; + default: + continue; + } + + if (not is_equal) + continue; + + loco::replace(cur_const).with(reference_const); + + // Remove from next checking + cur_pair.second[i] = nullptr; + + changed = true; + } + } + } + + return changed; +} + +template <loco::DataType DT> +void RemoveDuplicateConstPass::add_to_map(luci::CircleConst *const_node) +{ + const auto const_size = const_node->size<DT>(); + float sum = 0.0; + + for (uint32_t i = 0; i < const_size; ++i) + { + sum += const_node->at<DT>(i); + } + + if (_sum_to_const.find(sum) == _sum_to_const.end()) + { + _sum_to_const[sum] = {const_node}; + } + else + { + _sum_to_const.at(sum).push_back(const_node); + } +} + +/** + * Remove duplicate Const nodes. + * + * BEFORE + * [CircleNode] [CircleConst] + * | / + * | / + * [CircleNode] [CircleConst] + * | / + * | / + * [CircleNode] + * + * AFTER + * + * [CircleNode] [CircleConst] + * | / / + * | / / + * [CircleNode] / + * | / + * | / + * [CircleNode] + * + */ +bool RemoveDuplicateConstPass::run(loco::Graph *g) +{ + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto const_node = dynamic_cast<luci::CircleConst *>(node); + if (const_node == nullptr) + continue; + + switch (const_node->dtype()) + { + case loco::DataType::FLOAT32: + add_to_map<loco::DataType::FLOAT32>(const_node); + break; + case loco::DataType::S32: + add_to_map<loco::DataType::S32>(const_node); + break; + case loco::DataType::S16: + add_to_map<loco::DataType::S16>(const_node); + break; + case loco::DataType::S8: + add_to_map<loco::DataType::S8>(const_node); + break; + case loco::DataType::U8: + add_to_map<loco::DataType::U8>(const_node); + break; + default: + continue; + } + } + + return remove_duplicate_const(); +} + +} // namespace luci diff --git a/compiler/luci/pass/src/RemoveDuplicateConstPass.test.cpp b/compiler/luci/pass/src/RemoveDuplicateConstPass.test.cpp new file mode 100644 index 000000000..5052a3e01 --- /dev/null +++ b/compiler/luci/pass/src/RemoveDuplicateConstPass.test.cpp @@ -0,0 +1,159 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/RemoveDuplicateConstPass.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/test/TestIOGraph.h> +#include <gtest/gtest.h> + +namespace +{ +using namespace luci::test; + +class DuplicateConstsGraphlet +{ +public: + DuplicateConstsGraphlet() = default; + +public: + void init(loco::Graph *g, bool is_duplicate) + { + _reshape_shape = g->nodes()->create<luci::CircleConst>(); + _reshape_shape->rank(1); + _reshape_shape->dim(0).set(1); + _reshape_shape->shape_status(luci::ShapeStatus::VALID); + _reshape_shape->dtype(loco::DataType::S32); + + _reshape_shape->size<loco::DataType::S32>(1); + _reshape_shape->at<loco::DataType::S32>(0) = 5; + _reshape_shape->name("reshape_shape_1"); + + _reshape_shape_duplicate = g->nodes()->create<luci::CircleConst>(); + _reshape_shape_duplicate->rank(1); + _reshape_shape_duplicate->dim(0).set(1); + _reshape_shape_duplicate->shape_status(luci::ShapeStatus::VALID); + _reshape_shape_duplicate->dtype(loco::DataType::S32); + if (is_duplicate) + { + _reshape_shape_duplicate->size<loco::DataType::S32>(1); + _reshape_shape_duplicate->at<loco::DataType::S32>(0) = 5; + } + else + { + _reshape_shape_duplicate->size<loco::DataType::S32>(2); + _reshape_shape_duplicate->at<loco::DataType::S32>(0) = 1; + _reshape_shape_duplicate->at<loco::DataType::S32>(1) = 5; + } + _reshape_shape_duplicate->name("reshape_shape_2"); + + _reshape_f = g->nodes()->create<luci::CircleReshape>(); + _reshape_f->newShape()->rank(1); + _reshape_f->newShape()->dim(0) = 5; + _reshape_f->name("reshape_f"); + + _reshape_s = g->nodes()->create<luci::CircleReshape>(); + if (is_duplicate) + { + _reshape_s->newShape()->rank(1); + _reshape_s->newShape()->dim(0) = 5; + } + else + { + _reshape_s->newShape()->rank(2); + _reshape_s->newShape()->dim(0) = 1; + _reshape_s->newShape()->dim(1) = 5; + } + _reshape_s->name("reshape_s"); + } + +protected: + luci::CircleReshape *_reshape_f = nullptr; + luci::CircleReshape *_reshape_s = nullptr; + luci::CircleConst *_reshape_shape = nullptr; + luci::CircleConst *_reshape_shape_duplicate = nullptr; +}; + +class DuplicateConstsGraph : public TestIOGraph, public DuplicateConstsGraphlet +{ +public: + DuplicateConstsGraph() = default; + +public: + void init(const ShapeU32 in_shape, const ShapeU32 out_shape, bool is_duplicate) + { + TestIOGraph::init(in_shape, out_shape); + + DuplicateConstsGraphlet::init(g(), is_duplicate); + + // connect graph + _reshape_f->tensor(input()); + _reshape_f->shape(_reshape_shape); + + _reshape_s->tensor(_reshape_f); + _reshape_s->shape(_reshape_shape_duplicate); + + output()->from(_reshape_s); + } +}; +} // namespace + +TEST(RemoveDuplicateConstPass, name) +{ + luci::RemoveDuplicateConstPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST(RemoveDuplicateConstPass, remove_duplicate) +{ + DuplicateConstsGraph g; + g.init({1, 5}, {5}, true); + + luci::RemoveDuplicateConstPass pass; + while (pass.run(g.g())) + ; + + uint32_t const_num = 0; + for (auto node : loco::active_nodes(loco::output_nodes(g.g()))) + { + auto target_node = dynamic_cast<luci::CircleConst *>(node); + if (target_node != nullptr) + const_num++; + } + + ASSERT_EQ(const_num, 1); +} + +TEST(RemoveDuplicateConstPass, remove_duplicate_NEG) +{ + DuplicateConstsGraph g; + g.init({1, 5}, {1, 5}, false); + + luci::RemoveDuplicateConstPass pass; + while (pass.run(g.g())) + ; + + uint32_t const_num = 0; + for (auto node : loco::active_nodes(loco::output_nodes(g.g()))) + { + auto target_node = dynamic_cast<luci::CircleConst *>(node); + if (target_node != nullptr) + const_num++; + } + + ASSERT_EQ(const_num, 2); +} diff --git a/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.cpp b/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.cpp index 741b70956..07457c1e8 100644 --- a/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.cpp +++ b/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.cpp @@ -64,6 +64,40 @@ luci::CircleNode *fromActivation(luci::CircleNode *inp, luci::FusedActFunc act) } } +// Create CircleReshape where +// - dtype is same with node +// - shape is same with node +// NOTE: User should set input(tensor) of the returned Op. +luci::CircleReshape *create_reshape(luci::CircleFullyConnected *node) +{ + assert(node); // FIX_CALLER_UNLESS + + auto g = node->graph(); + + auto reshape = g->nodes()->create<luci::CircleReshape>(); + reshape->name(node->name() + "/reshape"); + reshape->dtype(node->dtype()); + luci::add_origin(reshape, luci::get_origin(node)); + + auto shape_const = g->nodes()->create<luci::CircleConst>(); + shape_const->dtype(loco::DataType::S32); + shape_const->rank(1); + shape_const->dim(0).set(node->rank()); + shape_const->size<loco::DataType::S32>(node->rank()); + for (uint32_t i = 0; i < node->rank(); i++) + { + assert(node->dim(i).known()); // FIX_CALLER_UNLESS + shape_const->at<loco::DataType::S32>(i) = node->dim(i).value(); + } + shape_const->shape_status(luci::ShapeStatus::VALID); + shape_const->name(node->name() + "/shape"); + luci::add_origin(shape_const, luci::get_origin(node)); + + reshape->shape(shape_const); + + return reshape; +} + /** * Replace Fully Connected with Batched MatMul * @@ -79,19 +113,23 @@ luci::CircleNode *fromActivation(luci::CircleNode *inp, luci::FusedActFunc act) * * [Node1] [Node2] * \ / - * [BatchMatMul] [BiasValue]? + * [BatchMatMul] + * | + * [Reshape] [BiasValue]? * \ / * [Add]? * | * [Activation]? * * Nodes with "?" denote optional elements + * NOTE Reshape Op is inserted to keep the original shape of FullyConnected Op + * Reshape Op can be redundant (input shape == output shape). This can be removed + * by RemoveUnnecessaryReshapePass. */ bool replace_fc_with_matmul(luci::CircleFullyConnected *fc) { luci::CircleNode *x = nullptr; luci::CircleNode *y = nullptr; - luci::CircleNode *b = nullptr; luci::CircleTranspose *ty = nullptr; luci::CircleTranspose *tx = nullptr; bool adj_x = false; @@ -122,10 +160,13 @@ bool replace_fc_with_matmul(luci::CircleFullyConnected *fc) x = loco::must_cast<luci::CircleNode *>(fc->input()); } - b = loco::must_cast<luci::CircleNode *>(fc->bias()); + if (x->dtype() != loco::DataType::FLOAT32 || y->dtype() != loco::DataType::FLOAT32) + return false; - if (x->dtype() != loco::DataType::FLOAT32 || y->dtype() != loco::DataType::FLOAT32 || - b->dtype() != loco::DataType::FLOAT32) + auto bc = dynamic_cast<luci::CircleConst *>(fc->bias()); + // NOTE bias can be empty as CircleOutputExclude type + // NOTE we can only handle bias as FLOAT32 type as of now + if (nullptr != bc && bc->dtype() != loco::DataType::FLOAT32) return false; auto name = fc->name(); @@ -141,6 +182,9 @@ bool replace_fc_with_matmul(luci::CircleFullyConnected *fc) luci::add_origin(matmul, luci::get_origin(fc)); + auto reshape = create_reshape(fc); + reshape->tensor(matmul); + auto all_zero = [](const luci::CircleConst *c) { bool ac = true; for (uint32_t i = 0; i < c->size<loco::DataType::FLOAT32>() && ac; i++) @@ -150,12 +194,11 @@ bool replace_fc_with_matmul(luci::CircleFullyConnected *fc) return ac; }; - auto bc = dynamic_cast<luci::CircleConst *>(b); - if ((nullptr != bc) && !all_zero(bc)) + if (nullptr != bc && !all_zero(bc)) { auto bias_add = fc->graph()->nodes()->create<luci::CircleAdd>(); - bias_add->x(matmul); - bias_add->y(b); + bias_add->x(reshape); + bias_add->y(bc); bias_add->name(fc->name() + "/bias_add"); bias_add->dtype(fc->dtype()); add_origin(bias_add, get_origin(fc)); @@ -164,7 +207,8 @@ bool replace_fc_with_matmul(luci::CircleFullyConnected *fc) } else { - auto n = fromActivation(matmul, fc->fusedActivationFunction()); + // NOTE bias doesn't exist or bias is all zero + auto n = fromActivation(reshape, fc->fusedActivationFunction()); add_origin(n, luci::get_origin(fc)); n->name(fc->name() + "fusedActivation"); n->dtype(fc->dtype()); diff --git a/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp b/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp index 7606a6125..93024f3f7 100644 --- a/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp +++ b/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp @@ -159,8 +159,8 @@ TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, simple_test) auto ret = pass.run(g.g()); EXPECT_EQ(true, ret); - auto mm = dynamic_cast<luci::CircleBatchMatMul *>(g.output()->from()); - EXPECT_NE(nullptr, mm); + auto res = dynamic_cast<luci::CircleReshape *>(g.output()->from()); + EXPECT_NE(nullptr, res); } TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, nonzero_bias_test) diff --git a/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp b/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp index 1e8f681c8..f61882796 100644 --- a/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp +++ b/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp @@ -153,7 +153,6 @@ bool resolve_matmul(luci::CircleCustom *cop) } auto empty_bias = graph->nodes()->create<luci::CircleOutputExclude>(); - empty_bias->dtype(loco::DataType::FLOAT32); // Needed for type inference auto fc_node = graph->nodes()->create<luci::CircleFullyConnected>(); fc_node->input(lhs); diff --git a/compiler/luci/pass/src/ResolveCustomOpMaxPoolWithArgmaxPass.cpp b/compiler/luci/pass/src/ResolveCustomOpMaxPoolWithArgmaxPass.cpp index f37f27742..7c038d56d 100644 --- a/compiler/luci/pass/src/ResolveCustomOpMaxPoolWithArgmaxPass.cpp +++ b/compiler/luci/pass/src/ResolveCustomOpMaxPoolWithArgmaxPass.cpp @@ -23,6 +23,7 @@ #include <loco.h> #include <oops/InternalExn.h> +#include <limits> // std::numeric_limits #include <flatbuffers/flexbuffers.h> diff --git a/compiler/luci/pass/src/ResolveCustomOpSplitVPass.cpp b/compiler/luci/pass/src/ResolveCustomOpSplitVPass.cpp index a65065800..5a09e3930 100644 --- a/compiler/luci/pass/src/ResolveCustomOpSplitVPass.cpp +++ b/compiler/luci/pass/src/ResolveCustomOpSplitVPass.cpp @@ -20,6 +20,8 @@ #include <luci/Profile/CircleNodeOrigin.h> #include <luci/Service/Nodes/CircleConst.h> +#include <limits> // std::numeric_limits + namespace { diff --git a/compiler/luci/pass/src/UnrollUnidirectionalSequenceLSTMPass.cpp b/compiler/luci/pass/src/UnrollUnidirectionalSequenceLSTMPass.cpp new file mode 100644 index 000000000..b73efafa5 --- /dev/null +++ b/compiler/luci/pass/src/UnrollUnidirectionalSequenceLSTMPass.cpp @@ -0,0 +1,672 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/UnrollUnidirectionalSequenceLSTMPass.h" + +#include "helpers/NodeFiller.h" +#include "helpers/TypeMapper.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/Profile/CircleNodeOrigin.h> + +#include <string> +#include <vector> + +/** + * BEFORE + * [CircleNode] + * | + * [UnidirectionalSequenceLSTM] + * | + * [CircleNode] + * + * AFTER + * + * [CircleNode] + * | + * [CircleTranspose] + * | + * [CircleUnpack] + * | + * [CircleUnpackOut] + * | + * (Unrolled sub network) + * | + * [CirclePack] + * | | + * [CircleTranspose] [UnidirectionalSequenceLSTM] + * | | + * [CircleNode] + * + * NOTE for timesteps = 1, + * first [CircleTranspose] is not added and + * last [CirclePack] + [CircleTranspose] is replaced with [CircleReshape] + * + * First unrolled sub network is as follows + * - [] and 'Circle' are omitted + * - all FC has one or two Const for Weight/Bias + * + * (input) + * | + * FC + * | + * Split + * +---------+----------+----------+ + * | | | | + * | Logistic Logistic Tanh + * | Const | | | + * | | | | | + * | +-- Mul +-- Mul ---+ + * | | | + * | +---- Add ------+ + * | | + * | +----+----+ + * | | | + * Logistic Tanh | + * | | | + * +-- Mul ----+ | + * | | + * (output) (A) + * + * and following unrolled sub networks are; + * + * (prev-output) (input) + * | | + * FC FC + * | | + * +--- Add --+ + * Const | + * | | + * +------ Add + * | + * Split + * | + * +---------+----------+----------+ + * SplitOut SplitOut SplitOut SplitOut + * | | | | + * | Logistic Logistic Tanh + * | (A') | | | + * | | | | | + * | +--- Mul +-- Mul ---+ + * | | | + * | +---- Add ------+ + * | | + * | +----+----+ + * | | | + * Logistic Tanh | + * | | | + * +-- Mul ----+ | + * | | + * (output) (next) + * + * where (A) and (A') are connected + * + */ + +namespace +{ + +struct UnrollLSTM +{ + luci::CircleConst *transpose_perm(void); + luci::CircleTranspose *first_transpose(luci::CircleNode *input); + std::vector<luci::CircleUnpackOut *> input_unpacks(luci::CircleNode *input); + luci::CircleConst *merged_weights(luci::CircleConst *iw, luci::CircleConst *fw, + luci::CircleConst *cw, luci::CircleConst *ow); + luci::CircleFullyConnected *create_input_matmul(luci::CircleNode *input); + luci::CircleAdd *create_input_matmul(luci::CircleNode *input, luci::CircleMul *mul, + uint32_t step); + std::vector<luci::CircleSplitOut *> matmul_splits(luci::CircleNode *input, uint32_t step); + luci::CircleConst *forget_zero(void); + luci::CircleMul *forget_gate_cell(std::vector<luci::CircleSplitOut *> &splits, + luci::CircleNode *prev, uint32_t step, + luci::CircleNode **retadd); + luci::CircleReshape *last_reshape(luci::CircleNode *input); + luci::CircleTranspose *last_transpose(std::vector<luci::CircleMul *> &output_muls); + + luci::CircleUnidirectionalSequenceLSTM *_lstm{nullptr}; + loco::Graph::NodeContext *_nctx{nullptr}; + std::string _name; + uint32_t _batch{0}; + uint32_t _timesteps{0}; + uint32_t _units{0}; // output space dim +}; + +luci::CircleConst *UnrollLSTM::transpose_perm(void) +{ + auto perm = _nctx->create<luci::CircleConst>(); + perm->dtype(loco::DataType::S32); + perm->rank(1); + perm->dim(0) = 3; + perm->size<loco::DataType::S32>(3); + perm->at<loco::DataType::S32>(0) = 1; + perm->at<loco::DataType::S32>(1) = 0; + perm->at<loco::DataType::S32>(2) = 2; + perm->shape_status(luci::ShapeStatus::VALID); + + return perm; +} + +luci::CircleTranspose *UnrollLSTM::first_transpose(luci::CircleNode *input) +{ + assert(input != nullptr); + + auto perm = transpose_perm(); + perm->name(_name + "_perm1"); + luci::add_origin(perm, luci::get_origin(_lstm)); + + auto transpose = _nctx->create<luci::CircleTranspose>(); + transpose->a(input); + transpose->perm(perm); + transpose->name(_name + "_trans1"); + luci::add_origin(transpose, luci::get_origin(_lstm)); + + return transpose; +} + +std::vector<luci::CircleUnpackOut *> UnrollLSTM::input_unpacks(luci::CircleNode *input) +{ + assert(input != nullptr); + + // NOTE unpack input can be LSTM or Transpose + auto unpack = _nctx->create<luci::CircleUnpack>(); + unpack->num(_timesteps); + unpack->axis(0); + unpack->value(input); + unpack->name(_name + "_unpack"); + luci::add_origin(unpack, luci::get_origin(_lstm)); + + std::vector<luci::CircleUnpackOut *> outs; + for (uint32_t idx = 0; idx < _timesteps; ++idx) + { + auto unpackout = _nctx->create<luci::CircleUnpackOut>(); + unpackout->input(unpack); + unpackout->index(idx); + unpackout->name(_name + "_unpackout_" + std::to_string(idx)); + luci::add_origin(unpackout, luci::get_origin(_lstm)); + outs.push_back(unpackout); + } + + return outs; +} + +luci::CircleConst *UnrollLSTM::merged_weights(luci::CircleConst *iw, luci::CircleConst *fw, + luci::CircleConst *cw, luci::CircleConst *ow) +{ + assert(iw != nullptr); + assert(fw != nullptr); + assert(cw != nullptr); + assert(ow != nullptr); + + auto iw_rank = iw->rank(); + assert(iw_rank == fw->rank()); + assert(iw_rank == cw->rank()); + assert(iw_rank == ow->rank()); + + uint32_t ne_w = 1; + for (uint32_t i = 0; i < iw_rank; i++) + ne_w *= iw->dim(i).value(); + + assert(iw->dtype() == loco::DataType::FLOAT32); + assert(fw->dtype() == loco::DataType::FLOAT32); + assert(cw->dtype() == loco::DataType::FLOAT32); + assert(ow->dtype() == loco::DataType::FLOAT32); + + // merged weights + auto mw = _nctx->create<luci::CircleConst>(); + mw->dtype(iw->dtype()); + mw->rank(iw_rank); + mw->dim(0) = 4u * iw->dim(0).value(); + for (uint32_t i = 1; i < iw_rank; i++) + mw->dim(i) = iw->dim(i); + mw->size<loco::DataType::FLOAT32>(4 * ne_w); + mw->shape_status(luci::ShapeStatus::VALID); + for (uint32_t i = 0; i < ne_w; ++i) + { + mw->at<loco::DataType::FLOAT32>(i + ne_w * 0) = iw->at<loco::DataType::FLOAT32>(i); + mw->at<loco::DataType::FLOAT32>(i + ne_w * 1) = fw->at<loco::DataType::FLOAT32>(i); + mw->at<loco::DataType::FLOAT32>(i + ne_w * 2) = cw->at<loco::DataType::FLOAT32>(i); + mw->at<loco::DataType::FLOAT32>(i + ne_w * 3) = ow->at<loco::DataType::FLOAT32>(i); + } + return mw; +} + +luci::CircleFullyConnected *UnrollLSTM::create_input_matmul(luci::CircleNode *input) +{ + assert(input != nullptr); + + // weights + auto iw = loco::must_cast<luci::CircleConst *>(_lstm->input_to_input_weights()); + auto fw = loco::must_cast<luci::CircleConst *>(_lstm->input_to_forget_weights()); + auto cw = loco::must_cast<luci::CircleConst *>(_lstm->input_to_cell_weights()); + auto ow = loco::must_cast<luci::CircleConst *>(_lstm->input_to_output_weights()); + + auto fcw = merged_weights(iw, fw, cw, ow); + fcw->name(_name + "_fc_w"); + luci::add_origin(fcw, luci::get_origin(_lstm)); + + // bias + auto ib = loco::must_cast<luci::CircleConst *>(_lstm->input_gate_bias()); + auto fb = loco::must_cast<luci::CircleConst *>(_lstm->forget_gate_bias()); + auto cb = loco::must_cast<luci::CircleConst *>(_lstm->cell_gate_bias()); + auto ob = loco::must_cast<luci::CircleConst *>(_lstm->output_gate_bias()); + + auto fcb = merged_weights(ib, fb, cb, ob); + fcb->name(_name + "_fc_b"); + luci::add_origin(fcb, luci::get_origin(_lstm)); + + auto fc = _nctx->create<luci::CircleFullyConnected>(); + fc->input(input); + fc->weights(fcw); + fc->bias(fcb); + fc->fusedActivationFunction(luci::FusedActFunc::NONE); + fc->name(_name + "_fc"); + luci::add_origin(fc, luci::get_origin(_lstm)); + + return fc; +} + +luci::CircleAdd *UnrollLSTM::create_input_matmul(luci::CircleNode *input, luci::CircleMul *mul, + uint32_t step) +{ + assert(input != nullptr); + assert(mul != nullptr); + assert(step < _timesteps); + + auto base_name = _name + "_matmul" + std::to_string(step); + + // input weights + auto iw = loco::must_cast<luci::CircleConst *>(_lstm->input_to_input_weights()); + auto fw = loco::must_cast<luci::CircleConst *>(_lstm->input_to_forget_weights()); + auto cw = loco::must_cast<luci::CircleConst *>(_lstm->input_to_cell_weights()); + auto ow = loco::must_cast<luci::CircleConst *>(_lstm->input_to_output_weights()); + + auto fcw = merged_weights(iw, fw, cw, ow); + fcw->name(base_name + "_fc_w"); + luci::add_origin(fcw, luci::get_origin(_lstm)); + + auto fcb = _nctx->create<luci::CircleOutputExclude>(); + + auto fc = _nctx->create<luci::CircleFullyConnected>(); + fc->input(input); + fc->weights(fcw); + fc->bias(fcb); + fc->fusedActivationFunction(luci::FusedActFunc::NONE); + fc->name(base_name + "_fc"); + luci::add_origin(fc, luci::get_origin(_lstm)); + + // recurrent weights + auto ri = loco::must_cast<luci::CircleConst *>(_lstm->recurrent_to_input_weights()); + auto rf = loco::must_cast<luci::CircleConst *>(_lstm->recurrent_to_forget_weights()); + auto rc = loco::must_cast<luci::CircleConst *>(_lstm->recurrent_to_cell_weights()); + auto ro = loco::must_cast<luci::CircleConst *>(_lstm->recurrent_to_output_weights()); + + auto fcrw = merged_weights(ri, rf, rc, ro); + fcrw->name(base_name + "_fcr_w"); + luci::add_origin(fcrw, luci::get_origin(_lstm)); + + auto fcrb = _nctx->create<luci::CircleOutputExclude>(); + + auto fcr = _nctx->create<luci::CircleFullyConnected>(); + fcr->input(mul); + fcr->weights(fcrw); + fcr->bias(fcrb); + fcr->fusedActivationFunction(luci::FusedActFunc::NONE); + fcr->name(base_name + "_fcr"); + luci::add_origin(fcr, luci::get_origin(_lstm)); + + auto add_fc = _nctx->create<luci::CircleAdd>(); + add_fc->x(fcr); + add_fc->y(fc); + add_fc->fusedActivationFunction(luci::FusedActFunc::NONE); + add_fc->name(base_name + "_addfc"); + luci::add_origin(add_fc, luci::get_origin(_lstm)); + + // bias + auto ib = loco::must_cast<luci::CircleConst *>(_lstm->input_gate_bias()); + auto fb = loco::must_cast<luci::CircleConst *>(_lstm->forget_gate_bias()); + auto cb = loco::must_cast<luci::CircleConst *>(_lstm->cell_gate_bias()); + auto ob = loco::must_cast<luci::CircleConst *>(_lstm->output_gate_bias()); + + auto bias = merged_weights(ib, fb, cb, ob); + bias->name(base_name + "_bias"); + + auto add_bias = _nctx->create<luci::CircleAdd>(); + add_bias->x(add_fc); + add_bias->y(bias); + add_bias->fusedActivationFunction(luci::FusedActFunc::NONE); + add_bias->name(base_name + "_addbias"); + luci::add_origin(add_bias, luci::get_origin(_lstm)); + + return add_bias; +} + +std::vector<luci::CircleSplitOut *> UnrollLSTM::matmul_splits(luci::CircleNode *input, + uint32_t step) +{ + assert(input != nullptr); + assert(step < _timesteps); + + std::string split_name = _name + "_sp" + std::to_string(step); + + auto split_dim = _nctx->create<luci::CircleConst>(); + split_dim->dtype(loco::DataType::S32); + split_dim->rank(1); + split_dim->dim(0) = 1; + split_dim->size<loco::DataType::S32>(1); + split_dim->at<loco::DataType::S32>(0) = 1; + split_dim->shape_status(luci::ShapeStatus::VALID); + split_dim->name(split_name + "_dim"); + luci::add_origin(split_dim, luci::get_origin(_lstm)); + + auto split = _nctx->create<luci::CircleSplit>(); + split->num_split(4); + split->split_dim(split_dim); + split->input(input); + split->name(split_name); + luci::add_origin(split, luci::get_origin(_lstm)); + + auto split_o0 = _nctx->create<luci::CircleSplitOut>(); + split_o0->input(split); + split_o0->index(0); + split_o0->name(split_name + "_spo0"); + luci::add_origin(split_o0, luci::get_origin(_lstm)); + + auto split_o1 = _nctx->create<luci::CircleSplitOut>(); + split_o1->input(split); + split_o1->index(1); + split_o1->name(split_name + "_spo1"); + luci::add_origin(split_o1, luci::get_origin(_lstm)); + + auto split_o2 = _nctx->create<luci::CircleSplitOut>(); + split_o2->input(split); + split_o2->index(2); + split_o2->name(split_name + "_spo2"); + luci::add_origin(split_o2, luci::get_origin(_lstm)); + + auto split_o3 = _nctx->create<luci::CircleSplitOut>(); + split_o3->input(split); + split_o3->index(3); + split_o3->name(split_name + "_spo3"); + luci::add_origin(split_o3, luci::get_origin(_lstm)); + + std::vector<luci::CircleSplitOut *> outs; + outs.push_back(split_o0); + outs.push_back(split_o1); + outs.push_back(split_o2); + outs.push_back(split_o3); + return outs; +} + +luci::CircleConst *UnrollLSTM::forget_zero(void) +{ + uint32_t amount = _batch * _units; + + auto zero = _nctx->create<luci::CircleConst>(); + zero->dtype(loco::DataType::FLOAT32); + zero->rank(2); + zero->dim(0) = _batch; + zero->dim(1) = _units; + zero->size<loco::DataType::FLOAT32>(amount); + for (uint32_t idx = 0; idx < amount; ++idx) + zero->at<loco::DataType::FLOAT32>(idx) = 0.0; + zero->shape_status(luci::ShapeStatus::VALID); + zero->name(_name + "_zero"); + luci::add_origin(zero, luci::get_origin(_lstm)); + return zero; +} + +luci::CircleMul *UnrollLSTM::forget_gate_cell(std::vector<luci::CircleSplitOut *> &splits, + luci::CircleNode *prev, uint32_t step, + luci::CircleNode **retadd) +{ + assert(splits.size() > 0); + assert(prev != nullptr); + assert(step < _timesteps); + + std::string net_name = _name + "_net" + std::to_string(step); + + auto split_0 = splits[0]; // input-input : Logistic - Mul(c) - Add - Tanh - Mul + auto split_1 = splits[1]; // input-forget : Logistic - Mul(p) - Add - Tanh - Mul + auto split_2 = splits[2]; // input-cell : Tanh - Mul(c) - Add - Tanh - Mul + auto split_3 = splits[3]; // input-output : Logistic - Mul + + auto logis_0 = _nctx->create<luci::CircleLogistic>(); + logis_0->x(split_0); + logis_0->name(net_name + "_log0"); + luci::add_origin(logis_0, luci::get_origin(_lstm)); + + auto logis_1 = _nctx->create<luci::CircleLogistic>(); + logis_1->x(split_1); + logis_1->name(net_name + "_log1"); + luci::add_origin(logis_1, luci::get_origin(_lstm)); + + auto tanh_2 = _nctx->create<luci::CircleTanh>(); + tanh_2->x(split_2); + tanh_2->name(net_name + "_tanh2"); + luci::add_origin(tanh_2, luci::get_origin(_lstm)); + + auto logis_3 = _nctx->create<luci::CircleLogistic>(); + logis_3->x(split_3); + logis_3->name(net_name + "_log3"); + luci::add_origin(logis_3, luci::get_origin(_lstm)); + + auto mul_c = _nctx->create<luci::CircleMul>(); + mul_c->x(logis_0); + mul_c->y(tanh_2); + mul_c->fusedActivationFunction(luci::FusedActFunc::NONE); + mul_c->name(net_name + "_mul1"); + luci::add_origin(mul_c, luci::get_origin(_lstm)); + + auto mul_p = _nctx->create<luci::CircleMul>(); + mul_p->x(logis_1); + mul_p->y(prev); + mul_p->fusedActivationFunction(luci::FusedActFunc::NONE); + mul_p->name(net_name + "_mul2"); + luci::add_origin(mul_p, luci::get_origin(_lstm)); + + auto add_cp = _nctx->create<luci::CircleAdd>(); + add_cp->x(mul_c); + add_cp->y(mul_p); + add_cp->fusedActivationFunction(luci::FusedActFunc::NONE); + add_cp->name(net_name + "_add1"); + luci::add_origin(add_cp, luci::get_origin(_lstm)); + + if (retadd != nullptr) + *retadd = add_cp; + + auto tanh_cp = _nctx->create<luci::CircleTanh>(); + tanh_cp->x(add_cp); + tanh_cp->name(net_name + "_tanh3"); + luci::add_origin(tanh_cp, luci::get_origin(_lstm)); + + auto mul_out = _nctx->create<luci::CircleMul>(); + mul_out->x(logis_3); + mul_out->y(tanh_cp); + mul_out->fusedActivationFunction(luci::FusedActFunc::NONE); + mul_out->name(net_name + "_mul3"); + luci::add_origin(mul_out, luci::get_origin(_lstm)); + + return mul_out; +} + +luci::CircleReshape *UnrollLSTM::last_reshape(luci::CircleNode *input) +{ + assert(input != nullptr); + + auto reshape_s = _nctx->create<luci::CircleConst>(); + reshape_s->dtype(loco::DataType::S32); + reshape_s->rank(1); + reshape_s->dim(0) = 3; + reshape_s->size<loco::DataType::S32>(3); + reshape_s->at<loco::DataType::S32>(0) = _batch; + reshape_s->at<loco::DataType::S32>(1) = _timesteps; + reshape_s->at<loco::DataType::S32>(2) = _units; + reshape_s->shape_status(luci::ShapeStatus::VALID); + reshape_s->name(_name + "_reshape_s"); + luci::add_origin(reshape_s, luci::get_origin(_lstm)); + + auto reshape = _nctx->create<luci::CircleReshape>(); + reshape->tensor(input); + reshape->shape(reshape_s); + reshape->newShape()->rank(3); + reshape->newShape()->dim(0) = _batch; + reshape->newShape()->dim(1) = _timesteps; + reshape->newShape()->dim(2) = _units; + reshape->name(_name + "_reshape"); + luci::add_origin(reshape, luci::get_origin(_lstm)); + + return reshape; +} + +luci::CircleTranspose *UnrollLSTM::last_transpose(std::vector<luci::CircleMul *> &output_muls) +{ + assert(output_muls.size() == _timesteps); + + auto pack = _nctx->create<luci::CirclePack>(_timesteps); + pack->axis(0); + for (uint32_t idx = 0; idx < _timesteps; ++idx) + pack->values(idx, output_muls[idx]); + pack->name(_name + "_pack"); + luci::add_origin(pack, luci::get_origin(_lstm)); + + auto perm = transpose_perm(); + perm->name(_name + "_perm2"); + luci::add_origin(perm, luci::get_origin(_lstm)); + + auto transpose = _nctx->create<luci::CircleTranspose>(); + transpose->a(pack); + transpose->perm(perm); + transpose->name(_name + "_trans2"); + luci::add_origin(transpose, luci::get_origin(_lstm)); + + return transpose; +} + +bool unroll_lstm(luci::CircleUnidirectionalSequenceLSTM *lstm) +{ + // NOTE shape of input of lstm is interpreted as [batch, timesteps, feature] + // shape of output of lstm is interpreted as [batch, timesteps, units] + // TODO add more conditions to check LSTM + assert(lstm != nullptr); + assert(lstm->rank() == 3); // use assert to findout when this happens + if (lstm->rank() != 3) + return false; + if (!(lstm->dim(0).known() and lstm->dim(1).known() and lstm->dim(2).known())) + return false; + + UnrollLSTM ulstm; + ulstm._lstm = lstm; + ulstm._nctx = lstm->graph()->nodes(); + ulstm._name = lstm->name(); + ulstm._batch = lstm->dim(0).value(); + ulstm._timesteps = lstm->dim(1).value(); + ulstm._units = lstm->dim(2).value(); // output space dim + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(lstm->input()); + assert(input->rank() == 3); // use assert to findout when this happens + if (input->rank() != 3) + return false; + assert(input->dim(0).value() == ulstm._batch); + assert(input->dim(1).value() == ulstm._timesteps); + + if (ulstm._timesteps > 1) + { + // Transpose to switch batch <-> timesteps + // NOTE TF uses Reshape when batch is 1 but as there is Transpose->Reshape + // Pass, we can just use Transpose for both cases + auto transpose = ulstm.first_transpose(input); + input = transpose; + } + + auto unpacks = ulstm.input_unpacks(input); + assert(unpacks.size() == ulstm._timesteps); + uint32_t step = 0; + auto unpackout = unpacks[step]; + + // First FC + auto fc_1 = ulstm.create_input_matmul(unpackout); + assert(fc_1 != nullptr); + auto splits = ulstm.matmul_splits(fc_1, step); + assert(splits.size() == 4); + + luci::CircleNode *prev = nullptr; // prev step CircleAdd + luci::CircleNode *this_add = nullptr; + + prev = ulstm.forget_zero(); // provide all zero constant for first step + + std::vector<luci::CircleMul *> output_muls; + auto mul_gc = ulstm.forget_gate_cell(splits, prev, step, &this_add); + assert(mul_gc != nullptr); + assert(this_add != nullptr); + // gather all Muls for last Pack + output_muls.push_back(mul_gc); + + for (step = 1; step < ulstm._timesteps; ++step) + { + auto unpackout = unpacks[step]; + auto add_n = ulstm.create_input_matmul(unpackout, mul_gc, step); + + auto splits = ulstm.matmul_splits(add_n, step); + assert(splits.size() == 4); + + prev = this_add; + mul_gc = ulstm.forget_gate_cell(splits, prev, step, &this_add); + assert(mul_gc != nullptr); + assert(this_add != nullptr); + + output_muls.push_back(mul_gc); + } + assert(output_muls.size() == ulstm._timesteps); + + if (ulstm._timesteps == 1) + { + // Reshape for single step + auto reshape = ulstm.last_reshape(mul_gc); + loco::replace(lstm).with(reshape); + } + else + { + // Pack + Transpose for two or more steps + auto transpose = ulstm.last_transpose(output_muls); + loco::replace(lstm).with(transpose); + } + + return true; +} + +} // namespace + +namespace luci +{ + +bool UnrollUnidirectionalSequenceLSTMPass::run(loco::Graph *g) +{ + bool changed = false; + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto lstm = dynamic_cast<luci::CircleUnidirectionalSequenceLSTM *>(node)) + { + if (unroll_lstm(lstm)) + changed = true; + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/UnrollUnidirectionalSequenceLSTMPass.test.cpp b/compiler/luci/pass/src/UnrollUnidirectionalSequenceLSTMPass.test.cpp new file mode 100644 index 000000000..3f273cbd3 --- /dev/null +++ b/compiler/luci/pass/src/UnrollUnidirectionalSequenceLSTMPass.test.cpp @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/UnrollUnidirectionalSequenceLSTMPass.h" + +#include <luci/test/TestIOGraph.h> + +#include <luci/IR/Nodes/CircleUnidirectionalSequenceLSTM.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class UniSeqLSTMGraphlet +{ +public: + UniSeqLSTMGraphlet() = default; + + void init(loco::Graph *g, const ShapeU32 oshape) + { + _uslstm = g->nodes()->create<luci::CircleUnidirectionalSequenceLSTM>(); + _uslstm->dtype(loco::DataType::FLOAT32); + _uslstm->shape(oshape); + _uslstm->name("uslstm"); + + _uslstm->fusedActivationFunction(luci::FusedActFunc::TANH); + _uslstm->cell_clip(0.0); + _uslstm->proj_clip(0.0); + _uslstm->time_major(false); + _uslstm->asymmetric_quantize_inputs(false); + + _iw = weight_1x1(g); + _rw = weight_1x1(g); + _gb = weight_1(g); + _ex = g->nodes()->create<luci::CircleOutputExclude>(); + } + +protected: + luci::CircleConst *weight_1x1(loco::Graph *g) + { + auto w = g->nodes()->create<luci::CircleConst>(); + w->dtype(loco::DataType::FLOAT32); + w->rank(2); + w->dim(0) = 1; + w->dim(1) = 1; + w->size<loco::DataType::FLOAT32>(1); + w->at<loco::DataType::FLOAT32>(0) = 1.0; + w->shape_status(luci::ShapeStatus::VALID); + return w; + } + + luci::CircleConst *weight_1(loco::Graph *g) + { + auto w = g->nodes()->create<luci::CircleConst>(); + w->dtype(loco::DataType::FLOAT32); + w->rank(1); + w->dim(0) = 1; + w->size<loco::DataType::FLOAT32>(1); + w->at<loco::DataType::FLOAT32>(0) = 1.0; + w->shape_status(luci::ShapeStatus::VALID); + return w; + } + +protected: + luci::CircleUnidirectionalSequenceLSTM *_uslstm = nullptr; + luci::CircleConst *_iw = nullptr; + luci::CircleConst *_rw = nullptr; + luci::CircleConst *_gb = nullptr; + luci::CircleOutputExclude *_ex = nullptr; +}; + +class UnrollUniSeqLSTMPassTestGraph : public TestIOGraph, public UniSeqLSTMGraphlet +{ +public: + UnrollUniSeqLSTMPassTestGraph() = default; + + void init(const ShapeU32 ishape, const ShapeU32 oshape) + { + TestIOGraph::init(ishape, oshape); + UniSeqLSTMGraphlet::init(g(), oshape); + + auto inode = input(); + _uslstm->input(inode); + + _uslstm->input_to_input_weights(_iw); + _uslstm->input_to_forget_weights(_iw); + _uslstm->input_to_cell_weights(_iw); + _uslstm->input_to_output_weights(_iw); + + _uslstm->recurrent_to_input_weights(_rw); + _uslstm->recurrent_to_forget_weights(_rw); + _uslstm->recurrent_to_cell_weights(_rw); + _uslstm->recurrent_to_output_weights(_rw); + + _uslstm->cell_to_input_weights(_ex); + _uslstm->cell_to_forget_weights(_ex); + _uslstm->cell_to_output_weights(_ex); + + _uslstm->input_gate_bias(_gb); + _uslstm->forget_gate_bias(_gb); + _uslstm->cell_gate_bias(_gb); + _uslstm->output_gate_bias(_gb); + + _uslstm->projection_weights(_ex); + _uslstm->projection_bias(_ex); + + _uslstm->output_state(_ex); + _uslstm->cell_state(_ex); + + _uslstm->input_layer_norm_coefficients(_ex); + _uslstm->forget_layer_norm_coefficients(_ex); + _uslstm->cell_layer_norm_coefficients(_ex); + _uslstm->output_layer_norm_coefficients(_ex); + + output()->from(_uslstm); + } +}; + +} // namespace + +namespace +{ + +using namespace luci::test; + +// FakeQuantGraphlet is for simple negative test +class FakeQuantGraphlet +{ +public: + FakeQuantGraphlet() = default; + +public: + void init(loco::Graph *g) + { + _fq = g->nodes()->create<luci::CircleFakeQuant>(); + _fq->name("fq"); + } + +protected: + luci::CircleFakeQuant *_fq = nullptr; +}; + +class FakeQuantGraph : public TestIOGraph, public FakeQuantGraphlet +{ +public: + FakeQuantGraph() = default; + +public: + void init(void) + { + TestIOGraph::init({1, 1, 1}, {1, 1, 1}); + FakeQuantGraphlet::init(g()); + + _fq->inputs(input()); + + output()->from(_fq); + } +}; + +} // namespace + +TEST(UnrollUnidirectionalSequenceLSTMPassTestName, name) +{ + luci::UnrollUnidirectionalSequenceLSTMPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +class UnrollUnidirectionalSequenceLSTMPassTest : public ::testing::Test +{ +public: + UnrollUniSeqLSTMPassTestGraph g; + luci::UnrollUnidirectionalSequenceLSTMPass pass; +}; + +TEST_F(UnrollUnidirectionalSequenceLSTMPassTest, simple_run) +{ + g.init({1, 1, 1}, {1, 1, 1}); + + EXPECT_TRUE(pass.run(g.g())); +} + +class UnrollUnidirectionalSequenceLSTMPassTestN : public ::testing::Test +{ +public: + FakeQuantGraph g; + luci::UnrollUnidirectionalSequenceLSTMPass pass; +}; + +TEST_F(UnrollUnidirectionalSequenceLSTMPassTestN, simple_run_NEG) +{ + g.init(); + + EXPECT_FALSE(pass.run(g.g())); +} diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h index 408e6b8d9..6bf7ff698 100644 --- a/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h +++ b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h @@ -133,6 +133,10 @@ private: bool visit(const luci::CircleAdd *node) { + // Skip granularity check for indices + if (node->dtype() == loco::DataType::S32 or node->dtype() == loco::DataType::S64) + return true; + RETURN_FALSE_UNLESS(is_lwq(node)); RETURN_FALSE_UNLESS(is_lwq(node->x())); RETURN_FALSE_UNLESS(is_lwq(node->y())); @@ -176,6 +180,10 @@ private: bool visit(const luci::CircleMul *node) { + // Skip granularity check for indices + if (node->dtype() == loco::DataType::S32 or node->dtype() == loco::DataType::S64) + return true; + RETURN_FALSE_UNLESS(is_lwq(node)); RETURN_FALSE_UNLESS(is_lwq(node->x())); RETURN_FALSE_UNLESS(is_lwq(node->y())); diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp b/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp index cf86acabe..3ce32555b 100644 --- a/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp +++ b/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp @@ -47,6 +47,10 @@ namespace luci template <loco::DataType Qtype, loco::DataType Btype> bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleAdd *node) { + // Allow add of indices + if (group_has_type(node, loco::DataType::S32) or group_has_type(node, loco::DataType::S64)) + return true; + return group_has_type(node, Qtype); } @@ -240,6 +244,10 @@ bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleMirrorPa template <loco::DataType Qtype, loco::DataType Btype> bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleMul *node) { + // Allow mul of indices + if (group_has_type(node, loco::DataType::S32) or group_has_type(node, loco::DataType::S64)) + return true; + return group_has_type(node, Qtype); } diff --git a/compiler/luci/pass/src/helpers/NodeFiller.h b/compiler/luci/pass/src/helpers/NodeFiller.h index b80f085b0..10113e8dd 100644 --- a/compiler/luci/pass/src/helpers/NodeFiller.h +++ b/compiler/luci/pass/src/helpers/NodeFiller.h @@ -57,6 +57,12 @@ public: */ template <class COMM_NODE> bool with_commutative_args_of(const COMM_NODE *node); + /** + * @note Similar as with_commutative_args_of but not commutative. + * _arg_1 and _arg_2 must match that of ARG_TYPE_1 and ARG_TYPE_2. + */ + template <class COMM_NODE> bool with_args_of(const COMM_NODE *node); + private: ARG_TYPE_1 **_arg_1; ARG_TYPE_2 **_arg_2; @@ -101,4 +107,24 @@ bool NodeFiller<ARG_TYPE_1, ARG_TYPE_2>::with_commutative_args_of(const COMM_NOD return false; } +template <class ARG_TYPE_1, class ARG_TYPE_2> +template <class COMM_NODE> +bool NodeFiller<ARG_TYPE_1, ARG_TYPE_2>::with_args_of(const COMM_NODE *node) +{ + // X == ARG_TYPE_1 / Y == ARG_TYPE_2 + { + auto x = dynamic_cast<ARG_TYPE_1 *>(node->x()); + auto y = dynamic_cast<ARG_TYPE_2 *>(node->y()); + + if (x && y) + { + *_arg_1 = x; + *_arg_2 = y; + return true; + } + } + + return false; +} + } // namespace luci diff --git a/compiler/luci/pass/src/helpers/SparsityFormatConverter.h b/compiler/luci/pass/src/helpers/SparsityFormatConverter.h index fcd9bbcd0..e01430489 100644 --- a/compiler/luci/pass/src/helpers/SparsityFormatConverter.h +++ b/compiler/luci/pass/src/helpers/SparsityFormatConverter.h @@ -18,6 +18,7 @@ #ifndef __LUCI_PASS_HELPERS_SPARSITY_FORMAT_CONVERTER_H__ #define __LUCI_PASS_HELPERS_SPARSITY_FORMAT_CONVERTER_H__ +#include <cstddef> #include <cstdint> #include <vector> diff --git a/compiler/luci/pass/src/helpers/Strings.cpp b/compiler/luci/pass/src/helpers/Strings.cpp index d020f6ddc..2628726c1 100644 --- a/compiler/luci/pass/src/helpers/Strings.cpp +++ b/compiler/luci/pass/src/helpers/Strings.cpp @@ -77,6 +77,15 @@ loco::DataType str_to_dtype(const std::string &str) return loco::DataType::Unknown; } +// Convert string to a vector of loco::DataType +std::vector<loco::DataType> str_vec_to_dtype_vec(std::vector<std::string> &vec) +{ + std::vector<loco::DataType> res; + std::transform(vec.begin(), vec.end(), std::back_inserter(res), + [](std::string s) -> loco::DataType { return str_to_dtype(to_lower_case(s)); }); + return res; +} + QuantizationGranularity str_to_granularity(const std::string &str) { if (to_lower_case(str).compare("layer") == 0) diff --git a/compiler/luci/pass/src/helpers/Strings.h b/compiler/luci/pass/src/helpers/Strings.h index 0e7818517..485f37948 100644 --- a/compiler/luci/pass/src/helpers/Strings.h +++ b/compiler/luci/pass/src/helpers/Strings.h @@ -36,6 +36,8 @@ std::string to_lower_case(std::string); loco::DataType str_to_dtype(const std::string &); +std::vector<loco::DataType> str_vec_to_dtype_vec(std::vector<std::string> &); + QuantizationGranularity str_to_granularity(const std::string &); } // namespace luci diff --git a/compiler/luci/pass/src/helpers/Strings.test.cpp b/compiler/luci/pass/src/helpers/Strings.test.cpp index d77b65038..6d854ad4f 100644 --- a/compiler/luci/pass/src/helpers/Strings.test.cpp +++ b/compiler/luci/pass/src/helpers/Strings.test.cpp @@ -48,3 +48,26 @@ TEST(StringsTest, str_to_granularity) EXPECT_THROW(luci::str_to_granularity("foo"), std::runtime_error); } + +TEST(StringsTest, str_vec_to_dtype_vec) +{ + std::vector<std::string> input1 = {"uint8", "int16", "float32"}; + auto result1 = luci::str_vec_to_dtype_vec(input1); + ASSERT_EQ(3, result1.size()); + ASSERT_EQ(loco::DataType::U8, result1[0]); + ASSERT_EQ(loco::DataType::S16, result1[1]); + ASSERT_EQ(loco::DataType::FLOAT32, result1[2]); + + std::vector<std::string> input2 = {"uint8", "int16", "float32", ""}; + auto result2 = luci::str_vec_to_dtype_vec(input2); + ASSERT_EQ(4, result2.size()); + ASSERT_EQ(loco::DataType::U8, result2[0]); + ASSERT_EQ(loco::DataType::S16, result2[1]); + ASSERT_EQ(loco::DataType::FLOAT32, result2[2]); + ASSERT_EQ(loco::DataType::Unknown, result2[3]); + + std::vector<std::string> input3 = {"uint8"}; + auto result3 = luci::str_vec_to_dtype_vec(input3); + ASSERT_EQ(1, result3.size()); + ASSERT_EQ(loco::DataType::U8, result3[0]); +} diff --git a/compiler/luci/pass/src/test/TestIOGraph.h b/compiler/luci/pass/src/test/TestIOGraph.h deleted file mode 100644 index b1fc41f90..000000000 --- a/compiler/luci/pass/src/test/TestIOGraph.h +++ /dev/null @@ -1,161 +0,0 @@ -/* - * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __LUCI_PASS_TEST_IO_GRAPH_H__ -#define __LUCI_PASS_TEST_IO_GRAPH_H__ - -#include "TestShape.h" - -#include <luci/IR/CircleNodes.h> - -namespace luci -{ -namespace test -{ - -/** - * @brief Graphlet with Inputs and loco::Graph for multiple inputs - * @note Every Graph will have Input(s) and Output(s) - * We put loco::Graph only in IsGraphlet not to declare separate - * class for loco::Graph - */ -template <unsigned N> class TestIsGraphlet -{ -public: - TestIsGraphlet() - { - for (uint32_t n = 0; n < N; ++n) - { - _graph_inputs[n] = nullptr; - _inputs[n] = nullptr; - } - } - -public: - virtual void init(loco::Graph *g, const ShapeU32 shape_in) - { - for (uint32_t n = 0; n < N; ++n) - { - _graph_inputs[n] = g->inputs()->create(); - - _inputs[n] = g->nodes()->create<luci::CircleInput>(); - _inputs[n]->shape(shape_in); - _inputs[n]->shape_status(luci::ShapeStatus::VALID); - _inputs[n]->dtype(loco::DataType::FLOAT32); - _inputs[n]->name("input_" + std::to_string(n)); - - _inputs[n]->index(_graph_inputs[n]->index()); - - auto input_shape = std::make_unique<loco::TensorShape>(); - set_shape_vector(input_shape.get(), shape_in); - _graph_inputs[n]->shape(std::move(input_shape)); - _graph_inputs[n]->dtype(loco::DataType::FLOAT32); - } - } - -public: - loco::Graph *g(void) { return &_g; } - luci::CircleInput *input(int idx) { return _inputs[idx]; } - -protected: - loco::Graph _g; - std::array<loco::GraphInput *, N> _graph_inputs; - std::array<luci::CircleInput *, N> _inputs; -}; - -/** - * @brief Graphlet with one Input - */ -class TestIGraphlet : public TestIsGraphlet<1> -{ -public: - luci::CircleInput *input() { return _inputs[0]; } -}; - -/** - * @brief Graphlet with Outputs for multiple outputs - */ -template <unsigned N> class TestOsGraphlet -{ -public: - TestOsGraphlet() - { - for (uint32_t n = 0; n < N; ++n) - { - _graph_outputs[n] = nullptr; - _outputs[n] = nullptr; - } - } - -public: - virtual void init(loco::Graph *g, const ShapeU32 shape_out) - { - for (uint32_t n = 0; n < N; ++n) - { - _graph_outputs[n] = g->outputs()->create(); - - _outputs[n] = g->nodes()->create<luci::CircleOutput>(); - _outputs[n]->shape(shape_out); - _outputs[n]->shape_status(luci::ShapeStatus::VALID); - _outputs[n]->dtype(loco::DataType::FLOAT32); - _outputs[n]->name("output_" + std::to_string(n)); - - _outputs[n]->index(_graph_outputs[n]->index()); - - auto output_shape = std::make_unique<loco::TensorShape>(); - set_shape_vector(output_shape.get(), shape_out); - _graph_outputs[n]->shape(std::move(output_shape)); - _graph_outputs[n]->dtype(loco::DataType::FLOAT32); - } - } - -public: - luci::CircleOutput *output(int idx) { return _outputs[idx]; } - -protected: - std::array<loco::GraphOutput *, N> _graph_outputs; - std::array<luci::CircleOutput *, N> _outputs; -}; - -/** - * @brief Graphlet with one Output - */ -class TestOGraphlet : public TestOsGraphlet<1> -{ -public: - luci::CircleOutput *output() { return _outputs[0]; } -}; - -/** - * @brief Graph with Input and Output - */ -class TestIOGraph : public TestIGraphlet, public TestOGraphlet -{ -public: - TestIOGraph() = default; - -public: - virtual void init(const ShapeU32 shape_in, const ShapeU32 shape_out) - { - TestIsGraphlet<1>::init(g(), shape_in); - TestOsGraphlet<1>::init(g(), shape_out); - } -}; - -} // namespace test -} // namespace luci - -#endif // __LUCI_PASS_TEST_IO_GRAPH_H__ |