diff options
author | Chunseok Lee <chunseok.lee@samsung.com> | 2022-09-07 19:04:21 +0900 |
---|---|---|
committer | Chunseok Lee <chunseok.lee@samsung.com> | 2022-09-07 19:04:21 +0900 |
commit | c690d52bdd137ed6a17353aa7af35e8141ece77b (patch) | |
tree | dbb7dd99133132dfbffcb8c9e9af4f1ffc2f4808 /compiler/luci/pass | |
parent | 3ad689f0803519e343c36d5700646e86059df961 (diff) | |
download | nnfw-tizen_7.0.tar.gz nnfw-tizen_7.0.tar.bz2 nnfw-tizen_7.0.zip |
Imported Upstream version 1.21.0upstream/1.21.0tizen_7.0_m2_releaseaccepted/tizen/unified/20220912.170817accepted/tizen/unified/20220912.164738accepted/tizen/7.0/unified/hotfix/20221116.105341accepted/tizen/7.0/unified/20221110.060236tizen_7.0_hotfixtizen_7.0accepted/tizen_7.0_unified_hotfixaccepted/tizen_7.0_unified
Diffstat (limited to 'compiler/luci/pass')
47 files changed, 4300 insertions, 266 deletions
diff --git a/compiler/luci/pass/CMakeLists.txt b/compiler/luci/pass/CMakeLists.txt index 5237c6d3f..d9d004db9 100644 --- a/compiler/luci/pass/CMakeLists.txt +++ b/compiler/luci/pass/CMakeLists.txt @@ -1,9 +1,16 @@ nnas_find_package(FlatBuffers EXACT 2.0 QUIET) +nnas_find_package(Fp16Source QUIET) + if(NOT FlatBuffers_FOUND) message(STATUS "FlatBuffers NOT FOUND") return() endif(NOT FlatBuffers_FOUND) +if(NOT Fp16Source_FOUND) + message(STATUS "Fp16Source NOT FOUND") + return() +endif(NOT Fp16Source_FOUND) + file(GLOB_RECURSE SOURCES "src/*.cpp") file(GLOB_RECURSE TESTS "src/*.test.cpp") list(REMOVE_ITEM SOURCES ${TESTS}) @@ -14,6 +21,7 @@ endif(NOT LUCI_LIBRARY_TYPE) add_library(luci_pass ${LUCI_LIBRARY_TYPE} ${SOURCES}) target_include_directories(luci_pass PRIVATE src) +target_include_directories(luci_pass PRIVATE ${Fp16Source_DIR}/include) target_include_directories(luci_pass PUBLIC include) target_link_libraries(luci_pass PUBLIC loco) target_link_libraries(luci_pass PUBLIC logo_core) diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h index c803898f6..b94822c35 100644 --- a/compiler/luci/pass/include/luci/CircleOptimizer.h +++ b/compiler/luci/pass/include/luci/CircleOptimizer.h @@ -47,8 +47,10 @@ public: ResolveCustomOpBatchMatMul, ResolveCustomOpMatMul, ResolveCustomOpMaxPoolWithArgmax, + ResolveCustomOpSplitV, FoldAddV2, FoldCast, + FoldDensify, FoldDepthwiseConv2D, FoldDequantize, FoldGather, @@ -61,6 +63,7 @@ public: ShuffleWeightTo16x1Float32, RemoveRedundantTranspose, ReplaceMulAddWithDepthwiseConv, + ReplaceNonConstFCWithBatchMatMul, ReplaceSubWithAdd, SubstitutePackToReshape, SubstitutePadV2ToPad, diff --git a/compiler/luci/pass/include/luci/Pass/FoldDensifyPass.h b/compiler/luci/pass/include/luci/Pass/FoldDensifyPass.h new file mode 100644 index 000000000..8ec81b1d4 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FoldDensifyPass.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FOLD_DENSIFY_PASS_H__ +#define __LUCI_FOLD_DENSIFY_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to Fold Densify if input is Sparse Constant + * + */ +struct FoldDensifyPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FoldDensifyPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FOLD_DENSIFY_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/RemoveRedundantDequantizePass.h b/compiler/luci/pass/include/luci/Pass/RemoveRedundantDequantizePass.h new file mode 100644 index 000000000..2deb75297 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/RemoveRedundantDequantizePass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_REMOVE_REDUNDANT_DEQUANTIZE_PASS_H__ +#define __LUCI_REMOVE_REDUNDANT_DEQUANTIZE_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to remove redundant dequantize operations + */ +struct RemoveRedundantDequantizePass final : public logo::Pass +{ + const char *name(void) const final { return "luci::RemoveRedundantDequantizePass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_REMOVE_REDUNDANT_DEQUANTIZE_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryReshapeNetPass.h b/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryReshapeNetPass.h new file mode 100644 index 000000000..19948a31c --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryReshapeNetPass.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_REMOVE_UNNECESSARY_RESHAPE_NET_PASS_H__ +#define __LUCI_REMOVE_UNNECESSARY_RESHAPE_NET_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to remove unnecessary Reshape nodes. + * @details This class will remove unnecessary pre/post-Reshape nodes. + * See https://github.com/Samsung/ONE/issues/9600 for more details. + */ +struct RemoveUnnecessaryReshapeNetPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::RemoveUnnecessaryReshapeNetPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_REMOVE_UNNECESSARY_RESHAPE_NET_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h b/compiler/luci/pass/include/luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h new file mode 100644 index 000000000..24e16ec49 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_REPLACE_NONCONST_FC_WITH_BATCH_MATMUL_PASS_H__ +#define __LUCI_REPLACE_NONCONST_FC_WITH_BATCH_MATMUL_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to replace "FC with non-const weight" with Batched MatMul + */ +struct ReplaceNonConstFCWithBatchMatMulPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::ReplaceNonConstFCWithBatchMatMulPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_REPLACE_NONCONST_FC_WITH_BATCH_MATMUL_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/ResolveCustomOpSplitVPass.h b/compiler/luci/pass/include/luci/Pass/ResolveCustomOpSplitVPass.h new file mode 100644 index 000000000..d4f0147e8 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/ResolveCustomOpSplitVPass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_RESOLVE_CUSTOM_OP_SPLIT_V_PASS_H__ +#define __LUCI_RESOLVE_CUSTOM_OP_SPLIT_V_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to resolve certain custom op of subgraph into splitv op in circle schema. + */ +struct ResolveCustomOpSplitVPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::ResolveCustomOpSplitVPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_RESOLVE_CUSTOM_OP_SPLIT_V_PASS_H__ diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp index 6dbb22d7c..74c569d20 100644 --- a/compiler/luci/pass/src/CircleOptimizer.cpp +++ b/compiler/luci/pass/src/CircleOptimizer.cpp @@ -20,6 +20,7 @@ #include "luci/Pass/ExpandBroadcastConstPass.h" #include "luci/Pass/FoldAddV2Pass.h" #include "luci/Pass/FoldCastPass.h" +#include "luci/Pass/FoldDensifyPass.h" #include "luci/Pass/FoldDepthwiseConv2DPass.h" #include "luci/Pass/FoldDequantizePass.h" #include "luci/Pass/FoldGatherPass.h" @@ -43,15 +44,18 @@ #include "luci/Pass/RemoveRedundantTransposePass.h" #include "luci/Pass/RemoveRedundantQuantizePass.h" #include "luci/Pass/RemoveUnnecessaryReshapePass.h" +#include "luci/Pass/RemoveUnnecessaryReshapeNetPass.h" #include "luci/Pass/RemoveUnnecessarySlicePass.h" #include "luci/Pass/RemoveUnnecessaryStridedSlicePass.h" #include "luci/Pass/RemoveUnnecessarySplitPass.h" +#include "luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h" #include "luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h" #include "luci/Pass/ReplaceSubWithAddPass.h" #include "luci/Pass/ResolveCustomOpAddPass.h" #include "luci/Pass/ResolveCustomOpBatchMatMulPass.h" #include "luci/Pass/ResolveCustomOpMatMulPass.h" #include "luci/Pass/ResolveCustomOpMaxPoolWithArgmaxPass.h" +#include "luci/Pass/ResolveCustomOpSplitVPass.h" #include "luci/Pass/SparsifyTensorPass.h" #include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h" #include "luci/Pass/SubstitutePackToReshapePass.h" @@ -127,7 +131,8 @@ bool OptimizeOptionsImpl::query(Algorithm algo) return true; } -void convert_nchw_to_nhwc(loco::Graph *g, bool preserve_input, bool preserve_output) +// TODO Make a struct for args +void convert_nchw_to_nhwc(loco::Graph *g, bool preserve_input, bool preserve_output, bool fuse_fc) { logo::Phase phase; @@ -135,6 +140,21 @@ void convert_nchw_to_nhwc(loco::Graph *g, bool preserve_input, bool preserve_out phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>()); phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>()); + // Resolve custom Ops + phase.emplace_back(std::make_unique<luci::ResolveCustomOpAddPass>()); + phase.emplace_back(std::make_unique<luci::ResolveCustomOpBatchMatMulPass>()); + phase.emplace_back(std::make_unique<luci::ResolveCustomOpMatMulPass>()); + phase.emplace_back(std::make_unique<luci::ResolveCustomOpMaxPoolWithArgmaxPass>()); + phase.emplace_back(std::make_unique<luci::ResolveCustomOpSplitVPass>()); + + // Fuse FullyConnected with Add + // Why we perform FuseAddWithFullyConnectedPass before ConvertNCHWToNHWCPass? + // FullyConnected Op's layout is not changed in ConvertNCHWToNHWCPass, while + // Add Op's layer is changed from NCHW to NHWC. + // This disables fusion of Add and FullyConnected after ConvertNCHWToNHWC. + if (fuse_fc) + phase.emplace_back(std::make_unique<luci::FuseAddWithFullyConnectedPass>()); + phase.emplace_back( std::make_unique<luci::ConvertNCHWToNHWCPass>(preserve_input, preserve_output)); @@ -190,7 +210,9 @@ void CircleOptimizer::optimize(loco::Graph *g) const bool preserve_output = _options->param(Options::AlgorithmParameters::NCHW_to_NHWC_output_shape) != "true"; - convert_nchw_to_nhwc(g, preserve_input, preserve_output); + bool fuse_fc = _options->query(Options::Algorithm::FuseAddWithFullyConnected); + + convert_nchw_to_nhwc(g, preserve_input, preserve_output, fuse_fc); } /* TRANSFORM DECLARATION BEGIN */ @@ -220,6 +242,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique<luci::ResolveCustomOpMaxPoolWithArgmaxPass>()); } + if (_options->query(Options::Algorithm::ResolveCustomOpSplitV)) + { + phase.emplace_back(std::make_unique<luci::ResolveCustomOpSplitVPass>()); + } if (_options->query(Options::Algorithm::FuseInstanceNorm)) { phase.emplace_back(std::make_unique<FuseInstanceNormPass>()); @@ -260,6 +286,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique<luci::FoldCastPass>()); } + if (_options->query(Options::Algorithm::FoldDensify)) + { + phase.emplace_back(std::make_unique<luci::FoldDensifyPass>()); + } if (_options->query(Options::Algorithm::FoldDepthwiseConv2D)) { phase.emplace_back(std::make_unique<luci::FoldDepthwiseConv2DPass>()); @@ -307,6 +337,7 @@ void CircleOptimizer::optimize(loco::Graph *g) const if (_options->query(Options::Algorithm::RemoveUnnecessaryReshape)) { phase.emplace_back(std::make_unique<luci::RemoveUnnecessaryReshapePass>()); + phase.emplace_back(std::make_unique<luci::RemoveUnnecessaryReshapeNetPass>()); } if (_options->query(Options::Algorithm::RemoveUnnecessarySlice)) { @@ -332,6 +363,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique<luci::RemoveRedundantQuantizePass>()); } + if (_options->query(Options::Algorithm::ReplaceNonConstFCWithBatchMatMul)) + { + phase.emplace_back(std::make_unique<luci::ReplaceNonConstFCWithBatchMatMulPass>()); + } if (_options->query(Options::Algorithm::ReplaceMulAddWithDepthwiseConv)) { phase.emplace_back(std::make_unique<luci::ReplaceMulAddWithDepthwiseConvPass>()); diff --git a/compiler/luci/pass/src/CircleQuantizer.cpp b/compiler/luci/pass/src/CircleQuantizer.cpp index ce38a90b9..9a6550b9f 100644 --- a/compiler/luci/pass/src/CircleQuantizer.cpp +++ b/compiler/luci/pass/src/CircleQuantizer.cpp @@ -22,6 +22,7 @@ #include "luci/Pass/RequantizePass.h" #include "luci/Pass/ConvertToFakeQuantizedModelPass.h" #include "luci/Pass/FoldDequantizePass.h" +#include "luci/Pass/RemoveRedundantDequantizePass.h" #include "luci/Pass/QuantizePreCheckerPass.h" #include "luci/Pass/QuantizeWithMinMaxPass.h" #include "luci/Pass/QuantizeDequantizeWeightsPass.h" @@ -252,8 +253,8 @@ void CircleQuantizer::quantize(loco::Graph *g) const static const std::vector<std::string> qwmm_supported_input_model_dtype{"float32"}; static const std::vector<std::string> qwmm_supported_output_model_dtype{"uint8", "int16"}; static const std::vector<std::string> qwmm_supported_granularity{"layer", "channel"}; - static const std::vector<std::string> qwmm_supported_input_type{"uint8", "int16"}; - static const std::vector<std::string> qwmm_supported_output_type{"uint8", "int16"}; + static const std::vector<std::string> qwmm_supported_input_type{"uint8", "int16", "float32"}; + static const std::vector<std::string> qwmm_supported_output_type{"uint8", "int16", "float32"}; auto input_model_dtype = _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype); @@ -434,6 +435,8 @@ void CircleQuantizer::quantize(loco::Graph *g) const phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>()); phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>()); + // Remove redundant Dequantize Ops generated during fake quantization + phase.emplace_back(std::make_unique<luci::RemoveRedundantDequantizePass>()); // Fold Dequantize Ops generated during fake quantization phase.emplace_back(std::make_unique<luci::FoldDequantizePass>()); diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp index ce4f54035..55a29d105 100644 --- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp +++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp @@ -28,6 +28,69 @@ namespace { +// Return true if from can be broadcasted to to +// to's shape is [N, C, H, W] +bool broadcastable(const luci::CircleConst *from, const luci::CircleNode *to) +{ + assert(to->rank() == 4); // FIX_CALLER_UNLESS + + const auto from_rank = from->rank(); + if (from_rank > 4) + return false; + + // Scalar is always broadcastable + if (from_rank == 0) + return true; + + for (uint32_t i = 1; i <= from_rank; i++) + { + auto to_index = 4 - i; + auto from_index = from_rank - i; + + if (from->dim(from_index).value() != to->dim(to_index).value() and + from->dim(from_index).value() != 1) + return false; + } + + return true; +} + +// Expand node to rank 4 +// node should have rank less than or equal to 4 +void expand_to_rank_4(luci::CircleConst *node) +{ + auto original_rank = node->rank(); + + assert(original_rank <= 4); // FIX_CALLER_UNLESS + + if (original_rank == 4) + return; + + std::vector<uint32_t> original_shape; + for (uint32_t i = 0; i < original_rank; i++) + { + original_shape.emplace_back(node->dim(i).value()); + } + + node->rank(4); + for (uint32_t i = 0; i < (4 - original_rank); i++) + node->dim(i) = 1; + + for (uint32_t i = 0; i < original_rank; i++) + node->dim(i + (4 - original_rank)) = original_shape.at(i); +} + +bool is_output(const loco::Node *node) +{ + auto cnode = loco::must_cast<const luci::CircleNode *>(node); + auto opcode = cnode->opcode(); + if (opcode == luci::CircleOpcode::CIRCLEOUTPUT || + opcode == luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE) + return true; + + return false; +} + bool is_same_shape(const luci::CircleNode *node, const std::vector<loco::Dimension> &shape) { if (not node) @@ -484,7 +547,7 @@ bool is_NCHW_with_s_const(const T *node, luci::CircleNode *&pred_node, // // Find MUL with an NCHW pattern described below // - Input (non-constant) shape : [N, C, H, W] -// - Input (constant) shape : [1, C, 1, 1], [N, C, H, W] or a scalar (1) +// - Input (constant) shape : broadcastable to [N, C, H, W] // - Output shape : [N, C, H, W] bool is_NCHW_with_const(const luci::CircleMul *node, luci::CircleNode *&pred_node, luci::CircleConst *&multiplier) @@ -511,32 +574,12 @@ bool is_NCHW_with_const(const luci::CircleMul *node, luci::CircleNode *&pred_nod if (pred_node->rank() != 4) return false; - const auto const_rank = multiplier->rank(); - // Support Rank 4 or scalar (rank 0 or 1) - if (const_rank != 4 && const_rank != 0 && const_rank != 1) + if (not broadcastable(multiplier, node)) return false; - const auto input_cdim = pred_node->dim(1); - const auto output_cdim = node->dim(1); - - if (const_rank == 4) - { - bool supported_shape = false; - - // Check multiplier is (1, C, 1, 1) - if (is_same_shape(multiplier, {1, node->dim(1), 1, 1})) - supported_shape = true; - - // Check multiplier is (N, C, H, W) - if (is_same_shape(multiplier, {node->dim(0), node->dim(1), node->dim(2), node->dim(3)})) - supported_shape = true; + expand_to_rank_4(multiplier); - return supported_shape; - } - if (input_cdim == output_cdim) - return true; - else - return false; + return true; } // We assume ADD with const input is NCHW if, @@ -569,32 +612,12 @@ bool is_NCHW_with_const(const luci::CircleAdd *node, luci::CircleNode *&pred_nod if (pred_node->rank() != 4) return false; - const auto const_rank = beta->rank(); - // Support Rank 4 or scalar (rank 0 or 1) - if (const_rank != 4 && const_rank != 0 && const_rank != 1) + if (not broadcastable(beta, node)) return false; - const auto input_cdim = pred_node->dim(1); - const auto output_cdim = node->dim(1); - - if (const_rank == 4) - { - bool supported_shape = false; - - // Check beta is (1, C, 1, 1) - if (is_same_shape(beta, {1, node->dim(1), 1, 1})) - supported_shape = true; - - // Check beta is (N, C, H, W) - if (is_same_shape(beta, {node->dim(0), node->dim(1), node->dim(2), node->dim(3)})) - supported_shape = true; + expand_to_rank_4(beta); - return supported_shape; - } - if (input_cdim == output_cdim) - return true; - else - return false; + return true; } // We assume SUB with const input is NCHW if, @@ -675,6 +698,24 @@ template <class T> bool convert_unary_x(T *node) return true; } +template <class T> bool convert_unary_logits(T *node) +{ + const auto pred_node = loco::must_cast<luci::CircleNode *>(node->logits()); + auto pre_trans = create_pre_transpose(node); + pre_trans->a(pred_node); + node->logits(pre_trans); + + // Do shape inference for this node again. + node->shape_status(luci::ShapeStatus::UNDEFINED); + + auto post_trans = create_post_transpose(node); + loco::replace(node).with(post_trans); + + post_trans->a(node); + + return true; +} + class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool> { // Default @@ -742,17 +783,14 @@ class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool> if (is_NCHW_with_const(node, pred_node, beta)) { + assert(beta->rank() == 4); // FIX is_NCHW_with_const unless + auto nhwc_const = create_NHWC_from_NCHW(beta); + if (nhwc_const == nullptr) + return false; + node->y(nhwc_const); + auto pre_trans = create_pre_transpose(node); pre_trans->a(pred_node); - - if (beta->rank() == 4) - { - auto nhwc_const = create_NHWC_from_NCHW(beta); - if (nhwc_const == nullptr) - return false; - node->y(nhwc_const); - } - node->x(pre_trans); } else if (beta == nullptr) @@ -816,6 +854,11 @@ class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool> bool visit(luci::CircleLogistic *node) { return convert_unary_x<luci::CircleLogistic>(node); } + bool visit(luci::CircleLogSoftmax *node) + { + return convert_unary_logits<luci::CircleLogSoftmax>(node); + } + bool visit(luci::CircleMaximum *node) { luci::CircleNode *pred_node = nullptr; @@ -954,15 +997,15 @@ class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool> if (is_NCHW_with_const(node, pred_node, multiplier)) { + assert(multiplier->rank() == 4); // FIX is_NCHW_with_const unless + auto nhwc_const = create_NHWC_from_NCHW(multiplier); + if (nhwc_const == nullptr) + return false; + node->y(nhwc_const); + auto pre_trans = create_pre_transpose(node); pre_trans->a(pred_node); node->x(pre_trans); - - if (multiplier->rank() == 4) - { - auto nhwc_const = create_NHWC_from_NCHW(multiplier); - node->y(nhwc_const); - } } else if (multiplier == nullptr) { @@ -1049,12 +1092,127 @@ class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool> return true; } + // TODO Reduce duplicate code with CircleMean + bool visit(luci::CircleReduceMax *node) + { + auto input = loco::must_cast<luci::CircleNode *>(node->input()); + if (input->rank() != 4) + return false; + + auto rindices = dynamic_cast<luci::CircleConst *>(node->reduction_indices()); + if (not rindices) + return false; + + auto nhwc_rindices = create_NHWC_rindices(rindices); + if (not nhwc_rindices) + return false; + + auto pre_trans = create_pre_transpose(node); + pre_trans->a(input); + node->input(pre_trans); + + // Do shape inference for this node again. + node->shape_status(luci::ShapeStatus::UNDEFINED); + + node->reduction_indices(nhwc_rindices); + + if (node->keep_dims()) + { + auto post_trans = create_post_transpose(node); + loco::replace(node).with(post_trans); + + post_trans->a(node); + + return true; + } + + // The below codes handle the cases where node->keep_dims() == false + // 1D output never needs a transpose + if (node->rank() <= 1) + return true; + + std::vector<bool> reduced_dims_nhwc(4, false); + uint32_t num_reduced_indices = nhwc_rindices->size<loco::DataType::S32>(); + + for (uint32_t ri = 0; ri < num_reduced_indices; ++ri) + { + reduced_dims_nhwc[nhwc_rindices->at<loco::DataType::S32>(ri)] = true; + } + + // if channel dimension has been reduced, we don't need a transpose + if (reduced_dims_nhwc[3]) + return true; + + // likewise, if both space dimensions are reduced, no transpose is needed + if (reduced_dims_nhwc[1] && reduced_dims_nhwc[2]) + return true; + + std::vector<int32_t> post_trans_ind; + // case 1: only N is reduced + if (num_reduced_indices == 1 && reduced_dims_nhwc[0]) + post_trans_ind = {2, 0, 1}; + + // case 2: only H or W is reduced + if (num_reduced_indices == 1 && (reduced_dims_nhwc[1] || reduced_dims_nhwc[2])) + post_trans_ind = {0, 2, 1}; + + // case 3: N and either H or W are reduced + if (num_reduced_indices == 2) + post_trans_ind = {1, 0}; + + auto post_trans = create_Nd_transpose(node, post_trans_ind); + loco::replace(node).with(post_trans); + + post_trans->a(node); + + return true; + } + bool visit(luci::CircleRelu *node) { return convert_unary_features<luci::CircleRelu>(node); } bool visit(luci::CircleRelu6 *node) { return convert_unary_features<luci::CircleRelu6>(node); } bool visit(luci::CircleRsqrt *node) { return convert_unary_x<luci::CircleRsqrt>(node); } + bool visit(luci::CircleSoftmax *node) { return convert_unary_logits<luci::CircleSoftmax>(node); } + + bool visit(luci::CircleSplitV *node) + { + // Change split dimension + auto axis = dynamic_cast<luci::CircleConst *>(node->split_dim()); + if (not axis) + return false; + + if (axis->dtype() != loco::DataType::S32) + return false; + + if (axis->size<loco::DataType::S32>() != 1) + return false; + + axis->at<loco::DataType::S32>(0) = nchw_axis_to_nhwc(axis->at<loco::DataType::S32>(0)); + + // Insert pre-transpose + const auto pred_node = loco::must_cast<luci::CircleNode *>(node->input()); + auto pre_trans = create_pre_transpose(node); + pre_trans->a(pred_node); + node->input(pre_trans); + + // Do shape inference for this node again. + node->shape_status(luci::ShapeStatus::UNDEFINED); + + // Insert post-transposes + for (auto succ : loco::succs(node)) + { + auto svo = loco::must_cast<luci::CircleSplitVOut *>(succ); + + auto post_trans = create_post_transpose(svo); + loco::replace(svo).with(post_trans); + post_trans->a(svo); + } + + return true; + } + bool visit(luci::CircleSquaredDifference *node) { // TODO support CircleConst input @@ -1195,6 +1353,8 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g) // pre-Transpose --- [intermediate Ops] --- post-Transpose // | // +--[intermediate Ops] --- post-Transpose + // + // NOTE Intermediate Ops SHOULD NOT contain pre-Transpose/Reshape for (auto node : loco::postorder_traversal(loco::output_nodes(g))) { if (has_data_format(node)) @@ -1202,25 +1362,51 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g) if (is_pre_transpose(node) || is_pre_reshape(node)) { + std::set<loco::Node *> intermediate; + + // Variable to check intermediate Ops contain pre-Transpose/Reshape + bool has_pre = false; + + // Variable to check the pattern is closed with post-Transpose/Reshape + bool is_closed = true; + // For recursive call of lambda - std::function<void(loco::Node *)> set_data_format_to_succs; - set_data_format_to_succs = [&](loco::Node *n) { + std::function<void(loco::Node *)> collect_intermediate; + collect_intermediate = [&](loco::Node *n) { for (auto succ : loco::succs(n)) { // Exit condition if (is_post_transpose(succ) || is_post_reshape(succ)) continue; - if (not has_data_format(succ)) + if (is_pre_transpose(succ) || is_pre_reshape(succ)) + { + has_pre = true; + break; + } + + if (is_output(succ)) { - set_data_format(succ, DataFormat::NHWC); + is_closed = false; + break; } - set_data_format_to_succs(succ); + intermediate.emplace(succ); + + collect_intermediate(succ); } }; - set_data_format_to_succs(node); + collect_intermediate(node); + + if (has_pre or not is_closed) + continue; + + for (auto inter : intermediate) + { + if (not has_data_format(inter)) + set_data_format(inter, DataFormat::NHWC); + } } } @@ -1248,6 +1434,7 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g) case luci::CircleOpcode::ELU: case luci::CircleOpcode::LEAKY_RELU: case luci::CircleOpcode::LOGISTIC: + case luci::CircleOpcode::LOG_SOFTMAX: case luci::CircleOpcode::MAXIMUM: case luci::CircleOpcode::MEAN: case luci::CircleOpcode::MINIMUM: @@ -1255,9 +1442,12 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g) case luci::CircleOpcode::NEG: case luci::CircleOpcode::PAD: case luci::CircleOpcode::PADV2: + case luci::CircleOpcode::REDUCE_MAX: case luci::CircleOpcode::RELU: case luci::CircleOpcode::RELU6: case luci::CircleOpcode::RSQRT: + case luci::CircleOpcode::SOFTMAX: + case luci::CircleOpcode::SPLIT_V: case luci::CircleOpcode::SQUARED_DIFFERENCE: case luci::CircleOpcode::SUB: if (!has_data_format(node)) @@ -1296,7 +1486,8 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g) if (circle_node->rank() != 4) { // TODO replace the check above with the input rank check, and remove the condition below - if (not dynamic_cast<luci::CircleMean *>(node)) + if (not dynamic_cast<luci::CircleMean *>(node) and + not dynamic_cast<luci::CircleReduceMax *>(node)) continue; } diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp index dd81d1380..6bb3d3268 100644 --- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp +++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp @@ -16,6 +16,8 @@ #include <logo/Phase.h> +#include <luci/test/TestIOGraph.h> + #include "luci/Pass/ConvertNCHWToNHWCPass.h" #include "luci/Pass/CircleShapeInferencePass.h" @@ -23,6 +25,8 @@ #include <gtest/gtest.h> +using namespace luci::test; + namespace { @@ -202,6 +206,173 @@ public: luci::CircleConst *post_shape = nullptr; }; +/** + * Graph with pre-Reshape but no post-Transpose/Reshape. + * + * BEFORE + * [Input] + * | + * [Pre-Reshape] + * | + * [Relu] + * | + * [Output] + * + * AFTER + * [Input] + * | + * [Pre-Reshape] + * | + * [Pre-Transpose] + * | + * [Relu] + * | + * [Post-Transpose] + * | + * [Output] + */ +class NoPostReshapeGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + relu = g.nodes()->create<luci::CircleRelu>(); + pre_reshape = g.nodes()->create<luci::CircleReshape>(); + pre_shape = g.nodes()->create<luci::CircleConst>(); + + pre_shape->dtype(loco::DataType::S32); + + uint32_t channel_size = 16; + auto in = loco::must_cast<luci::CircleNode *>(input); + in->shape({1, channel_size, 4, 4}); + pre_shape->shape({4}); + + pre_shape->size<loco::DataType::S32>(4); + pre_shape->at<loco::DataType::S32>(0) = 1; + pre_shape->at<loco::DataType::S32>(1) = 4; + pre_shape->at<loco::DataType::S32>(2) = 4; + pre_shape->at<loco::DataType::S32>(3) = channel_size; + + pre_reshape->tensor(input); + pre_reshape->shape(pre_shape); + relu->features(pre_reshape); + + relu->name("Relu"); + pre_reshape->name("pre-reshape"); + + return relu; + } + +public: + luci::CircleRelu *relu = nullptr; + luci::CircleReshape *pre_reshape = nullptr; + luci::CircleConst *pre_shape = nullptr; +}; + +/** + * Graph with two pre-Reshapes + * + * BEFORE + * [Input] + * | + * [Pre-Reshape] + * | + * [Relu] + * | + * [Pre-Reshape] + * | + * [Post-Reshape] + * | + * [Output] + * + * AFTER + * [Input] + * | + * [Pre-Reshape] + * | + * [Pre-Transpose] + * | + * [Relu] + * | + * [Post-Transpose] + * | + * [Pre-Reshape] + * | + * [Post-Reshape] + * | + * [Output] + */ +class ReluNotClosedGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + relu = g.nodes()->create<luci::CircleRelu>(); + pre_reshape = g.nodes()->create<luci::CircleReshape>(); + pre_reshape_2 = g.nodes()->create<luci::CircleReshape>(); + post_reshape = g.nodes()->create<luci::CircleReshape>(); + pre_shape = g.nodes()->create<luci::CircleConst>(); + pre_shape_2 = g.nodes()->create<luci::CircleConst>(); + post_shape = g.nodes()->create<luci::CircleConst>(); + + pre_shape->dtype(loco::DataType::S32); + pre_shape_2->dtype(loco::DataType::S32); + post_shape->dtype(loco::DataType::S32); + + uint32_t channel_size = 16; + auto in = loco::must_cast<luci::CircleNode *>(input); + in->shape({1, channel_size, 4, 4}); + pre_shape->shape({4}); + pre_shape_2->shape({4}); + post_shape->shape({4}); + + pre_shape->size<loco::DataType::S32>(4); + pre_shape->at<loco::DataType::S32>(0) = 1; + pre_shape->at<loco::DataType::S32>(1) = 4; + pre_shape->at<loco::DataType::S32>(2) = 4; + pre_shape->at<loco::DataType::S32>(3) = channel_size; + + pre_shape_2->size<loco::DataType::S32>(4); + pre_shape_2->at<loco::DataType::S32>(0) = 1; + pre_shape_2->at<loco::DataType::S32>(1) = 4; + pre_shape_2->at<loco::DataType::S32>(2) = channel_size; + pre_shape_2->at<loco::DataType::S32>(3) = 4; + + post_shape->size<loco::DataType::S32>(4); + post_shape->at<loco::DataType::S32>(0) = 1; + post_shape->at<loco::DataType::S32>(1) = 4; + post_shape->at<loco::DataType::S32>(2) = 4; + post_shape->at<loco::DataType::S32>(3) = channel_size; + + pre_reshape->tensor(input); + pre_reshape->shape(pre_shape); + + relu->features(pre_reshape); + + pre_reshape_2->tensor(relu); + pre_reshape_2->shape(pre_shape_2); + + post_reshape->tensor(pre_reshape_2); + post_reshape->shape(post_shape); + + relu->name("Relu"); + pre_reshape->name("pre-reshape"); + pre_reshape->name("pre-reshape-2"); + post_reshape->name("post-reshape"); + + return post_reshape; + } + +public: + luci::CircleRelu *relu = nullptr; + luci::CircleReshape *pre_reshape = nullptr; + luci::CircleReshape *pre_reshape_2 = nullptr; + luci::CircleReshape *post_reshape = nullptr; + luci::CircleConst *pre_shape = nullptr; + luci::CircleConst *pre_shape_2 = nullptr; + luci::CircleConst *post_shape = nullptr; +}; + class AddScalarGraph final : public SimpleGraph { protected: @@ -312,6 +483,22 @@ public: luci::CircleLogistic *logistic = nullptr; }; +class LogSoftmaxGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + log_softmax = g.nodes()->create<luci::CircleLogSoftmax>(); + log_softmax->logits(input); + log_softmax->name("log_softmax"); + + return log_softmax; + } + +public: + luci::CircleLogSoftmax *log_softmax = nullptr; +}; + class MaximumGraph final : public SimpleGraph { protected: @@ -642,6 +829,51 @@ public: luci::CircleConst *const_value = nullptr; }; +class ReduceMaxGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + rm = g.nodes()->create<luci::CircleReduceMax>(); + rindices = g.nodes()->create<luci::CircleConst>(); + + rm->dtype(loco::DataType::FLOAT32); + rindices->dtype(loco::DataType::S32); + + rm->shape(_shape); + rindices->shape({static_cast<uint32_t>(_axes.size())}); + + rindices->size<loco::DataType::S32>(_axes.size()); + for (uint32_t i = 0; i < _axes.size(); ++i) + { + rindices->at<loco::DataType::S32>(i) = _axes[i]; + } + + rm->input(input); + rm->reduction_indices(rindices); + rm->keep_dims(_keep_dims); + + rm->name("reduce_max"); + rindices->name("rindices"); + + return rm; + } + +public: + void keep_dims(bool val) { _keep_dims = val; } + void axes(std::vector<int32_t> val) { _axes = val; } + void shape(std::initializer_list<uint32_t> val) { _shape = val; } + +public: + luci::CircleReduceMax *rm = nullptr; + luci::CircleConst *rindices = nullptr; + +private: + bool _keep_dims = true; + std::vector<int32_t> _axes = {2, 3}; + std::initializer_list<uint32_t> _shape = {1, 16, 1, 1}; +}; + class ReluGraph final : public SimpleGraph { protected: @@ -690,6 +922,111 @@ public: luci::CircleRsqrt *rsqrt = nullptr; }; +class SoftmaxGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + softmax = g.nodes()->create<luci::CircleSoftmax>(); + softmax->logits(input); + softmax->name("softmax"); + + return softmax; + } + +public: + luci::CircleSoftmax *softmax = nullptr; +}; + +class SplitVGraphlet +{ +public: + SplitVGraphlet() = default; + +public: + void init(loco::Graph *g) + { + // CircleCustom(SplitV) + _splitv = g->nodes()->create<luci::CircleSplitV>(); + _splitv->shape({1, 2, 2, 192}); + _splitv->dtype(loco::DataType::FLOAT32); + _splitv->name("splitv"); + + // CircleConst + auto size_splits = g->nodes()->create<luci::CircleConst>(); + size_splits->dtype(loco::DataType::S32); + size_splits->shape({3}); + size_splits->size<loco::DataType::S32>(3); + size_splits->at<loco::DataType::S32>(0) = 32; + size_splits->at<loco::DataType::S32>(1) = 32; + size_splits->at<loco::DataType::S32>(2) = 128; + + // CircleConst + auto split_dim = g->nodes()->create<luci::CircleConst>(); + split_dim->dtype(loco::DataType::S32); + split_dim->rank(0); + split_dim->size<loco::DataType::S32>(1); + split_dim->scalar<loco::DataType::S32>() = 3; + + _splitv->size_splits(size_splits); + _splitv->split_dim(split_dim); + _splitv->num_split(3); + + // CircleSplitVOut + _splitv_out1 = g->nodes()->create<luci::CircleSplitVOut>(); + _splitv_out1->shape({1, 2, 2, 32}); + _splitv_out1->dtype(loco::DataType::FLOAT32); + _splitv_out1->index(0); + _splitv_out1->input(_splitv); + _splitv_out1->name("splitv_out1"); + + // CircleSplitVOut + _splitv_out2 = g->nodes()->create<luci::CircleSplitVOut>(); + _splitv_out2->shape({1, 2, 2, 32}); + _splitv_out2->dtype(loco::DataType::FLOAT32); + _splitv_out2->index(1); + _splitv_out2->input(_splitv); + _splitv_out2->name("splitv_out2"); + + // CircleSplitVOut + _splitv_out3 = g->nodes()->create<luci::CircleSplitVOut>(); + _splitv_out3->shape({1, 2, 2, 128}); + _splitv_out3->dtype(loco::DataType::FLOAT32); + _splitv_out3->index(2); + _splitv_out3->input(_splitv); + _splitv_out3->name("splitv_out3"); + } + +public: + luci::CircleSplitV *splitv() { return _splitv; } + +protected: + luci::CircleSplitV *_splitv = nullptr; + luci::CircleSplitVOut *_splitv_out1 = nullptr; + luci::CircleSplitVOut *_splitv_out2 = nullptr; + luci::CircleSplitVOut *_splitv_out3 = nullptr; +}; + +class SplitVGraph : public TestIGraphlet, public TestOsGraphlet<3>, public SplitVGraphlet +{ +public: + SplitVGraph() = default; + + void init(void) + { + TestIGraphlet::init(g(), {1, 2, 2, 192}); + TestOsGraphlet<3>::init(g(), {{1, 2, 2, 32}, {1, 2, 2, 32}, {1, 2, 2, 128}}); + SplitVGraphlet::init(g()); + + // connect graph + _splitv->input(input()); + + output(0)->from(_splitv_out1); + output(1)->from(_splitv_out2); + output(2)->from(_splitv_out3); + } +}; + class SquaredDifferenceGraph final : public SimpleGraph { protected: @@ -929,8 +1266,11 @@ TEST(ConvertNCHWToNHWC, AddScalar) auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y()); EXPECT_NE(nullptr, new_beta); - EXPECT_EQ(1, new_beta->rank()); + EXPECT_EQ(4, new_beta->rank()); EXPECT_EQ(1, new_beta->dim(0).value()); + EXPECT_EQ(1, new_beta->dim(1).value()); + EXPECT_EQ(1, new_beta->dim(2).value()); + EXPECT_EQ(1, new_beta->dim(3).value()); check_pre_trans(g.output->from()); } @@ -1017,6 +1357,26 @@ TEST(ConvertNCHWToNHWC, Logistic) EXPECT_EQ(16, g.logistic->dim(3).value()); } +TEST(ConvertNCHWToNHWC, LogSoftmax) +{ + LogSoftmaxGraph g; + g.init(); + + run_phase(&g.g, true, true); + + check_pre_trans(g.log_softmax->logits()); + + auto log_softmax_succs = loco::succs(g.log_softmax); + EXPECT_EQ(1, log_softmax_succs.size()); + check_post_trans(*log_softmax_succs.begin()); + + // Check log_softmax shape + EXPECT_EQ(1, g.log_softmax->dim(0).value()); + EXPECT_EQ(4, g.log_softmax->dim(1).value()); + EXPECT_EQ(4, g.log_softmax->dim(2).value()); + EXPECT_EQ(16, g.log_softmax->dim(3).value()); +} + TEST(ConvertNCHWToNHWC, Maximum) { MaximumGraph g; @@ -1265,8 +1625,11 @@ TEST(ConvertNCHWToNHWC, MulScalar) auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y()); EXPECT_NE(nullptr, new_multiplier); - EXPECT_EQ(1, new_multiplier->rank()); + EXPECT_EQ(4, new_multiplier->rank()); EXPECT_EQ(1, new_multiplier->dim(0).value()); + EXPECT_EQ(1, new_multiplier->dim(1).value()); + EXPECT_EQ(1, new_multiplier->dim(2).value()); + EXPECT_EQ(1, new_multiplier->dim(3).value()); check_pre_trans(g.output->from()); } @@ -1451,6 +1814,85 @@ TEST(ConvertNCHWToNHWC, Preserve_Input_Output) } } +TEST(ConvertNCHWToNHWC, ReduceMax) +{ + ReduceMaxGraph g; + g.init(); + + run_phase(&g.g, false, false); + + check_pre_trans(g.rm->input()); + + auto rm_succs = loco::succs(g.rm); + EXPECT_EQ(1, rm_succs.size()); + check_post_trans(*rm_succs.begin()); + + auto new_rindices = dynamic_cast<luci::CircleConst *>(g.rm->reduction_indices()); + EXPECT_NE(nullptr, new_rindices); + EXPECT_EQ(1, new_rindices->rank()); + EXPECT_EQ(2, new_rindices->dim(0).value()); + EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>()); + EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0)); + EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1)); +} + +TEST(ConvertNCHWToNHWC, ReduceMax_keep_dims_false) +{ + struct TC + { + std::vector<int32_t> nchw_ind; + std::vector<int32_t> nhwc_ind; + std::initializer_list<uint32_t> shape; + bool needs_transpose = false; + }; + + uint32_t n = 1; + uint32_t c = 16; + uint32_t h = 4; + uint32_t w = 4; + + std::vector<TC> test_cases{{{0}, {0}, {c, h, w}, true}, {{1}, {3}, {n, h, w}, false}, + {{2}, {1}, {n, c, w}, true}, {{3}, {2}, {n, c, h}, true}, + {{0, 1}, {0, 3}, {h, w}, false}, {{0, 2}, {0, 1}, {c, w}, true}, + {{0, 3}, {0, 2}, {c, h}, true}, {{1, 2}, {3, 1}, {n, w}, false}, + {{1, 3}, {3, 2}, {n, h}, false}, {{2, 3}, {1, 2}, {n, c}, false}, + {{0, 1, 2}, {0, 3, 1}, {w}, false}}; + + for (auto &tc : test_cases) + { + ReduceMaxGraph g; + g.keep_dims(false); + g.axes(tc.nchw_ind); + g.shape(tc.shape); + g.init(); + + run_phase(&g.g, true, true); + + check_pre_trans(g.rm->input()); + + auto rm_succs = loco::succs(g.rm); + EXPECT_EQ(1, rm_succs.size()); + if (tc.needs_transpose) + { + EXPECT_NE(nullptr, dynamic_cast<luci::CircleTranspose *>(*rm_succs.begin())); + } + else + { + EXPECT_NE(nullptr, dynamic_cast<luci::CircleOutput *>(*rm_succs.begin())); + } + + auto new_rindices = dynamic_cast<luci::CircleConst *>(g.rm->reduction_indices()); + EXPECT_NE(nullptr, new_rindices); + EXPECT_EQ(1, new_rindices->rank()); + EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->dim(0).value()); + EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->size<loco::DataType::S32>()); + for (uint32_t i = 0; i < tc.nhwc_ind.size(); ++i) + { + EXPECT_EQ(tc.nhwc_ind[i], new_rindices->at<loco::DataType::S32>(i)); + } + } +} + TEST(ConvertNCHWToNHWC, Relu) { ReluGraph g; @@ -1511,6 +1953,57 @@ TEST(ConvertNCHWToNHWC, Rsqrt) EXPECT_EQ(16, g.rsqrt->dim(3).value()); } +TEST(ConvertNCHWToNHWC, Softmax) +{ + SoftmaxGraph g; + g.init(); + + run_phase(&g.g, true, true); + + check_pre_trans(g.softmax->logits()); + + auto softmax_succs = loco::succs(g.softmax); + EXPECT_EQ(1, softmax_succs.size()); + check_post_trans(*softmax_succs.begin()); + + // Check softmax shape + EXPECT_EQ(1, g.softmax->dim(0).value()); + EXPECT_EQ(4, g.softmax->dim(1).value()); + EXPECT_EQ(4, g.softmax->dim(2).value()); + EXPECT_EQ(16, g.softmax->dim(3).value()); +} + +TEST(ConvertNCHWToNHWC, SplitV) +{ + SplitVGraph g; + g.init(); + + run_phase(g.g(), true, true); + + check_pre_trans(g.splitv()->input()); + + auto splitv_succs = loco::succs(g.splitv()); + for (auto svo : loco::succs(g.splitv())) + { + for (auto succ : loco::succs(svo)) + { + check_post_trans(succ); + } + } + + // Check splitv() shape + EXPECT_EQ(1, g.splitv()->dim(0).value()); + EXPECT_EQ(2, g.splitv()->dim(1).value()); + EXPECT_EQ(192, g.splitv()->dim(2).value()); + EXPECT_EQ(2, g.splitv()->dim(3).value()); + + // Check axis + auto axis = dynamic_cast<luci::CircleConst *>(g.splitv()->split_dim()); + EXPECT_NE(nullptr, axis); + EXPECT_EQ(1, axis->size<loco::DataType::S32>()); + EXPECT_EQ(2, axis->at<loco::DataType::S32>(0)); +} + TEST(ConvertNCHWToNHWC, SquaredDifference) { SquaredDifferenceGraph g; @@ -1602,3 +2095,31 @@ TEST(ConvertNCHWToNHWC, SubScalar) check_pre_trans(g.output->from()); } + +TEST(ConvertNCHWToNHWC, Not_Closed_Case1_NEG) +{ + NoPostReshapeGraph g; + g.init(); + + run_phase(&g.g, true, true); + + check_pre_trans(g.relu->features()); + + auto relu_succs = loco::succs(g.relu); + EXPECT_EQ(1, relu_succs.size()); + check_post_trans(*relu_succs.begin()); +} + +TEST(ConvertNCHWToNHWC, Not_Closed_Case2_NEG) +{ + ReluNotClosedGraph g; + g.init(); + + run_phase(&g.g, true, true); + + check_pre_trans(g.relu->features()); + + auto relu_succs = loco::succs(g.relu); + EXPECT_EQ(1, relu_succs.size()); + check_post_trans(*relu_succs.begin()); +} diff --git a/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp index 11970fff5..72f590135 100644 --- a/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp +++ b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp @@ -184,8 +184,63 @@ struct FakeQuantize final : public luci::CircleNodeMutableVisitor<void> // For non-const activation, insert Quantize-Dequantize Ops // and dequantize the node - void visit(luci::CircleConv2D *node) { fq_activation(node); } void visit(luci::CircleAdd *node) { fq_activation(node); } + void visit(luci::CircleAveragePool2D *node) { fq_activation(node); } + void visit(luci::CircleBatchMatMul *node) { fq_activation(node); } + void visit(luci::CircleConv2D *node) { fq_activation(node); } + void visit(luci::CircleDepthwiseConv2D *node) { fq_activation(node); } + void visit(luci::CircleDiv *node) { fq_activation(node); } + void visit(luci::CircleFullyConnected *node) { fq_activation(node); } + void visit(luci::CircleInstanceNorm *node) { fq_activation(node); } + void visit(luci::CircleLeakyRelu *node) { fq_activation(node); } + void visit(luci::CircleLogistic *node) { fq_activation(node); } + void visit(luci::CircleLogSoftmax *node) { fq_activation(node); } + void visit(luci::CircleMaxPool2D *node) { fq_activation(node); } + void visit(luci::CircleMul *node) { fq_activation(node); } + void visit(luci::CircleNeg *node) { fq_activation(node); } + void visit(luci::CirclePad *node) { fq_activation(node); } + void visit(luci::CirclePRelu *node) { fq_activation(node); } + void visit(luci::CircleMean *node) { fq_activation(node); } + void visit(luci::CircleReduceMax *node) { fq_activation(node); } + void visit(luci::CircleRelu *node) { fq_activation(node); } + void visit(luci::CircleRelu6 *node) { fq_activation(node); } + void visit(luci::CircleResizeBilinear *node) { fq_activation(node); } + void visit(luci::CircleResizeNearestNeighbor *node) { fq_activation(node); } + void visit(luci::CircleRsqrt *node) { fq_activation(node); } + void visit(luci::CircleSoftmax *node) { fq_activation(node); } + void visit(luci::CircleSqrt *node) { fq_activation(node); } + void visit(luci::CircleTanh *node) { fq_activation(node); } + void visit(luci::CircleTransposeConv *node) { fq_activation(node); } + + // For Ops that do not change the value of input, do nothing + // (dtype will be automatically updated by type inference) + void visit(luci::CircleCast *) {} + void visit(luci::CircleConcatenation *) {} + void visit(luci::CircleGather *) {} + void visit(luci::CircleSlice *) {} + void visit(luci::CircleStridedSlice *) {} + void visit(luci::CircleReshape *) {} + void visit(luci::CircleSplit *) {} + void visit(luci::CircleSplitOut *) {} + void visit(luci::CircleSplitV *) {} + void visit(luci::CircleSplitVOut *) {} + void visit(luci::CircleTranspose *) {} + + // For Ops that return index, fake quantization is unnecessary + void visit(luci::CircleArgMax *) {} + + // Virtual node + void visit(luci::CircleOutputExclude *) {} + + void visit(luci::CircleQuantize *node) + { + RETURN_UNLESS(is_quant_act(node)); + + insert_dequantize(node); + } + + // Dequantize Op does nothing in fp32 model + void visit(luci::CircleDequantize *) {} }; #undef RETURN_UNLESS diff --git a/compiler/luci/pass/src/FoldDensifyPass.cpp b/compiler/luci/pass/src/FoldDensifyPass.cpp new file mode 100644 index 000000000..5ddc743e5 --- /dev/null +++ b/compiler/luci/pass/src/FoldDensifyPass.cpp @@ -0,0 +1,180 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FoldDensifyPass.h" +#include "helpers/SparsityFormatConverter.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/Profile/CircleNodeOrigin.h> + +#include <cassert> +#include <vector> + +namespace +{ + +bool is_foldable_const(luci::CircleConst *node) +{ + if (node->sparsityparam() == nullptr) + return false; + + if (node->dtype() == loco::DataType::FLOAT32) + return true; + if (node->dtype() == loco::DataType::FLOAT16) + return true; + + return false; +} + +luci::CircleConst *densified_const_node(luci::CircleConst *const_node) +{ + assert(const_node->sparsityparam()); + + auto name = const_node->name(); + assert(name.length() > 0); + auto g = const_node->graph(); + auto new_const_node = g->nodes()->create<luci::CircleConst>(); + + new_const_node->dtype(const_node->dtype()); + new_const_node->rank(const_node->rank()); + + uint32_t dim_size = 1; + std::vector<int> dense_shape; + for (uint32_t i = 0; i < new_const_node->rank(); ++i) + { + assert(const_node->dim(i).known()); + new_const_node->dim(i) = const_node->dim(i); + + uint32_t value = const_node->dim(i).value(); + dim_size *= value; + dense_shape.emplace_back(static_cast<int32_t>(value)); + } + + if (const_node->dtype() == loco::DataType::FLOAT32) + new_const_node->size<loco::DataType::FLOAT32>(dim_size); + else + { + assert(const_node->dtype() == loco::DataType::FLOAT16); + new_const_node->size<loco::DataType::FLOAT16>(dim_size); + } + + new_const_node->shape_status(luci::ShapeStatus::VALID); + new_const_node->name(name + "_DS"); + + if (const_node->dtype() == loco::DataType::FLOAT32) + { + auto const_items = const_node->size<loco::DataType::FLOAT32>(); + auto f_data = std::make_unique<float[]>(const_items); + for (size_t i = 0; i < const_items; ++i) + f_data[i] = const_node->at<loco::DataType::FLOAT32>(i); + + sparsity::TfLiteSparsity sp = to_tflite_sparsity(const_node->sparsityparam()); + sparsity::FormatConverter<float> converter(dense_shape, sp); + converter.SparseToDense(f_data.get()); + const auto &data_dense = converter.GetData(); + assert(data_dense.size() == dim_size); + + for (uint32_t i = 0; i < dim_size; ++i) + new_const_node->at<loco::DataType::FLOAT32>(i) = data_dense[i]; + + luci::freeTfLiteSparsity(sp); + } + else + { + assert(const_node->dtype() == loco::DataType::FLOAT16); + + auto const_items = const_node->size<loco::DataType::FLOAT16>(); + auto f_data = std::make_unique<uint16_t[]>(const_items); + for (size_t i = 0; i < const_items; ++i) + f_data[i] = const_node->at<loco::DataType::FLOAT16>(i); + + // Primitive type for FLOAT16 is UINT16 + sparsity::TfLiteSparsity sp = to_tflite_sparsity(const_node->sparsityparam()); + sparsity::FormatConverter<uint16_t> converter(dense_shape, sp); + converter.SparseToDense(f_data.get()); + const auto &data_dense = converter.GetData(); + assert(data_dense.size() == dim_size); + for (uint32_t i = 0; i < dim_size; ++i) + new_const_node->at<loco::DataType::FLOAT16>(i) = data_dense[i]; + + luci::freeTfLiteSparsity(sp); + } + + return new_const_node; +} + +/** + * @brief Fold Densify if input is Sparse Constant + */ +bool fold_densify(luci::CircleDensify *densify) +{ + auto const_input = dynamic_cast<luci::CircleConst *>(densify->input()); + if (not const_input) + return false; + + if (not is_foldable_const(const_input)) + return false; + + auto dense_const = densified_const_node(const_input); + assert(dense_const); + + loco::replace(densify).with(dense_const); + luci::add_origin(dense_const, luci::composite_origin( + {luci::get_origin(densify), luci::get_origin(const_input)})); + + return true; +} + +} // namespace + +namespace luci +{ + +/** + * BEFORE + * + * [CircleConst](sparse) + * | + * [CircleDensify] + * | + * [CircleNode] + * | + * + * AFTER + * + * [CircleConst](dense) [CircleConst](sparse) + * | | + * [CircleNode] [CircleDensify] + * | + */ +bool FoldDensifyPass::run(loco::Graph *g) +{ + bool changed = false; + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto densify = dynamic_cast<luci::CircleDensify *>(node)) + { + if (fold_densify(densify)) + changed = true; + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FoldDensifyPass.test.cpp b/compiler/luci/pass/src/FoldDensifyPass.test.cpp new file mode 100644 index 000000000..2f9736f49 --- /dev/null +++ b/compiler/luci/pass/src/FoldDensifyPass.test.cpp @@ -0,0 +1,158 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FoldDensifyPass.h" +#include "PassTestGraphs.h" + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +namespace +{ + +class FoldDensifyPassGraph : public luci::ConstantFoldingAddTestGraph +{ +public: + FoldDensifyPassGraph(std::initializer_list<uint32_t> shape) + : luci::ConstantFoldingAddTestGraph(shape, loco::DataType::FLOAT32) + { + _densify = _g.nodes()->create<luci::CircleDensify>(); + _x = _g.nodes()->create<luci::CircleConst>(); + + _densify->dtype(loco::DataType::FLOAT32); + _x->dtype(loco::DataType::FLOAT32); + + _densify->shape(shape); + _x->shape(shape); + + _densify->input(_x); + + _densify->name("densify"); + _x->name("x"); + } + + loco::Node *createFoldedPattern() override { return _densify; } + +public: + void fill_const_dense(void) + { + uint32_t num_elems = 1; + for (uint32_t r = 0; r < _x->rank(); ++r) + num_elems *= _x->dim(r).value(); + + _x->size<loco::DataType::FLOAT32>(num_elems); + for (uint32_t i = 0; i < num_elems; i++) + _x->at<loco::DataType::FLOAT32>(i) = static_cast<float>(i + 1); + } + + void fill_const_sparse(void) + { + // fill 4x4 of + // [[1 0 0 0] + // [0 2 0 0] + // [0 0 3 0] + // [0 0 0 4]] + + // values of 1.0, 2.0, 3.0, 4.0 + uint32_t udata[] = {0x3f800000, 0x40000000, 0x40400000, 0x40800000}; + float *fdata = reinterpret_cast<float *>(udata); + + _x->size<loco::DataType::FLOAT32>(4); + for (uint32_t i = 0; i < 4; i++) + _x->at<loco::DataType::FLOAT32>(i) = fdata[i]; + + auto sparsityparam = std::make_unique<luci::SparsityParam>(); + sparsityparam->traversal_order = std::vector<int32_t>({0, 1}); + sparsityparam->block_map = std::vector<int32_t>({}); + + auto dm0 = luci::DimMetaData(luci::DimensionType::DENSE, 4); + + std::vector<int32_t> as_vec = {0, 1, 2, 3, 4}; + std::vector<int32_t> ai_vec = {0, 1, 2, 3}; + auto as = luci::SparseIndexVector(luci::SparseIndexVectorType::I32, as_vec); + auto ai = luci::SparseIndexVector(luci::SparseIndexVectorType::I32, ai_vec); + auto dm1 = luci::DimMetaData(luci::DimensionType::SPARSE_CSR, 0, as, ai); + sparsityparam->dim_metadata.emplace_back(dm0); + sparsityparam->dim_metadata.emplace_back(dm1); + + _x->sparsityparam(std::move(sparsityparam)); + } + +protected: + luci::CircleDensify *_densify = nullptr; + luci::CircleConst *_x = nullptr; +}; + +class FoldDensifyPassGraphTest : public FoldDensifyPassGraph, public ::testing::Test +{ +public: + FoldDensifyPassGraphTest() : FoldDensifyPassGraph({4, 4}) {} + + virtual void SetUp() { init(); } +}; + +} // namespace + +TEST(FoldDensifyPassGraph, name) +{ + luci::FoldDensifyPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST_F(FoldDensifyPassGraphTest, no_sparsity_param_NEG) +{ + fill_const_dense(); + + luci::FoldDensifyPass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_EQ(nullptr, folded_const); +} + +TEST_F(FoldDensifyPassGraphTest, sparsity_param) +{ + fill_const_sparse(); + + luci::FoldDensifyPass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + EXPECT_EQ(2, folded_const->rank()); + EXPECT_EQ(4, folded_const->dim(0).value()); + EXPECT_EQ(4, folded_const->dim(1).value()); + EXPECT_EQ(16, folded_const->size<loco::DataType::FLOAT32>()); + for (int y = 0; y < 4; ++y) + { + for (int x = 0; x < 4; ++x) + { + float ovalue = folded_const->at<loco::DataType::FLOAT32>(y * 4 + x); + float fvalue = 0.0; + if (x == y) + { + // diagonal position + fvalue = static_cast<float>(y + 1); + } + EXPECT_EQ(fvalue, ovalue); + } + } +} diff --git a/compiler/luci/pass/src/FoldDequantizePass.cpp b/compiler/luci/pass/src/FoldDequantizePass.cpp index 3dd4f8cea..b6526deb0 100644 --- a/compiler/luci/pass/src/FoldDequantizePass.cpp +++ b/compiler/luci/pass/src/FoldDequantizePass.cpp @@ -19,6 +19,8 @@ #include <luci/IR/CircleNodes.h> #include <luci/Profile/CircleNodeOrigin.h> +#include <fp16.h> + namespace { @@ -32,6 +34,9 @@ bool is_hybrid_kernel_supported(loco::Node *node) bool is_foldable_const(luci::CircleConst *node) { + if (node->dtype() == loco::DataType::FLOAT16) + return true; + if (node->quantparam() == nullptr) return false; @@ -39,17 +44,18 @@ bool is_foldable_const(luci::CircleConst *node) return true; if (node->dtype() == loco::DataType::U8) return true; + if (node->dtype() == loco::DataType::S16) + return true; + if (node->dtype() == loco::DataType::S32) + return true; + if (node->dtype() == loco::DataType::S64) + return true; return false; } luci::CircleConst *dequantized_const_node(luci::CircleConst *const_node) { - if (const_node->quantparam() == nullptr) - { - throw std::runtime_error("Given constant node has no quantization parameter"); - } - auto name = const_node->name(); assert(name.length() > 0); auto g = const_node->graph(); @@ -67,38 +73,70 @@ luci::CircleConst *dequantized_const_node(luci::CircleConst *const_node) new_const_node->shape_status(luci::ShapeStatus::VALID); new_const_node->name(name + "_DQ"); + if (const_node->dtype() == loco::DataType::FLOAT16) + { + for (uint32_t i = 0; i < new_const_node->size<loco::DataType::FLOAT32>(); ++i) + { + auto raw = const_node->at<loco::DataType::FLOAT16>(i); + new_const_node->at<loco::DataType::FLOAT32>(i) = fp16_ieee_to_fp32_value(raw); + } + return new_const_node; + } + + if (const_node->quantparam() == nullptr) + { + throw std::runtime_error("Given constant node has no quantization parameter"); + } + const int32_t q_dim = const_node->quantparam()->quantized_dimension; - const int32_t q_dim_value = const_node->dim(q_dim).value(); + // For scalar, q_dim_value is 1 + // For non-scalar, q_dim_value is the size of quantized dimension + const int32_t q_dim_value = const_node->rank() == 0 ? 1 : const_node->dim(q_dim).value(); int32_t right_count = q_dim_value; for (uint32_t i = q_dim + 1; i < const_node->rank(); ++i) right_count *= const_node->dim(i).value(); - if (const_node->dtype() == loco::DataType::S8) + for (uint32_t i = 0; i < new_const_node->size<loco::DataType::FLOAT32>(); ++i) { - for (uint32_t i = 0; i < const_node->size<loco::DataType::S8>(); ++i) - { - uint32_t qd = (i % right_count) / (right_count / q_dim_value); - if (qd >= const_node->quantparam()->zerop.size()) - qd = 0; + uint32_t qd = (i % right_count) / (right_count / q_dim_value); + if (qd >= const_node->quantparam()->zerop.size()) + qd = 0; - new_const_node->at<loco::DataType::FLOAT32>(i) = - (float)(const_node->at<loco::DataType::S8>(i) - const_node->quantparam()->zerop.at(qd)) * - const_node->quantparam()->scale.at(qd); - } - } - else - { - for (uint32_t i = 0; i < const_node->size<loco::DataType::U8>(); ++i) + switch (const_node->dtype()) { - uint32_t qd = (i % right_count) / (right_count / q_dim_value); - if (qd >= const_node->quantparam()->zerop.size()) - qd = 0; - - new_const_node->at<loco::DataType::FLOAT32>(i) = - (float)((int)const_node->at<loco::DataType::U8>(i) - - const_node->quantparam()->zerop.at(qd)) * - const_node->quantparam()->scale.at(qd); + case loco::DataType::S8: + new_const_node->at<loco::DataType::FLOAT32>(i) = + static_cast<float>(const_node->at<loco::DataType::S8>(i) - + const_node->quantparam()->zerop.at(qd)) * + const_node->quantparam()->scale.at(qd); + break; + case loco::DataType::S16: + new_const_node->at<loco::DataType::FLOAT32>(i) = + static_cast<float>(const_node->at<loco::DataType::S16>(i) - + const_node->quantparam()->zerop.at(qd)) * + const_node->quantparam()->scale.at(qd); + break; + case loco::DataType::S32: + new_const_node->at<loco::DataType::FLOAT32>(i) = + static_cast<float>(const_node->at<loco::DataType::S32>(i) - + const_node->quantparam()->zerop.at(qd)) * + const_node->quantparam()->scale.at(qd); + break; + case loco::DataType::S64: + new_const_node->at<loco::DataType::FLOAT32>(i) = + static_cast<float>(const_node->at<loco::DataType::S64>(i) - + const_node->quantparam()->zerop.at(qd)) * + const_node->quantparam()->scale.at(qd); + break; + case loco::DataType::U8: + new_const_node->at<loco::DataType::FLOAT32>(i) = + static_cast<float>(const_node->at<loco::DataType::U8>(i) - + const_node->quantparam()->zerop.at(qd)) * + const_node->quantparam()->scale.at(qd); + break; + default: + throw std::runtime_error("Not supported dtype for FoldDequantizePass"); } } @@ -160,7 +198,7 @@ bool FoldDequantizePass::run(loco::Graph *g) { bool changed = false; - for (auto node : loco::all_nodes(g)) + for (auto node : loco::active_nodes(loco::output_nodes(g))) { if (auto circle_dequant = dynamic_cast<luci::CircleDequantize *>(node)) { diff --git a/compiler/luci/pass/src/FoldDequantizePass.test.cpp b/compiler/luci/pass/src/FoldDequantizePass.test.cpp index d82a7bc87..fb5b6adc0 100644 --- a/compiler/luci/pass/src/FoldDequantizePass.test.cpp +++ b/compiler/luci/pass/src/FoldDequantizePass.test.cpp @@ -15,12 +15,389 @@ */ #include "luci/Pass/FoldDequantizePass.h" +#include "PassTestGraphs.h" #include <gtest/gtest.h> +namespace +{ + +template <loco::DataType DT> +class FoldDequantizeTest : public luci::ConstantFoldingAddTestGraph, public ::testing::Test +{ +public: + FoldDequantizeTest() : luci::ConstantFoldingAddTestGraph({2, 2, 2}, DT) {} + + virtual void SetUp() { init(); } + + loco::Node *createFoldedPattern() override + { + _dequantize = _g.nodes()->create<luci::CircleDequantize>(); + _input = _g.nodes()->create<luci::CircleConst>(); + + _dequantize->dtype(loco::DataType::FLOAT32); + _input->dtype(DT); + + _input->shape({2, 2, 2}); + + _input->size<DT>(8); + _input->at<DT>(0) = 0; + _input->at<DT>(1) = 1; + _input->at<DT>(2) = 2; + _input->at<DT>(3) = 3; + _input->at<DT>(4) = 4; + _input->at<DT>(5) = 5; + _input->at<DT>(6) = 6; + _input->at<DT>(7) = 7; + + auto qparam = std::make_unique<luci::CircleQuantParam>(); + qparam->quantized_dimension = 1; + qparam->scale.push_back(5.0); + qparam->scale.push_back(10.0); + qparam->zerop.push_back(1); + qparam->zerop.push_back(2); + _input->quantparam(std::move(qparam)); + + _dequantize->input(_input); + + _dequantize->name("dequantize"); + _input->name("input"); + + return _dequantize; + } + + void createScalarPattern() + { + _input->rank(0); + _input->size<DT>(1); + _input->at<DT>(0) = 1; + + auto qparam = std::make_unique<luci::CircleQuantParam>(); + qparam->quantized_dimension = 0; + qparam->scale.push_back(1.0); + qparam->zerop.push_back(0); + _input->quantparam(std::move(qparam)); + } + + void createNotFoldablePattern() { _input->quantparam(nullptr); } + +protected: + luci::CircleDequantize *_dequantize = nullptr; + luci::CircleConst *_input = nullptr; +}; + +class S8FoldDequantizeTest : public FoldDequantizeTest<loco::DataType::S8> +{ +}; + +class S16FoldDequantizeTest : public FoldDequantizeTest<loco::DataType::S16> +{ +}; + +class S32FoldDequantizeTest : public FoldDequantizeTest<loco::DataType::S32> +{ +}; + +class S64FoldDequantizeTest : public FoldDequantizeTest<loco::DataType::S64> +{ +}; + +class U8FoldDequantizeTest : public FoldDequantizeTest<loco::DataType::U8> +{ +}; + +class F16FoldDequantizeTest : public luci::ConstantFoldingTestGraph, public ::testing::Test +{ +public: + F16FoldDequantizeTest() : ConstantFoldingTestGraph({2, 2}, loco::DataType::FLOAT16) {} + + virtual void SetUp() { init(); } + + loco::Node *createFoldedPattern() override + { + const auto DT = loco::DataType::FLOAT16; + _dequantize = _g.nodes()->create<luci::CircleDequantize>(); + _f16const = _g.nodes()->create<luci::CircleConst>(); + + _dequantize->dtype(loco::DataType::FLOAT32); + _f16const->dtype(DT); + + _f16const->shape({2, 2}); + + _f16const->size<loco::DataType::FLOAT16>(4); + _f16const->at<DT>(0) = 49408; // -2.5f + _f16const->at<DT>(1) = 47104; // -0.5f + _f16const->at<DT>(2) = 0; // 0.0f + _f16const->at<DT>(3) = 15872; // 1.5f + // NOTE how to get uint16_t value of float16 ? + // Use compiler/souschef/src/Gaussian.cpp GaussianFloat16DataChef::generate() + // uint16_t value = fp16_ieee_from_fp32_value(-2.5); + // printf("-2.5 = %u\r\n", value); + + _dequantize->input(_f16const); + + _dequantize->name("dequantize"); + _f16const->name("input"); + + _output->from(_dequantize); + + return _dequantize; + } + + void createNotFoldablePattern() { _dequantize->input(_input); } + +protected: + luci::CircleConst *getFoldedPattern() override + { + return dynamic_cast<luci::CircleConst *>(_output->from()); + } + + void init() override { createFoldedPattern(); } + +protected: + luci::CircleDequantize *_dequantize = nullptr; + luci::CircleConst *_f16const = nullptr; +}; + +} // namespace + TEST(FoldDequantizePassTest, name) { luci::FoldDequantizePass pass; auto const name = pass.name(); ASSERT_NE(nullptr, name); } + +TEST_F(U8FoldDequantizeTest, fold_dequant_basic) +{ + luci::FoldDequantizePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + // Chec type, shape, values of folded const + EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype()); + EXPECT_EQ(3, folded_const->rank()); + EXPECT_EQ(2, folded_const->dim(0).value()); + EXPECT_EQ(2, folded_const->dim(1).value()); + EXPECT_EQ(2, folded_const->dim(2).value()); + EXPECT_EQ(-5.0, folded_const->at<loco::DataType::FLOAT32>(0)); + EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(1)); + EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(2)); + EXPECT_EQ(10.0, folded_const->at<loco::DataType::FLOAT32>(3)); + EXPECT_EQ(15.0, folded_const->at<loco::DataType::FLOAT32>(4)); + EXPECT_EQ(20.0, folded_const->at<loco::DataType::FLOAT32>(5)); + EXPECT_EQ(40.0, folded_const->at<loco::DataType::FLOAT32>(6)); + EXPECT_EQ(50.0, folded_const->at<loco::DataType::FLOAT32>(7)); +} + +TEST_F(U8FoldDequantizeTest, fold_dequant_basic_NEG) +{ + createNotFoldablePattern(); + + luci::FoldDequantizePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_EQ(nullptr, folded_const); +} + +TEST_F(S8FoldDequantizeTest, fold_dequant_basic) +{ + luci::FoldDequantizePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + // Chec type, shape, values of folded const + EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype()); + EXPECT_EQ(3, folded_const->rank()); + EXPECT_EQ(2, folded_const->dim(0).value()); + EXPECT_EQ(2, folded_const->dim(1).value()); + EXPECT_EQ(2, folded_const->dim(2).value()); + EXPECT_EQ(-5.0, folded_const->at<loco::DataType::FLOAT32>(0)); + EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(1)); + EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(2)); + EXPECT_EQ(10.0, folded_const->at<loco::DataType::FLOAT32>(3)); + EXPECT_EQ(15.0, folded_const->at<loco::DataType::FLOAT32>(4)); + EXPECT_EQ(20.0, folded_const->at<loco::DataType::FLOAT32>(5)); + EXPECT_EQ(40.0, folded_const->at<loco::DataType::FLOAT32>(6)); + EXPECT_EQ(50.0, folded_const->at<loco::DataType::FLOAT32>(7)); +} + +TEST_F(S8FoldDequantizeTest, fold_dequant_basic_NEG) +{ + createNotFoldablePattern(); + + luci::FoldDequantizePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_EQ(nullptr, folded_const); +} + +TEST_F(S16FoldDequantizeTest, fold_dequant_basic) +{ + luci::FoldDequantizePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + // Chec type, shape, values of folded const + EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype()); + EXPECT_EQ(3, folded_const->rank()); + EXPECT_EQ(2, folded_const->dim(0).value()); + EXPECT_EQ(2, folded_const->dim(1).value()); + EXPECT_EQ(2, folded_const->dim(2).value()); + EXPECT_EQ(-5.0, folded_const->at<loco::DataType::FLOAT32>(0)); + EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(1)); + EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(2)); + EXPECT_EQ(10.0, folded_const->at<loco::DataType::FLOAT32>(3)); + EXPECT_EQ(15.0, folded_const->at<loco::DataType::FLOAT32>(4)); + EXPECT_EQ(20.0, folded_const->at<loco::DataType::FLOAT32>(5)); + EXPECT_EQ(40.0, folded_const->at<loco::DataType::FLOAT32>(6)); + EXPECT_EQ(50.0, folded_const->at<loco::DataType::FLOAT32>(7)); +} + +TEST_F(S16FoldDequantizeTest, fold_dequant_basic_NEG) +{ + createNotFoldablePattern(); + + luci::FoldDequantizePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_EQ(nullptr, folded_const); +} + +TEST_F(S32FoldDequantizeTest, fold_dequant_basic) +{ + luci::FoldDequantizePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + // Chec type, shape, values of folded const + EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype()); + EXPECT_EQ(3, folded_const->rank()); + EXPECT_EQ(2, folded_const->dim(0).value()); + EXPECT_EQ(2, folded_const->dim(1).value()); + EXPECT_EQ(2, folded_const->dim(2).value()); + EXPECT_EQ(-5.0, folded_const->at<loco::DataType::FLOAT32>(0)); + EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(1)); + EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(2)); + EXPECT_EQ(10.0, folded_const->at<loco::DataType::FLOAT32>(3)); + EXPECT_EQ(15.0, folded_const->at<loco::DataType::FLOAT32>(4)); + EXPECT_EQ(20.0, folded_const->at<loco::DataType::FLOAT32>(5)); + EXPECT_EQ(40.0, folded_const->at<loco::DataType::FLOAT32>(6)); + EXPECT_EQ(50.0, folded_const->at<loco::DataType::FLOAT32>(7)); +} + +TEST_F(S32FoldDequantizeTest, fold_dequant_basic_NEG) +{ + createNotFoldablePattern(); + + luci::FoldDequantizePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_EQ(nullptr, folded_const); +} + +TEST_F(S64FoldDequantizeTest, fold_dequant_basic) +{ + luci::FoldDequantizePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + // Chec type, shape, values of folded const + EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype()); + EXPECT_EQ(3, folded_const->rank()); + EXPECT_EQ(2, folded_const->dim(0).value()); + EXPECT_EQ(2, folded_const->dim(1).value()); + EXPECT_EQ(2, folded_const->dim(2).value()); + EXPECT_EQ(-5.0, folded_const->at<loco::DataType::FLOAT32>(0)); + EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(1)); + EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(2)); + EXPECT_EQ(10.0, folded_const->at<loco::DataType::FLOAT32>(3)); + EXPECT_EQ(15.0, folded_const->at<loco::DataType::FLOAT32>(4)); + EXPECT_EQ(20.0, folded_const->at<loco::DataType::FLOAT32>(5)); + EXPECT_EQ(40.0, folded_const->at<loco::DataType::FLOAT32>(6)); + EXPECT_EQ(50.0, folded_const->at<loco::DataType::FLOAT32>(7)); +} + +TEST_F(S64FoldDequantizeTest, fold_dequant_basic_NEG) +{ + createNotFoldablePattern(); + + luci::FoldDequantizePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_EQ(nullptr, folded_const); +} + +TEST_F(U8FoldDequantizeTest, fold_dequant_scalar) +{ + createScalarPattern(); + + luci::FoldDequantizePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + // Check type, shape, values of folded const + EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype()); + EXPECT_EQ(0, folded_const->rank()); + EXPECT_EQ(1.0, folded_const->at<loco::DataType::FLOAT32>(0)); +} + +TEST_F(F16FoldDequantizeTest, fold_dequant_basic) +{ + luci::FoldDequantizePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + // Chec type, shape, values of folded const + EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype()); + EXPECT_EQ(2, folded_const->rank()); + EXPECT_EQ(2, folded_const->dim(0).value()); + EXPECT_EQ(2, folded_const->dim(1).value()); + EXPECT_EQ(-2.5, folded_const->at<loco::DataType::FLOAT32>(0)); + EXPECT_EQ(-0.5, folded_const->at<loco::DataType::FLOAT32>(1)); + EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(2)); + EXPECT_EQ(1.5, folded_const->at<loco::DataType::FLOAT32>(3)); +} + +TEST_F(F16FoldDequantizeTest, fold_dequant_basic_NEG) +{ + createNotFoldablePattern(); + + luci::FoldDequantizePass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_EQ(nullptr, folded_const); +} diff --git a/compiler/luci/pass/src/FoldSparseToDensePass.cpp b/compiler/luci/pass/src/FoldSparseToDensePass.cpp index 0c6fc43ed..ed60d8899 100644 --- a/compiler/luci/pass/src/FoldSparseToDensePass.cpp +++ b/compiler/luci/pass/src/FoldSparseToDensePass.cpp @@ -19,6 +19,8 @@ #include <luci/IR/CircleNodes.h> +#include <limits> + namespace { diff --git a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp index 2c990f0a5..bc09abee2 100644 --- a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp +++ b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp @@ -22,6 +22,7 @@ #include <luci/Profile/CircleNodeOrigin.h> #include <luci/Service/CircleShapeInference.h> #include <luci/Service/Nodes/CircleConst.h> +#include <luci/Service/CircleNodeClone.h> namespace { @@ -55,6 +56,26 @@ void copy_shape(luci::CircleReshape *reshape, luci::CircleReshape *new_reshape) new_reshape->newShape()->dim(r) = reshape->newShape()->dim(r); } +luci::CircleReshape *create_cloned_reshape(luci::CircleReshape *reshape) +{ + assert(reshape != nullptr); // FIX_CALLER_UNLESS + + luci::CircleConst *cloned_shape = clone_shape(reshape); + if (cloned_shape == nullptr) + return nullptr; + + auto cloned_node = luci::clone_node(reshape, reshape->graph()); + if (cloned_node == nullptr) + return nullptr; + + auto new_reshape = loco::must_cast<luci::CircleReshape *>(cloned_node); + new_reshape->shape(cloned_shape); + new_reshape->name(reshape->name() + "_C"); + luci::add_origin(new_reshape, luci::get_origin(reshape)); + + return new_reshape; +} + bool forward_reshape(luci::CircleReshape *reshape, luci::CircleNeg *neg) { assert(reshape != nullptr); @@ -85,6 +106,26 @@ bool forward_reshape(luci::CircleReshape *reshape, luci::CircleNeg *neg) return true; } +bool forward_reshape(luci::CircleReshape *reshape, luci::CircleLogistic *logit) +{ + assert(reshape != nullptr); // FIX_CALLER_UNLESS + assert(logit != nullptr); // FIX_CALLER_UNLESS + + auto new_reshape = create_cloned_reshape(reshape); + if (not new_reshape) + return false; + + // reconnect network + loco::replace(logit).with(new_reshape); + logit->x(reshape->tensor()); + new_reshape->tensor(logit); + + // Do shape inference for this node again. + logit->shape_status(luci::ShapeStatus::UNDEFINED); + + return true; +} + class ForwardReshape final : public luci::CircleNodeMutableVisitor<bool> { protected: @@ -103,6 +144,14 @@ protected: return forward_reshape(reshape, node); } + bool visit(luci::CircleLogistic *node) + { + auto reshape = as_reshape(node->x()); + if (reshape == nullptr) + return false; + + return forward_reshape(reshape, node); + } // TODO add more unary operators }; diff --git a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp index 2593a014c..373513270 100644 --- a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp +++ b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp @@ -65,6 +65,42 @@ protected: luci::CircleConst *_reshape_shape = nullptr; }; +// TODO Reduce duplicate code with ReshapeNegGraphlet +class ReshapeLogisticGraphlet +{ +public: + ReshapeLogisticGraphlet() = default; + +public: + void init(loco::Graph *g, const ShapeU32 shape_in, const ShapeU32 shape_out) + { + std::vector<uint32_t> shape_out_v = shape_out; + + _reshape_shape = g->nodes()->create<luci::CircleConst>(); + _reshape = g->nodes()->create<luci::CircleReshape>(); + _logistic = g->nodes()->create<luci::CircleLogistic>(); + + _reshape_shape->dtype(loco::DataType::S32); + _reshape_shape->rank(1); + _reshape_shape->dim(0).set(shape_out_v.size()); + _reshape_shape->shape_status(luci::ShapeStatus::VALID); + // values + const auto size = shape_out_v.size(); + _reshape_shape->size<loco::DataType::S32>(size); + for (uint32_t i = 0; i < size; i++) + _reshape_shape->at<loco::DataType::S32>(i) = shape_out_v[i]; + + _reshape_shape->name("reshape_shape"); + _reshape->name("reshape"); + _logistic->name("logistic"); + } + +protected: + luci::CircleReshape *_reshape = nullptr; + luci::CircleLogistic *_logistic = nullptr; + luci::CircleConst *_reshape_shape = nullptr; +}; + class ForwardReshapeToNegGraph : public TestIOGraph, public ReshapeNegGraphlet { public: @@ -85,6 +121,26 @@ public: } }; +class ForwardReshapeToLogisticGraph : public TestIOGraph, public ReshapeLogisticGraphlet +{ +public: + ForwardReshapeToLogisticGraph() = default; + +public: + void init(const ShapeU32 shape_in, const ShapeU32 shape_out) + { + TestIOGraph::init(shape_in, shape_out); + ReshapeLogisticGraphlet::init(g(), shape_in, shape_out); + + // connect network + _reshape->tensor(input()); + _reshape->shape(_reshape_shape); + _logistic->x(_reshape); + + output()->from(_logistic); + } +}; + class ForwardReshapeToNegGraphTest : public ::testing::Test { public: @@ -101,6 +157,22 @@ protected: luci::ForwardReshapeToUnaryOpPass _pass; }; +class ForwardReshapeToLogisticGraphTest : public ::testing::Test +{ +public: + ForwardReshapeToLogisticGraphTest() = default; + + void run_pass(void) + { + while (_pass.run(_graph.g())) + ; + } + +protected: + ForwardReshapeToLogisticGraph _graph; + luci::ForwardReshapeToUnaryOpPass _pass; +}; + } // namespace TEST(ForwardReshapeToUnaryOpPassTest, name) @@ -123,3 +195,17 @@ TEST_F(ForwardReshapeToNegGraphTest, simple_forward) neg = dynamic_cast<luci::CircleNeg *>(reshape->tensor()); ASSERT_NE(nullptr, neg); } + +TEST_F(ForwardReshapeToLogisticGraphTest, forward) +{ + _graph.init({2, 2, 2}, {2, 4}); + + run_pass(); + + auto reshape = dynamic_cast<luci::CircleReshape *>(_graph.output()->from()); + auto log = dynamic_cast<luci::CircleLogistic *>(_graph.output()->from()); + ASSERT_NE(nullptr, reshape); + ASSERT_EQ(nullptr, log); + log = dynamic_cast<luci::CircleLogistic *>(reshape->tensor()); + ASSERT_NE(nullptr, log); +} diff --git a/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp index 97a962cb6..3cf31ed10 100644 --- a/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp +++ b/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp @@ -99,6 +99,12 @@ bool fuse_add_with_fc(luci::CircleFullyConnected *fc) fused_bias->at<loco::DataType::FLOAT32>(i) += const_bias->at<loco::DataType::FLOAT32>(i); } + // At this point, it is guarateed that fused_bias's shape is [1, 1, ..., N] or [N] + // where N is weights->dim(0). + // The shape is normalized to [N] to become the bias of FC + fused_bias->rank(1); + fused_bias->dim(0) = weights->dim(0); + fc->bias(fused_bias); fc->fusedActivationFunction(add->fusedActivationFunction()); diff --git a/compiler/luci/pass/src/FuseAddWithTConvPass.cpp b/compiler/luci/pass/src/FuseAddWithTConvPass.cpp index 2bca57014..852bc8b63 100644 --- a/compiler/luci/pass/src/FuseAddWithTConvPass.cpp +++ b/compiler/luci/pass/src/FuseAddWithTConvPass.cpp @@ -37,10 +37,10 @@ namespace * \ | * [CircleTransposeConv] [CircleAdd] * | - * ([CircleRelu6]) + * ([CircleRelu/Relu6]) * | * - * Note: CircleRelu6 is inserted if Add activation is ReLU6 + * Note: CircleRelu/Relu6 is inserted if Add activation is ReLU6 */ bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv) { @@ -65,7 +65,8 @@ bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv) if (add->dtype() != loco::DataType::FLOAT32) return false; if (add->fusedActivationFunction() != luci::FusedActFunc::NONE && - add->fusedActivationFunction() != luci::FusedActFunc::RELU6) + add->fusedActivationFunction() != luci::FusedActFunc::RELU6 && + add->fusedActivationFunction() != luci::FusedActFunc::RELU) return false; // get addition @@ -102,6 +103,19 @@ bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv) // remove add node replace(add).with(relu); } + else if (add->fusedActivationFunction() == luci::FusedActFunc::RELU) + { + auto name = addition->name(); + assert(name.length() > 0); + // separate relu op from add op + auto relu = add->graph()->nodes()->create<luci::CircleRelu>(); + relu->features(tconv); + relu->name(name + "/Relu"); + luci::add_origin(relu, luci::get_origin(add)); + + // remove add node + replace(add).with(relu); + } else { replace(add).with(tconv); diff --git a/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp b/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp index 337954960..e6b54df36 100644 --- a/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp +++ b/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp @@ -29,7 +29,7 @@ namespace * NOTE TF's BatchNormalization is converted to Mul and Add. * * BEFORE - * | [CircleOutputExclude] + * | [CircleConst]/[CircleOutputExclude] * | / [CircleConst] * | / / * [CircleTransposeConv] [CircleConst] @@ -40,7 +40,7 @@ namespace * | * * AFTER - * | [CircleOutputExclude] + * | [CircleConst]/[CircleOutputExclude] * +-------------------------------------+ / [CircleConst] * | | / / * | [CircleTransposeConv] [CircleConst] @@ -69,9 +69,10 @@ bool fused_batch_norm_with_tconv(luci::CircleAdd *add) return false; // check scale and shift constant attributes - if (scale->rank() != 1) + // TODO maybe rank check is not needed + if (scale->rank() != 1 && scale->rank() != 4) return false; - if (shift->rank() != 1) + if (shift->rank() != 1 && shift->rank() != 4) return false; // check mul, add attributes if (mul->dtype() != loco::DataType::FLOAT32) @@ -82,9 +83,8 @@ bool fused_batch_norm_with_tconv(luci::CircleAdd *add) add->fusedActivationFunction() != luci::FusedActFunc::RELU6) return false; - // tconv bias should be not set - if (not dynamic_cast<luci::CircleOutputExclude *>(tconv->bias())) - return false; + // tconv bias is optional + auto bias = dynamic_cast<luci::CircleConst *>(tconv->bias()); // get weight of tconv auto filter = dynamic_cast<luci::CircleConst *>(tconv->filter()); @@ -96,10 +96,36 @@ bool fused_batch_norm_with_tconv(luci::CircleAdd *add) return false; auto filter_out_chn = filter->dim(0).value(); - if (filter_out_chn != scale->dim(0).value()) + // allow scale/shift and bias shape of [N], [1,1,1,N]; BN works for "channel-wise" + auto srank = scale->rank() - 1; + if (filter_out_chn != scale->dim(srank).value()) return false; - if (filter_out_chn != shift->dim(0).value()) + for (uint32_t d = 0; d < srank; ++d) + { + if (1 != scale->dim(d).value()) + return false; + } + srank = shift->rank() - 1; + if (filter_out_chn != shift->dim(srank).value()) return false; + for (uint32_t d = 0; d < srank; ++d) + { + if (1 != shift->dim(d).value()) + return false; + } + if (bias) + { + if (bias->dtype() != loco::DataType::FLOAT32) + return false; + srank = bias->rank() - 1; + if (filter_out_chn != bias->dim(srank).value()) + return false; + for (uint32_t d = 0; d < srank; ++d) + { + if (1 != bias->dim(d).value()) + return false; + } + } auto name = add->name(); assert(name.length() > 0); @@ -151,6 +177,11 @@ bool fused_batch_norm_with_tconv(luci::CircleAdd *add) for (uint32_t c = 0; c < filter_out_chn; ++c) { fused_bias->at<loco::DataType::FLOAT32>(c) = shift->at<loco::DataType::FLOAT32>(c); + if (bias != nullptr) + { + fused_bias->at<loco::DataType::FLOAT32>(c) += + bias->at<loco::DataType::FLOAT32>(c) * scale->at<loco::DataType::FLOAT32>(c); + } } fused_bias->name(name + "/TransposeConv/bias"); @@ -166,6 +197,10 @@ bool fused_batch_norm_with_tconv(luci::CircleAdd *add) luci::add_origin(fused_tconv, luci::composite_origin( {luci::get_origin(add), luci::get_origin(mul), luci::get_origin(tconv)})); + if (bias != nullptr) + { + luci::add_origin(fused_tconv, luci::get_origin(bias)); + } if (add->fusedActivationFunction() == luci::FusedActFunc::RELU6) { diff --git a/compiler/luci/pass/src/FuseInstanceNormPass.cpp b/compiler/luci/pass/src/FuseInstanceNormPass.cpp index f3ec6cd9e..10a651e35 100644 --- a/compiler/luci/pass/src/FuseInstanceNormPass.cpp +++ b/compiler/luci/pass/src/FuseInstanceNormPass.cpp @@ -325,6 +325,10 @@ public: } private: + bool condition_common_1_5(uint32_t ifm_channel_depth); + bool condition_common_3_4(); + +private: template <enum PatternVersion> bool match(); public: @@ -368,21 +372,8 @@ private: if (not(condition)) \ return false; -template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_1>() +bool InstanceNormPattern::condition_common_1_5(uint32_t ifm_channel_depth) { - CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal)); - CHECK_OR_FALSE(luci::fill(&ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm)); - - auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm); - CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID); - CHECK_OR_FALSE(ifm_circle->rank() == 4); - CHECK_OR_FALSE(ifm_circle->dim(3).known()); - uint32_t ifm_channel_depth = ifm_circle->dim(3).value(); - - CHECK_OR_FALSE(luci::fill(&rsqrt, &const_as_gamma).with_commutative_args_of(mul_gamma)); - - CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_gamma, ifm_channel_depth)); - add_as_variance = dynamic_cast<luci::CircleAdd *>(rsqrt->x()); CHECK_OR_FALSE(add_as_variance); @@ -408,6 +399,70 @@ template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion: CHECK_OR_FALSE(const_as_beta); CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_beta, ifm_channel_depth)); + return true; +} + +bool InstanceNormPattern::condition_common_3_4() +{ + // check left sub + ifm = sub->x(); + CHECK_OR_FALSE(ifm); + + luci::CircleNode *ifm_node = loco::must_cast<luci::CircleNode *>(ifm); + CHECK_OR_FALSE(ifm_node->rank() == 4); + CHECK_OR_FALSE(ifm_node->dim(3).known()); + + mean_of_ifm = dynamic_cast<luci::CircleMean *>(sub->y()); + CHECK_OR_FALSE(mean_of_ifm); + CHECK_OR_FALSE(ifm == mean_of_ifm->input()); + + // continue search from add_as_variance + CHECK_OR_FALSE(luci::fill(&sqrt, &const_as_epsilon).with_commutative_args_of(add_as_variance)); + CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32); + // TODO Support regarding broadcast + CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1); + + mean_as_variance = dynamic_cast<luci::CircleMean *>(sqrt->x()); + CHECK_OR_FALSE(mean_as_variance); + + square = dynamic_cast<luci::CircleSquare *>(mean_as_variance->input()); + CHECK_OR_FALSE(square); + + sub_2 = dynamic_cast<luci::CircleSub *>(square->x()); + CHECK_OR_FALSE(sub_2); + CHECK_OR_FALSE(ifm == sub_2->x()); + + mean_of_ifm_2 = dynamic_cast<luci::CircleMean *>(sub_2->y()); + CHECK_OR_FALSE(mean_of_ifm_2); + CHECK_OR_FALSE(ifm == mean_of_ifm_2->input()); + + loco::Node *ifm_should_be = nullptr; + luci::CircleMean *mean_of_ifm_2_should_be = nullptr; + CHECK_OR_FALSE( + luci::fill(&ifm_should_be, &mean_of_ifm_2_should_be).with_commutative_args_of(sub_2)); + CHECK_OR_FALSE(ifm == ifm_should_be); + CHECK_OR_FALSE(mean_of_ifm_2 == mean_of_ifm_2_should_be); + + return true; +} + +template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_1>() +{ + CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal)); + CHECK_OR_FALSE(luci::fill(&ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm)); + + auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm); + CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID); + CHECK_OR_FALSE(ifm_circle->rank() == 4); + CHECK_OR_FALSE(ifm_circle->dim(3).known()); + uint32_t ifm_channel_depth = ifm_circle->dim(3).value(); + + CHECK_OR_FALSE(luci::fill(&rsqrt, &const_as_gamma).with_commutative_args_of(mul_gamma)); + + CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_gamma, ifm_channel_depth)); + + CHECK_OR_FALSE(condition_common_1_5(ifm_channel_depth)); + luci::CircleMul *mul_gamma_should_be = nullptr; luci::CircleMean *mean_of_ifm_should_be = nullptr; @@ -488,44 +543,7 @@ template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion: CHECK_OR_FALSE(luci::fill(&div, &const_as_gamma).with_commutative_args_of(mul_gamma)); CHECK_OR_FALSE(luci::fill(&sub, &add_as_variance).with_commutative_args_of(div)); - // check left sub - ifm = sub->x(); - CHECK_OR_FALSE(ifm); - - luci::CircleNode *ifm_node = loco::must_cast<luci::CircleNode *>(ifm); - CHECK_OR_FALSE(ifm_node->rank() == 4); - CHECK_OR_FALSE(ifm_node->dim(3).known()); - - mean_of_ifm = dynamic_cast<luci::CircleMean *>(sub->y()); - CHECK_OR_FALSE(mean_of_ifm); - CHECK_OR_FALSE(ifm == mean_of_ifm->input()); - - // continue search from add_as_variance - CHECK_OR_FALSE(luci::fill(&sqrt, &const_as_epsilon).with_commutative_args_of(add_as_variance)); - CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32); - // TODO Support regarding broadcast - CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1); - - mean_as_variance = dynamic_cast<luci::CircleMean *>(sqrt->x()); - CHECK_OR_FALSE(mean_as_variance); - - square = dynamic_cast<luci::CircleSquare *>(mean_as_variance->input()); - CHECK_OR_FALSE(square); - - sub_2 = dynamic_cast<luci::CircleSub *>(square->x()); - CHECK_OR_FALSE(sub_2); - CHECK_OR_FALSE(ifm == sub_2->x()); - - mean_of_ifm_2 = dynamic_cast<luci::CircleMean *>(sub_2->y()); - CHECK_OR_FALSE(mean_of_ifm_2); - CHECK_OR_FALSE(ifm == mean_of_ifm_2->input()); - - loco::Node *ifm_should_be = nullptr; - luci::CircleMean *mean_of_ifm_2_should_be = nullptr; - CHECK_OR_FALSE( - luci::fill(&ifm_should_be, &mean_of_ifm_2_should_be).with_commutative_args_of(sub_2)); - CHECK_OR_FALSE(ifm == ifm_should_be); - CHECK_OR_FALSE(mean_of_ifm_2 == mean_of_ifm_2_should_be); + CHECK_OR_FALSE(condition_common_3_4()); _matched = true; return true; @@ -546,44 +564,7 @@ template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion: CHECK_OR_FALSE(div); CHECK_OR_FALSE(luci::fill(&sub, &add_as_variance).with_commutative_args_of(div)); - // check left sub - ifm = sub->x(); - CHECK_OR_FALSE(ifm); - - luci::CircleNode *ifm_node = loco::must_cast<luci::CircleNode *>(ifm); - CHECK_OR_FALSE(ifm_node->rank() == 4); - CHECK_OR_FALSE(ifm_node->dim(3).known()); - - mean_of_ifm = dynamic_cast<luci::CircleMean *>(sub->y()); - CHECK_OR_FALSE(mean_of_ifm); - CHECK_OR_FALSE(ifm == mean_of_ifm->input()); - - // continue search from add_as_variance - CHECK_OR_FALSE(luci::fill(&sqrt, &const_as_epsilon).with_commutative_args_of(add_as_variance)); - CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32); - // TODO Support regarding broadcast - CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1); - - mean_as_variance = dynamic_cast<luci::CircleMean *>(sqrt->x()); - CHECK_OR_FALSE(mean_as_variance); - - square = dynamic_cast<luci::CircleSquare *>(mean_as_variance->input()); - CHECK_OR_FALSE(square); - - sub_2 = dynamic_cast<luci::CircleSub *>(square->x()); - CHECK_OR_FALSE(sub_2); - CHECK_OR_FALSE(ifm == sub_2->x()); - - mean_of_ifm_2 = dynamic_cast<luci::CircleMean *>(sub_2->y()); - CHECK_OR_FALSE(mean_of_ifm_2); - CHECK_OR_FALSE(ifm == mean_of_ifm_2->input()); - - loco::Node *ifm_should_be = nullptr; - luci::CircleMean *mean_of_ifm_2_should_be = nullptr; - CHECK_OR_FALSE( - luci::fill(&ifm_should_be, &mean_of_ifm_2_should_be).with_commutative_args_of(sub_2)); - CHECK_OR_FALSE(ifm == ifm_should_be); - CHECK_OR_FALSE(mean_of_ifm_2 == mean_of_ifm_2_should_be); + CHECK_OR_FALSE(condition_common_3_4()); assert(const_as_gamma == nullptr); assert(const_as_beta == nullptr); @@ -612,30 +593,7 @@ template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion: CHECK_OR_FALSE(ifm_circle->dim(3).known()); uint32_t ifm_channel_depth = ifm_circle->dim(3).value(); - add_as_variance = dynamic_cast<luci::CircleAdd *>(rsqrt->x()); - CHECK_OR_FALSE(add_as_variance); - - CHECK_OR_FALSE( - luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance)); - - CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32); - // TODO Support regarding broadcast - CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1); - - CHECK_OR_FALSE(is_instance_mean_v1(mean_as_variance)); - - sqdiff = dynamic_cast<luci::CircleSquaredDifference *>(mean_as_variance->input()); - CHECK_OR_FALSE(sqdiff); - - loco::Node *ifm_should_be = nullptr; - CHECK_OR_FALSE(luci::fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff)); - CHECK_OR_FALSE(ifm == ifm_should_be); - CHECK_OR_FALSE(is_instance_mean_v1(mean_of_ifm)); - CHECK_OR_FALSE(ifm == mean_of_ifm->input()); - - const_as_beta = dynamic_cast<luci::CircleConst *>(sub->x()); - CHECK_OR_FALSE(const_as_beta); - CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_beta, ifm_channel_depth)); + CHECK_OR_FALSE(condition_common_1_5(ifm_channel_depth)); luci::CircleRsqrt *rsqrt_should_be = nullptr; luci::CircleMean *mean_of_ifm_should_be = nullptr; diff --git a/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp b/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp index b4975486d..e8fa2a478 100644 --- a/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp +++ b/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp @@ -23,6 +23,7 @@ #include <luci/Log.h> #include <cmath> +#include <limits> namespace { diff --git a/compiler/luci/pass/src/PropagateQParamForwardPass.cpp b/compiler/luci/pass/src/PropagateQParamForwardPass.cpp index 003e4c293..aaadb2864 100644 --- a/compiler/luci/pass/src/PropagateQParamForwardPass.cpp +++ b/compiler/luci/pass/src/PropagateQParamForwardPass.cpp @@ -138,13 +138,18 @@ struct PropagateQParamForward final : public luci::CircleNodeMutableVisitor<bool auto qtype = luci::activation_qtype(input_node); switch (qtype) { - case luci::ActivationQType::PreDefinedValue: - node->quantparam(luci::make_predefined_qparam(input_node->opcode(), node->dtype())); + case luci::ActivationQType::PreDefinedLogistic: + case luci::ActivationQType::PreDefinedTanh: + case luci::ActivationQType::PreDefinedSoftmax: + node->quantparam(luci::make_predefined_qparam(qtype, node->dtype())); break; case luci::ActivationQType::IntScale: luci::set_int_scale(node); break; default: + // This assert ensures this switch-satement handles all ActivationQTypes + // TODO Find a better design to remove coupling with ActivationQType + assert(qtype == luci::ActivationQType::MinMax); break; } diff --git a/compiler/luci/pass/src/QuantizationUtils.cpp b/compiler/luci/pass/src/QuantizationUtils.cpp index ad86cedf4..06a4ae9f6 100644 --- a/compiler/luci/pass/src/QuantizationUtils.cpp +++ b/compiler/luci/pass/src/QuantizationUtils.cpp @@ -20,6 +20,7 @@ #include <iostream> #include <cmath> +#include <limits> namespace luci { @@ -276,31 +277,70 @@ uint32_t cal_offset(loco::TensorShape &dimension, uint32_t *indices) indices[2] * dimension.dim(3).value() + indices[3]; } +// Activation (ofm) qtype is determined in different ways. +// 1. Pre-defined values: Some Ops have pre-defined qparams (ex: LOGISTIC, TANH) +// 2. Integer scale: Output of some Ops should be integers (ex: FLOOR, CEIL) +// 3. Activation qtype of input: Some Ops propagate qparam from input to output (ex: QUANTIZE, +// TRANSPOSE, etc. See PropagateQParamForwardPass.cpp for more details). ActivationQType activation_qtype(const CircleNode *node) { auto fused_act_node = dynamic_cast<const CircleNodeMixin<CircleNodeTrait::FusedActFunc> *>(node); if (fused_act_node && fused_act_node->fusedActivationFunction() == FusedActFunc::TANH) - return ActivationQType::PreDefinedValue; + return ActivationQType::PreDefinedTanh; + +#define RETURN_INPUT_ACTIVATION_QTYPE(CLASS, INPUT) \ + { \ + auto n = loco::must_cast<const CLASS *>(node); \ + auto input = loco::must_cast<CircleNode *>(n->INPUT()); \ + return activation_qtype(input); \ + } switch (node->opcode()) { case CircleOpcode::LOGISTIC: + return ActivationQType::PreDefinedLogistic; case CircleOpcode::TANH: + return ActivationQType::PreDefinedTanh; case CircleOpcode::SOFTMAX: - return ActivationQType::PreDefinedValue; + return ActivationQType::PreDefinedSoftmax; case CircleOpcode::FLOOR: case CircleOpcode::FLOOR_DIV: case CircleOpcode::FLOOR_MOD: case CircleOpcode::CEIL: return ActivationQType::IntScale; + case CircleOpcode::GATHER: + RETURN_INPUT_ACTIVATION_QTYPE(CircleGather, params); + case CircleOpcode::RESHAPE: + RETURN_INPUT_ACTIVATION_QTYPE(CircleReshape, tensor); + case CircleOpcode::TRANSPOSE: + RETURN_INPUT_ACTIVATION_QTYPE(CircleTranspose, a); + case CircleOpcode::STRIDED_SLICE: + RETURN_INPUT_ACTIVATION_QTYPE(CircleStridedSlice, input); + case CircleOpcode::SPLIT: + RETURN_INPUT_ACTIVATION_QTYPE(CircleSplit, input); + case CircleOpcode::CIRCLESPLITOUT: + RETURN_INPUT_ACTIVATION_QTYPE(CircleSplitOut, input); + case CircleOpcode::SPLIT_V: + RETURN_INPUT_ACTIVATION_QTYPE(CircleSplitV, input); + case CircleOpcode::CIRCLESPLITVOUT: + RETURN_INPUT_ACTIVATION_QTYPE(CircleSplitVOut, input); + case CircleOpcode::UNPACK: + RETURN_INPUT_ACTIVATION_QTYPE(CircleUnpack, value); + case CircleOpcode::CIRCLEUNPACKOUT: + RETURN_INPUT_ACTIVATION_QTYPE(CircleUnpackOut, input); + case CircleOpcode::QUANTIZE: + RETURN_INPUT_ACTIVATION_QTYPE(CircleQuantize, input); default: break; } +#undef RETURN_INPUT_ACTIVATION_QTYPE + return ActivationQType::MinMax; } -std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, loco::DataType dtype) +std::unique_ptr<CircleQuantParam> make_predefined_qparam(ActivationQType qtype, + loco::DataType dtype) { auto qparam = std::make_unique<CircleQuantParam>(); @@ -309,9 +349,9 @@ std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, lo qparam->zerop.emplace_back(zp); }; - switch (opcode) + switch (qtype) { - case CircleOpcode::LOGISTIC: + case ActivationQType::PreDefinedLogistic: if (dtype == loco::DataType::U8) set_qparam(1.0f / 256.0f, 0); else @@ -320,7 +360,7 @@ std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, lo set_qparam(1.0f / 32768.0f, 0); } break; - case CircleOpcode::TANH: + case ActivationQType::PreDefinedTanh: if (dtype == loco::DataType::U8) set_qparam(2.0f / 256.0f, 128); else @@ -329,7 +369,7 @@ std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, lo set_qparam(1.0f / 32768.0f, 0); } break; - case CircleOpcode::SOFTMAX: + case ActivationQType::PreDefinedSoftmax: if (dtype == loco::DataType::U8) set_qparam(1.0f / 255.0f, 0); else @@ -341,7 +381,7 @@ std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, lo default: throw std::runtime_error("Unsupported opcode with pre-defined qparam"); } - return std::move(qparam); + return qparam; } // For nodes with integer output, we use integer scale @@ -395,4 +435,74 @@ void quant_const(luci::CircleConst *node, loco::DataType quant_type) node->quantparam(std::move(quantparam)); } +namespace +{ + +// TODO move this to a more global helper file +int nbits(loco::DataType dt) noexcept +{ + switch (dt) + { + case loco::DataType::S8: + case loco::DataType::U8: + return 8; + case loco::DataType::S16: + case loco::DataType::U16: + case loco::DataType::FLOAT16: + return 16; + case loco::DataType::S32: + case loco::DataType::U32: + case loco::DataType::FLOAT32: + return 32; + case loco::DataType::S64: + return 64; + default: + return 64; // a safe large default + } +} + +// TODO Check if the metric is valid +// Returns true if [min,max] is poorly representable +bool range_check(float min, float max, loco::DataType dtype) +{ + float thresh = 1.5f; + return log2f(max) - log2f(min) > nbits(dtype) * thresh; +} + +bool warn_scale_zp(float scale, int64_t zp, luci::CircleNode *n) +{ + float min, max; + // estimate min/max + switch (n->dtype()) + { + case loco::DataType::U8: + min = scale * (0 - zp); + max = scale * (255 - zp); + break; + case loco::DataType::S16: + min = scale * (-32767); + max = scale * (32767); + break; + default: + return false; + } + return range_check(min, max, n->dtype()); +} + +} // namespace + +void warn_accuracy_with_range(luci::CircleNode *n) +{ + LOGGER(l); + auto qp = n->quantparam(); + auto k = qp->zerop.size(); + for (uint32_t i = 0; i < k; i++) + { + if (warn_scale_zp(qp->scale[i], qp->zerop[i], n)) + WARN(l) << "Quantization of " << i << "-th channel of " << n->name() + << "'s quantization may cause accuracy issues" << std::endl; + ; + } +} + } // namespace luci diff --git a/compiler/luci/pass/src/QuantizationUtils.h b/compiler/luci/pass/src/QuantizationUtils.h index cd8cec95a..4d5316ccb 100644 --- a/compiler/luci/pass/src/QuantizationUtils.h +++ b/compiler/luci/pass/src/QuantizationUtils.h @@ -62,15 +62,19 @@ bool is_quantized(const CircleNode *node); enum ActivationQType { - MinMax, // Quantize using recorded min/max - PreDefinedValue, // Quantize using pre-defined values - IntScale, // Round scale to a positive integer + MinMax, // Quantize using recorded min/max + PreDefinedLogistic, // Quantize using pre-defined values + PreDefinedTanh, // Quantize using pre-defined values + PreDefinedSoftmax, // Quantize using pre-defined values + IntScale, // Round scale to a positive integer }; ActivationQType activation_qtype(const CircleNode *node); // Create qparam with pre-defined values for speical operators -std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, loco::DataType dtype); +std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleNode *node, loco::DataType dtype); +std::unique_ptr<CircleQuantParam> make_predefined_qparam(ActivationQType qtype, + loco::DataType dtype); // Update node's scale to a positive integer (for special Ops e.g., Floor, Ceil) void set_int_scale(luci::CircleNode *node); @@ -78,6 +82,10 @@ void set_int_scale(luci::CircleNode *node); // Quantize const tensor using its min/max values void quant_const(luci::CircleConst *node, loco::DataType quant_type); +// Check that a node is quantized without significant loss of precision; +// Emits warnings to log with WARN +void warn_accuracy_with_range(luci::CircleNode *n); + } // namespace luci #endif // __LUCI_QUANTIZATION_UTILS_H__ diff --git a/compiler/luci/pass/src/QuantizeActivation.cpp b/compiler/luci/pass/src/QuantizeActivation.cpp index 149331824..95251a82c 100644 --- a/compiler/luci/pass/src/QuantizeActivation.cpp +++ b/compiler/luci/pass/src/QuantizeActivation.cpp @@ -114,29 +114,26 @@ void QuantizeSpecialActivation::visit(luci::CircleNode *node) auto fused_act_node = dynamic_cast<CircleNodeMixin<CircleNodeTrait::FusedActFunc> *>(node); if (fused_act_node != nullptr && fused_act_node->fusedActivationFunction() == FusedActFunc::TANH) { - auto qparam = make_predefined_qparam(luci::CircleOpcode::TANH, output_type); + auto qparam = make_predefined_qparam(luci::ActivationQType::PreDefinedTanh, output_type); node->quantparam(std::move(qparam)); } } void QuantizeSpecialActivation::visit(luci::CircleLogistic *node) { - assert(activation_qtype(node) == luci::ActivationQType::PreDefinedValue); - auto qparam = make_predefined_qparam(luci::CircleOpcode::LOGISTIC, output_type); + auto qparam = make_predefined_qparam(luci::ActivationQType::PreDefinedLogistic, output_type); node->quantparam(std::move(qparam)); } void QuantizeSpecialActivation::visit(luci::CircleTanh *node) { - assert(activation_qtype(node) == luci::ActivationQType::PreDefinedValue); - auto qparam = make_predefined_qparam(luci::CircleOpcode::TANH, output_type); + auto qparam = make_predefined_qparam(luci::ActivationQType::PreDefinedTanh, output_type); node->quantparam(std::move(qparam)); } void QuantizeSpecialActivation::visit(luci::CircleSoftmax *node) { - assert(activation_qtype(node) == luci::ActivationQType::PreDefinedValue); - auto qparam = make_predefined_qparam(luci::CircleOpcode::SOFTMAX, output_type); + auto qparam = make_predefined_qparam(luci::ActivationQType::PreDefinedSoftmax, output_type); node->quantparam(std::move(qparam)); } diff --git a/compiler/luci/pass/src/QuantizeBias.cpp b/compiler/luci/pass/src/QuantizeBias.cpp index aa496232a..de97a14dd 100644 --- a/compiler/luci/pass/src/QuantizeBias.cpp +++ b/compiler/luci/pass/src/QuantizeBias.cpp @@ -22,6 +22,7 @@ #include <algorithm> #include <cmath> +#include <limits> using namespace luci; @@ -201,6 +202,18 @@ CircleConst *QuantizeBias::quantized_bias(CircleNode *input, const CircleNode *w std::vector<float> scaling_factor(size); std::vector<int64_t> zp(size); + if (const_bias->rank() == 0) + { + // TODO Support quantization of scalar bias + throw std::runtime_error("Quantization of scalar bias is not yet supported (" + + const_bias->name() + ")"); + } + if (size != const_bias->dim(const_bias->rank() - 1).value()) + { + throw std::runtime_error(const_bias->name() + + " (bias) should have the shape of [1, 1, .. 1, channel]"); + } + if (output_type == loco::DataType::U8) { new_bias = quant_bias_per_channel(const_bias, input_scale, weight_scale, scaling_factor, zp); @@ -218,6 +231,7 @@ CircleConst *QuantizeBias::quantized_bias(CircleNode *input, const CircleNode *w auto quantparam = std::make_unique<CircleQuantParam>(); quantparam->scale = scaling_factor; quantparam->zerop = zp; + quantparam->quantized_dimension = const_bias->rank() - 1; assert(new_bias->quantparam() == nullptr); // bias should not be quantized before new_bias->quantparam(std::move(quantparam)); diff --git a/compiler/luci/pass/src/QuantizeBias.test.cpp b/compiler/luci/pass/src/QuantizeBias.test.cpp new file mode 100644 index 000000000..0104a191b --- /dev/null +++ b/compiler/luci/pass/src/QuantizeBias.test.cpp @@ -0,0 +1,189 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "QuantizeBias.h" + +#include <luci/test/TestIOGraph.h> +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleQuantParam.h> + +#include <gtest/gtest.h> + +using namespace luci; + +namespace +{ + +using namespace luci::test; + +// TODO Reduce duplicate codes in ResolveCustomOpMatMulPass.cpp +template <typename T> +luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype, + const std::vector<uint32_t> &shape, T value) +{ + auto node = g->nodes()->create<luci::CircleConst>(); + node->dtype(dtype); + node->rank(shape.size()); + + uint32_t size = 1; + for (uint32_t i = 0; i < shape.size(); ++i) + { + node->dim(i) = shape.at(i); + size *= shape.at(i); + } + node->shape_status(luci::ShapeStatus::VALID); + +#define INIT_VALUES(DT) \ + { \ + node->size<DT>(size); \ + for (uint32_t i = 0; i < size; ++i) \ + node->at<DT>(i) = value; \ + } + + switch (dtype) + { + case loco::DataType::U8: + INIT_VALUES(loco::DataType::U8); + break; + case loco::DataType::S16: + INIT_VALUES(loco::DataType::S16); + break; + case loco::DataType::S32: + INIT_VALUES(loco::DataType::S32); + break; + case loco::DataType::FLOAT32: + INIT_VALUES(loco::DataType::FLOAT32) + break; + default: + INTERNAL_EXN("create_const_node called with unsupported type"); + break; + } + return node; +} + +/** + * Simple graph for test + * + * BEFORE + * + * [IFM] [WEIGHTS] [BIAS(FP32)] + * \ | / + * [FC] + * | + * [OFM] + * + * AFTER + * + * [IFM] [WEIGHTS] [BIAS(Quantized)] + * \ | / + * [FC] + * | + * [OFM] + */ +struct Q8FCGraphlet +{ +public: + Q8FCGraphlet() = default; + virtual ~Q8FCGraphlet() = default; + + void init(loco::Graph *g, const ShapeU32 out_shape, const ShapeU32 w_shape, + const ShapeU32 bias_shape, const float bv) + { + _fc = g->nodes()->create<luci::CircleFullyConnected>(); + _fc->input(_x); + _x->dtype(loco::DataType::U8); + { + auto quantparam = std::make_unique<CircleQuantParam>(); + quantparam->scale.push_back(1.0); + quantparam->zerop.push_back(0); + quantparam->quantized_dimension = 0; + _x->quantparam(std::move(quantparam)); + } + + auto weights = create_const_node<uint8_t>(g, loco::DataType::U8, w_shape, 1.0); + auto w_qparam = std::make_unique<CircleQuantParam>(); + std::vector<float> w_scale(weights->dim(0).value(), 1.0); + std::vector<int64_t> w_zp(weights->dim(0).value(), 0); + w_qparam->scale = w_scale; + w_qparam->zerop = w_zp; + w_qparam->quantized_dimension = 0; + weights->quantparam(std::move(w_qparam)); + _fc->weights(weights); + _fc->fusedActivationFunction(luci::FusedActFunc::NONE); + _fc->dtype(loco::DataType::U8); + _fc->shape(out_shape); + auto l = _fc->dim(_fc->rank() - 1).value(); + _fc->bias(create_const_node(g, loco::DataType::FLOAT32, bias_shape, bv)); + _fc->name("fc"); + { + auto quantparam = std::make_unique<CircleQuantParam>(); + quantparam->scale.push_back(1.0); + quantparam->zerop.push_back(0); + quantparam->quantized_dimension = 0; + _fc->quantparam(std::move(quantparam)); + } + } + +public: + luci::CircleFullyConnected *fc() { return _fc; } + +protected: + luci::CircleFullyConnected *_fc = nullptr; + luci::CircleInput *_x = nullptr; +}; + +struct Q8FCGraph final : public TestIGraphlet, public TestOGraphlet, public Q8FCGraphlet +{ + void init(const ShapeU32 in_shape, const ShapeU32 w_shape, const ShapeU32 out_shape, + const ShapeU32 bias_shape, const float bv) + { + TestIGraphlet::init(g(), in_shape); + TestOGraphlet::init(g(), out_shape); + _x = input(); + Q8FCGraphlet::init(g(), out_shape, w_shape, bias_shape, bv); + output()->from(_fc); + } +}; + +class CQ8QuantizeBiasFCTest : public ::testing::Test +{ +public: + Q8FCGraph g; + luci::QuantizeBias qb{loco::DataType::FLOAT32, loco::DataType::U8, + luci::QuantizationGranularity::ChannelWise}; +}; + +} // namespace + +TEST_F(CQ8QuantizeBiasFCTest, fully_connected) +{ + g.init({1, 18, 80}, {256, 80}, {18, 256}, {1, 256}, 1); + g.fc()->accept(&qb); + + auto bias = loco::must_cast<CircleConst *>(g.fc()->bias()); + auto qparam = bias->quantparam(); + + EXPECT_NE(nullptr, qparam); + EXPECT_EQ(256, qparam->scale.size()); + EXPECT_EQ(256, qparam->zerop.size()); + EXPECT_EQ(1, qparam->quantized_dimension); +} + +TEST_F(CQ8QuantizeBiasFCTest, wrong_bias_shape_NEG) +{ + g.init({1, 18, 80}, {256, 80}, {18, 256}, {1, 2, 128}, 1); + EXPECT_ANY_THROW(g.fc()->accept(&qb)); // Wrong bias shape +} diff --git a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp index c9b35e0be..ef047d35d 100644 --- a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp +++ b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp @@ -27,6 +27,7 @@ #include <iostream> #include <cmath> #include <functional> +#include <limits> namespace { @@ -352,15 +353,15 @@ private: private: // Check if // 1. node is const - // 2. node was not quantized + // 2. node's dtype is float32 bool is_quantizable(loco::Node *node) { auto const_node = dynamic_cast<luci::CircleConst *>(node); if (not const_node) return false; - // Skip if this is already quantized - if (is_quantized(const_node)) + // Skip if this is not float32 + if (const_node->dtype() != loco::DataType::FLOAT32) return false; return true; diff --git a/compiler/luci/pass/src/QuantizeWeights.cpp b/compiler/luci/pass/src/QuantizeWeights.cpp index 11322ab44..500ae12ed 100644 --- a/compiler/luci/pass/src/QuantizeWeights.cpp +++ b/compiler/luci/pass/src/QuantizeWeights.cpp @@ -23,6 +23,7 @@ #include <cmath> #include <vector> #include <functional> +#include <limits> using namespace luci; diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp index d9a9d4db7..005144516 100644 --- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp +++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp @@ -41,10 +41,28 @@ namespace { using namespace luci; + +bool use_predefined_values(ActivationQType qtype) +{ + switch (qtype) + { + case ActivationQType::PreDefinedLogistic: + case ActivationQType::PreDefinedTanh: + case ActivationQType::PreDefinedSoftmax: + return true; + default: + // This ensures this switch-statement handles all ActivationQTypes + assert(qtype == ActivationQType::IntScale or qtype == ActivationQType::MinMax); + break; + } + + return false; +} + // Create a Quantize Op whose // dtype is out_type // shape is the same with node -// qparam is computed using node's min/max +// qparam is computed according to node's qtype luci::CircleQuantize *create_quantize_op(luci::CircleNode *node, loco::DataType out_type) { auto quantize = node->graph()->nodes()->create<CircleQuantize>(); @@ -60,9 +78,9 @@ luci::CircleQuantize *create_quantize_op(luci::CircleNode *node, loco::DataType assert(qparam); // FIX_CALLER_UNLESS auto qtype = luci::activation_qtype(node); - if (qtype == ActivationQType::PreDefinedValue) + if (use_predefined_values(qtype)) { - quantize->quantparam(luci::make_predefined_qparam(node->opcode(), out_type)); + quantize->quantparam(luci::make_predefined_qparam(qtype, out_type)); return quantize; } @@ -105,6 +123,23 @@ luci::CircleQuantize *create_quantize_op(luci::CircleNode *node, loco::DataType return quantize; } +// Create Dequantize Op whose shape is the same with node +luci::CircleDequantize *create_dequantize(luci::CircleNode *node) +{ + auto dequantize = node->graph()->nodes()->create<luci::CircleDequantize>(); + dequantize->name(node->name() + "_Dequantize"); + dequantize->dtype(loco::DataType::FLOAT32); + dequantize->rank(node->rank()); + for (uint32_t i = 0; i < node->rank(); i++) + dequantize->dim(i).set(node->dim(i).value()); + + dequantize->shape_status(luci::ShapeStatus::VALID); + + luci::add_origin(dequantize, luci::get_origin(node)); + + return dequantize; +} + } // namespace namespace luci @@ -229,11 +264,13 @@ private: INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleFullyConnected, input) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleGather, params) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleInstanceNorm, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLeakyRelu, features) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLocalResponseNormalization, input) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLogistic, x) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMaxPool2D, value) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMean, input) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMirrorPad, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleNeg, x) INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePad, input) INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePadV2, input) INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePRelu, input) @@ -241,6 +278,7 @@ private: INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReduceMax, input) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReduceMin, input) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleRelu, features) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleRelu6, features) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReshape, tensor) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleResizeBilinear, input) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleResizeNearestNeighbor, input) @@ -250,6 +288,7 @@ private: INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSoftmax, logits) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSpaceToBatchND, input) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSpaceToDepth, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSqueeze, input) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSqrt, x) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleStridedSlice, input) INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSum, input) @@ -353,7 +392,9 @@ void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const luci::add_origin(quant_op, luci::get_origin(succ)); } - // Requantize input + // Update qparam of input + // This step is skipped if input_type is float32 + if (_ctx->input_type != loco::DataType::FLOAT32) { auto quantparam = input->quantparam(); assert(quantparam); @@ -376,11 +417,13 @@ void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const assert(_ctx->input_type == loco::DataType::S16); compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max); } - input->dtype(_ctx->input_type); input->quantparam()->scale[0] = scaling_factor; input->quantparam()->zerop[0] = zp; } + // Update dtype of input + input->dtype(_ctx->input_type); + auto graph_input = inputs->at(input->index()); graph_input->dtype(_ctx->input_type); } @@ -405,13 +448,26 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const if (not from->quantparam()) continue; - // Insert Quantize Op - auto quant_op = create_quantize_op(from, _ctx->output_type); - loco::replace(from).with(quant_op); - quant_op->input(from); + // Insert Dequantize Op for float32 output_type + if (_ctx->output_type == loco::DataType::FLOAT32) + { + auto dequant_op = create_dequantize(from); + loco::replace(from).with(dequant_op); + dequant_op->input(from); + } + else + { + // Insert Quantize Op for non-float32 output_type + auto quant_op = create_quantize_op(from, _ctx->output_type); + loco::replace(from).with(quant_op); + quant_op->input(from); - // TODO Set a proper origin (Quantize should have its own Origin) - luci::add_origin(quant_op, luci::get_origin(from)); + // TODO Set a proper origin (Quantize should have its own Origin) + luci::add_origin(quant_op, luci::get_origin(from)); + } + + // Update dtype of output + output->dtype(_ctx->output_type); auto graph_output = outputs->at(output->index()); graph_output->dtype(_ctx->output_type); @@ -594,12 +650,25 @@ bool QuantizeWithMinMaxPass::run(loco::Graph *g) // Set output type set_output_type(g); + // Remove redundant Quantize Op + { + logo::Phase phase; + + phase.emplace_back(std::make_unique<luci::RemoveRedundantQuantizePass>()); + + ProgressReporter prog(g, logo::PhaseStrategy::Saturate); + logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g}; + phase_runner.attach(&prog); + phase_runner.run(phase); + } + // Remove min/max values for (auto node : loco::active_nodes(loco::output_nodes(g))) { auto circle_node = loco::must_cast<luci::CircleNode *>(node); if (auto qparam = circle_node->quantparam()) { + warn_accuracy_with_range(circle_node); qparam->min.clear(); qparam->max.clear(); } diff --git a/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp b/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp index cebafd32b..21b4fe1c6 100644 --- a/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp +++ b/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp @@ -1088,6 +1088,31 @@ private: luci::CircleConst *_const = nullptr; }; +class ReduceMaxTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({4, 3, 2}, {2}); + + _axis = create_const<Type::S32, int32_t>(g(), {4}, {1, 0, -3, -3}); + _reduce_max = g()->nodes()->create<luci::CircleReduceMax>(); + { + _reduce_max->input(input()); + _reduce_max->reduction_indices(_axis); + _reduce_max->name("test"); + _reduce_max->keep_dims(false); + } + output()->from(_reduce_max); + + set_minmax_to_non_const(g(), -1, 1); + } + +private: + luci::CircleReduceMax *_reduce_max = nullptr; + luci::CircleConst *_axis = nullptr; +}; + class ResizeBilinearTestGraph final : public SimpleTestGraph { public: @@ -2345,6 +2370,34 @@ TEST(QuantizedModelVerifierTest, Pow_wrong_granularity_NEG) SUCCEED(); } +TEST(QuantizedModelVerifierTest, ReduceMax) +{ + TEST_WITH_GRAPH(ReduceMaxTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(ReduceMaxTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(ReduceMaxTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(ReduceMaxTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(ReduceMaxTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(ReduceMaxTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, ReduceMax_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(ReduceMaxTestGraph, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(ReduceMaxTestGraph, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(ReduceMaxTestGraph, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, ReduceMax_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY(ReduceMaxTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_WRONG_GRANULARITY(ReduceMaxTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_WRONG_GRANULARITY(ReduceMaxTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + TEST(QuantizedModelVerifierTest, ResizeBilinear) { TEST_WITH_GRAPH(ResizeBilinearTestGraph, Type::U8, Granularity::LayerWise); diff --git a/compiler/luci/pass/src/RemoveRedundantDequantizePass.cpp b/compiler/luci/pass/src/RemoveRedundantDequantizePass.cpp new file mode 100644 index 000000000..66cd9d791 --- /dev/null +++ b/compiler/luci/pass/src/RemoveRedundantDequantizePass.cpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/RemoveRedundantDequantizePass.h" + +#include <luci/IR/CircleNodes.h> + +namespace +{ + +bool remove_redundant_dequant(luci::CircleDequantize *dequant) +{ + assert(dequant != nullptr); + + auto prev = loco::must_cast<luci::CircleNode *>(dequant->input()); + if (prev->dtype() != loco::DataType::FLOAT32) + return false; + + replace(dequant).with(prev); + + return true; +} + +} // namespace + +namespace luci +{ +/** + * Dequantize Op does the below things on the ifm. + * 1. Element-wise update of quantized values (u8/s16) to fp32 values + * 2. Update dtype to fp32 + * If the previous node is not quantized, dequantize Op is redundant. + * + * BEFORE + * + * [CircleNode (A)] + * | + * [CircleNode (B)] (fp32) + * | + * [CircleDequantize] + * | + * [CircleNode] + * + * AFTER + * + * [CircleNode (A)] + * | + * [CircleNode (B)] (fp32) + * | + * [CircleNode] + */ +bool RemoveRedundantDequantizePass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto target_node = dynamic_cast<luci::CircleDequantize *>(node); + if (target_node != nullptr) + { + if (remove_redundant_dequant(target_node)) + changed = true; + } + } + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/RemoveRedundantDequantizePass.test.cpp b/compiler/luci/pass/src/RemoveRedundantDequantizePass.test.cpp new file mode 100644 index 000000000..adb2f14a4 --- /dev/null +++ b/compiler/luci/pass/src/RemoveRedundantDequantizePass.test.cpp @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/RemoveRedundantDequantizePass.h" + +#include <luci/IR/CircleNodes.h> + +#include <luci/test/TestIOGraph.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class DequantizeGraphlet +{ +public: + DequantizeGraphlet() = default; + +public: + void init(loco::Graph *g) + { + _dequantize = g->nodes()->create<luci::CircleDequantize>(); + _dequantize->dtype(loco::DataType::FLOAT32); + _dequantize->name("dequantize"); + } + +protected: + luci::CircleDequantize *_dequantize = nullptr; +}; + +class RedundantDequantizeGraph : public TestIOGraph, public DequantizeGraphlet +{ +public: + RedundantDequantizeGraph() = default; + +public: + void init(void) + { + TestIOGraph::init({1}, {1}); + DequantizeGraphlet::init(g()); + + _dequantize->input(input()); + + output()->from(_dequantize); + } + + void init_u8_input(void) + { + TestIOGraph::init({1}, {1}); + DequantizeGraphlet::init(g()); + + // Use u8 input (dequantize is not redundant anymore) + input()->dtype(loco::DataType::U8); + { + auto qparam = std::make_unique<luci::CircleQuantParam>(); + qparam->scale = {1}; + qparam->zerop = {1}; + input()->quantparam(std::move(qparam)); + } + + _dequantize->input(input()); + + output()->from(_dequantize); + } +}; + +} // namespace + +TEST(RemoveRedundantDequantizePass, single_redundant_dequantize) +{ + RedundantDequantizeGraph g; + luci::RemoveRedundantDequantizePass pass; + + g.init(); + + EXPECT_TRUE(pass.run(g.g())); + + int count = 0; + for (auto node : loco::active_nodes(loco::output_nodes(g.g()))) + { + if (dynamic_cast<luci::CircleDequantize *>(node)) + { + count++; + } + } + + ASSERT_EQ(0, count); +} + +TEST(RemoveRedundantDequantizePass, wrong_dtype_NEG) +{ + RedundantDequantizeGraph g; + luci::RemoveRedundantDequantizePass pass; + + g.init_u8_input(); + + EXPECT_FALSE(pass.run(g.g())); +} diff --git a/compiler/luci/pass/src/RemoveUnnecessaryReshapeNetPass.cpp b/compiler/luci/pass/src/RemoveUnnecessaryReshapeNetPass.cpp new file mode 100644 index 000000000..476ec68bf --- /dev/null +++ b/compiler/luci/pass/src/RemoveUnnecessaryReshapeNetPass.cpp @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/RemoveUnnecessaryReshapeNetPass.h" + +#include <luci/IR/CircleNodes.h> + +namespace +{ + +bool acceptable_intermediate_op(const loco::Node *node) +{ + if (not node) + return false; + + const auto opcode = loco::must_cast<const luci::CircleNode *>(node)->opcode(); + + switch (opcode) + { + case luci::CircleOpcode::ADD: + case luci::CircleOpcode::MUL: + case luci::CircleOpcode::TANH: + case luci::CircleOpcode::LOGISTIC: + break; + + default: + return false; + } + + return true; +} + +bool same_shape(const loco::Node *a, const loco::Node *b) +{ + auto a_cnode = loco::must_cast<const luci::CircleNode *>(a); + auto b_cnode = loco::must_cast<const luci::CircleNode *>(b); + + if (a_cnode->rank() != b_cnode->rank()) + return false; + + for (uint32_t i = 0; i < a_cnode->rank(); i++) + { + if (not(a_cnode->dim(i) == b_cnode->dim(i))) + return false; + } + return true; +} + +class PreReshapeFinder +{ +public: + PreReshapeFinder(const luci::CircleReshape *post_reshape) : _post_reshape(post_reshape) + { + assert(post_reshape != nullptr); // FIX_CALLER_UNLESS + } + +public: + // Return true if pre_reshapes are found + bool collect_pre_reshapes(loco::Node *node) + { + // TODO Support diamond case + if (loco::succs(node).size() != 1) + return false; + + if (auto pre_reshape = dynamic_cast<luci::CircleReshape *>(node)) + { + // Check ifm of pre-reshape and ofm of post_reshape + if (not same_shape(pre_reshape->tensor(), _post_reshape)) + return false; + + // Check ofm of pre-reshape and ifm of post_reshape + if (not same_shape(pre_reshape, _post_reshape->tensor())) + return false; + + _pre_reshapes.emplace_back(pre_reshape); + return true; + } + + if (not acceptable_intermediate_op(node)) + return false; + + for (uint32_t i = 0; i < node->arity(); i++) + { + if (not collect_pre_reshapes(node->arg(i))) + return false; + } + + return true; + } + +public: + std::vector<luci::CircleReshape *> pre_reshapes(void) const { return _pre_reshapes; } + +private: + const luci::CircleReshape *_post_reshape = nullptr; + std::vector<luci::CircleReshape *> _pre_reshapes; +}; + +bool remove_unnecessary_reshape_net(luci::CircleReshape *reshape) +{ + PreReshapeFinder finder(reshape); + if (not finder.collect_pre_reshapes(reshape->tensor())) + return false; + + // Remove pre_reshapes + for (auto pre_reshape : finder.pre_reshapes()) + { + loco::replace(pre_reshape).with(pre_reshape->tensor()); + } + + // Remove post_reshape + loco::replace(reshape).with(reshape->tensor()); + + return true; +} + +} // namespace + +namespace luci +{ + +/** + * BEFORE + * + * [CircleNode] + * | + * [CircleReshape_1] (shape: A -> B) + * | + * [CircleNode] (ex: Add/Mul/Tanh/Logistic ..) + * | + * [CircleReshape_2] (shape: B -> A) + * | + * [CircleNode] + * + * AFTER + * + * [CircleNode] + * | \ + * | [CircleReshape_1] + * [CircleNode] + * | \ + * | [CircleReshape_2] + * [CircleNode] + **/ +bool RemoveUnnecessaryReshapeNetPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto reshape_node = dynamic_cast<luci::CircleReshape *>(node)) + { + if (remove_unnecessary_reshape_net(reshape_node)) + changed = true; + } + } + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/RemoveUnnecessaryReshapeNetPass.test.cpp b/compiler/luci/pass/src/RemoveUnnecessaryReshapeNetPass.test.cpp new file mode 100644 index 000000000..4ad707ba3 --- /dev/null +++ b/compiler/luci/pass/src/RemoveUnnecessaryReshapeNetPass.test.cpp @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "luci/Pass/RemoveUnnecessaryReshapeNetPass.h" + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +namespace +{ + +class RemoveUnnecessaryReshapeNet : public ::testing::Test +{ +public: + RemoveUnnecessaryReshapeNet() {} + + void createReshapeConst(luci::CircleReshape *target, const std::vector<uint32_t> shape) + { + auto shape_const = g.nodes()->create<luci::CircleConst>(); + shape_const->dtype(loco::DataType::S32); + shape_const->size<loco::DataType::S32>(shape.size()); + shape_const->shape_status(luci::ShapeStatus::VALID); + shape_const->rank(1); + shape_const->dim(0).set(shape.size()); + for (int32_t i = 0; i < shape.size(); i++) + { + shape_const->at<loco::DataType::S32>(i) = static_cast<int32_t>(shape.at(i)); + } + shape_const->name("shape_const"); + target->shape(shape_const); + target->rank(shape.size()); + for (uint32_t i = 0; i < shape.size(); i++) + { + target->dim(i) = shape[i]; + } + target->shape_status(luci::ShapeStatus::VALID); + } + + void buildGraph(const std::initializer_list<uint32_t> base_shape, + const std::initializer_list<uint32_t> first_shape, + const std::initializer_list<uint32_t> second_shape) + { + // Input Create. + input = g.nodes()->create<luci::CircleInput>(); + auto graph_input = g.inputs()->create(); + input->index(graph_input->index()); + input->shape_status(luci::ShapeStatus::VALID); + input->shape(base_shape); + input->name("input"); + + // Create first reshape. + first_reshape = g.nodes()->create<luci::CircleReshape>(); + first_reshape->tensor(input); + first_reshape->name("Reshape"); + createReshapeConst(first_reshape, first_shape); + + // Create logistic. + logistic = g.nodes()->create<luci::CircleLogistic>(); + logistic->x(first_reshape); + logistic->name("logistic"); + logistic->shape(first_shape); + logistic->shape_status(luci::ShapeStatus::VALID); + + // Create second reshape. + second_reshape = g.nodes()->create<luci::CircleReshape>(); + second_reshape->tensor(logistic); + second_reshape->name("second_reshape"); + createReshapeConst(second_reshape, second_shape); + + // Output Connect. + output = g.nodes()->create<luci::CircleOutput>(); + output->from(second_reshape); + output->name("output"); + auto graph_output = g.outputs()->create(); + output->index(graph_output->index()); + } + +public: + loco::Graph g; + luci::CircleInput *input = nullptr; + luci::CircleReshape *first_reshape = nullptr; + luci::CircleLogistic *logistic = nullptr; + luci::CircleReshape *second_reshape = nullptr; + luci::CircleOutput *output = nullptr; +}; + +} // namespace + +TEST_F(RemoveUnnecessaryReshapeNet, simple_case) +{ + buildGraph({1, 1, 1, 32}, {1, 1, 32, 1}, {1, 1, 1, 32}); + luci::RemoveUnnecessaryReshapeNetPass pass; + + ASSERT_TRUE(pass.run(&g)); + + int count = 0; + for (auto node : loco::active_nodes(loco::output_nodes(&g))) + { + if (auto reshape = dynamic_cast<luci::CircleReshape *>(node)) + count++; + } + ASSERT_EQ(0, count); +} + +TEST_F(RemoveUnnecessaryReshapeNet, shape_mismatch_NEG) +{ + buildGraph({1, 1, 1, 32}, {1, 1, 32, 1}, {1, 1, 2, 16}); + luci::RemoveUnnecessaryReshapeNetPass pass; + ASSERT_FALSE(pass.run(&g)); +} diff --git a/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.cpp b/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.cpp new file mode 100644 index 000000000..741b70956 --- /dev/null +++ b/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.cpp @@ -0,0 +1,196 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <luci/IR/CircleNodes.h> +#include <luci/Profile/CircleNodeOrigin.h> +#include <luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h> + +namespace +{ + +// TODO move to global helper list if needed +/** + * @brief Create a node with `inp` as input from fused activation fucntion `act` + */ +luci::CircleNode *fromActivation(luci::CircleNode *inp, luci::FusedActFunc act) +{ + switch (act) + { + case luci::FusedActFunc::NONE: + return inp; + case luci::FusedActFunc::RELU: + { + auto n = inp->graph()->nodes()->create<luci::CircleRelu>(); + n->features(inp); + return n; + } + case luci::FusedActFunc::RELU6: + { + auto n = inp->graph()->nodes()->create<luci::CircleRelu6>(); + n->features(inp); + return n; + } + case luci::FusedActFunc::RELU_N1_TO_1: + { + auto n = inp->graph()->nodes()->create<luci::CircleReluN1To1>(); + n->features(inp); + return n; + } + case luci::FusedActFunc::TANH: + { + auto n = inp->graph()->nodes()->create<luci::CircleTanh>(); + n->x(inp); + return n; + } + case luci::FusedActFunc::SIGN_BIT: + { + throw std::invalid_argument("no matching node to create from fused activation"); + } + default: + throw std::invalid_argument("invalid fused activation"); + } +} + +/** + * Replace Fully Connected with Batched MatMul + * + * BEFORE + * + * [Node1] [Node2] + * | | + * [transpose]? [transpose]? + * \ / + * [FullyConnected] + * + * AFTER + * + * [Node1] [Node2] + * \ / + * [BatchMatMul] [BiasValue]? + * \ / + * [Add]? + * | + * [Activation]? + * + * Nodes with "?" denote optional elements + */ +bool replace_fc_with_matmul(luci::CircleFullyConnected *fc) +{ + luci::CircleNode *x = nullptr; + luci::CircleNode *y = nullptr; + luci::CircleNode *b = nullptr; + luci::CircleTranspose *ty = nullptr; + luci::CircleTranspose *tx = nullptr; + bool adj_x = false; + bool adj_y = true; + + if (dynamic_cast<luci::CircleConst *>(fc->weights())) + return false; // NonConst + + if ((ty = dynamic_cast<luci::CircleTranspose *>(fc->weights()))) // is y a transpose? + { + adj_y = false; + if (dynamic_cast<luci::CircleConst *>(ty->a())) + return false; + else + y = loco::must_cast<luci::CircleNode *>(ty->a()); + } + else + { // y is not transpose and not const + y = loco::must_cast<luci::CircleNode *>(fc->weights()); + } + if ((tx = dynamic_cast<luci::CircleTranspose *>(fc->input()))) + { + adj_x = true; + x = loco::must_cast<luci::CircleNode *>(tx->a()); + } + else + { + x = loco::must_cast<luci::CircleNode *>(fc->input()); + } + + b = loco::must_cast<luci::CircleNode *>(fc->bias()); + + if (x->dtype() != loco::DataType::FLOAT32 || y->dtype() != loco::DataType::FLOAT32 || + b->dtype() != loco::DataType::FLOAT32) + return false; + + auto name = fc->name(); + assert(name.length() > 0); + + auto matmul = fc->graph()->nodes()->create<luci::CircleBatchMatMul>(); + matmul->x(x); + matmul->y(y); + matmul->adj_x(adj_x); + matmul->adj_y(adj_y); + matmul->name(name); + matmul->dtype(fc->dtype()); + + luci::add_origin(matmul, luci::get_origin(fc)); + + auto all_zero = [](const luci::CircleConst *c) { + bool ac = true; + for (uint32_t i = 0; i < c->size<loco::DataType::FLOAT32>() && ac; i++) + { + ac &= c->at<loco::DataType::FLOAT32>(i) == 0.0f; + } + return ac; + }; + + auto bc = dynamic_cast<luci::CircleConst *>(b); + if ((nullptr != bc) && !all_zero(bc)) + { + auto bias_add = fc->graph()->nodes()->create<luci::CircleAdd>(); + bias_add->x(matmul); + bias_add->y(b); + bias_add->name(fc->name() + "/bias_add"); + bias_add->dtype(fc->dtype()); + add_origin(bias_add, get_origin(fc)); + bias_add->fusedActivationFunction(fc->fusedActivationFunction()); + loco::replace(fc).with(bias_add); + } + else + { + auto n = fromActivation(matmul, fc->fusedActivationFunction()); + add_origin(n, luci::get_origin(fc)); + n->name(fc->name() + "fusedActivation"); + n->dtype(fc->dtype()); + loco::replace(fc).with(n); + } + + return true; +} +} // namespace + +namespace luci +{ + +bool ReplaceNonConstFCWithBatchMatMulPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto fc = dynamic_cast<luci::CircleFullyConnected *>(node)) + { + if (replace_fc_with_matmul(fc)) + changed = true; + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp b/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp new file mode 100644 index 000000000..7606a6125 --- /dev/null +++ b/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp @@ -0,0 +1,189 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h" + +#include <luci/test/TestIOGraph.h> +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +// TODO Reduce duplicate codes in ResolveCustomOpMatMulPass.cpp +template <typename T> +luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype, + const std::vector<uint32_t> &shape, + const std::vector<T> &values) +{ + auto node = g->nodes()->create<luci::CircleConst>(); + node->dtype(dtype); + node->rank(shape.size()); + + uint32_t size = 1; + for (uint32_t i = 0; i < shape.size(); ++i) + { + node->dim(i) = shape.at(i); + size *= shape.at(i); + } + node->shape_status(luci::ShapeStatus::VALID); + +#define INIT_VALUES(DT) \ + { \ + node->size<DT>(size); \ + for (uint32_t i = 0; i < values.size(); ++i) \ + node->at<DT>(i) = values[i]; \ + } + + switch (dtype) + { + case loco::DataType::U8: + INIT_VALUES(loco::DataType::U8); + break; + case loco::DataType::S16: + INIT_VALUES(loco::DataType::S16); + break; + case loco::DataType::S32: + INIT_VALUES(loco::DataType::S32); + break; + case loco::DataType::FLOAT32: + INIT_VALUES(loco::DataType::FLOAT32) + break; + default: + INTERNAL_EXN("create_const_node called with unsupported type"); + break; + } + return node; +} + +/** + * Simple graph for test + * + * BEFORE + * + * [IFM1] [IFM2] [BIAS] + * \ | / + * [FC] + * | + * [Res] + * + * AFTER + * [IFM1] [IFM2] + * \ | + * [BatchMatMul] [BIAS] + * \ / + * [Add] + * | + * [Res] + * + */ +struct FCGraphlet +{ +public: + FCGraphlet() = default; + virtual ~FCGraphlet() = default; + + void init(loco::Graph *g, const ShapeU32 r_shape, const float bv) + { + _tr_y = g->nodes()->create<luci::CircleTranspose>(); + _tr_y->a(_y); + std::vector<int32_t> tr_val = {1, 0}; + _tr_y->perm(create_const_node(g, loco::DataType::S32, {2}, tr_val)); + + _fc = g->nodes()->create<luci::CircleFullyConnected>(); + _fc->input(_x); + _fc->weights(_tr_y); + _fc->fusedActivationFunction(luci::FusedActFunc::NONE); + _fc->dtype(loco::DataType::FLOAT32); + _fc->shape(r_shape); + auto l = _fc->dim(_fc->rank() - 1).value(); + std::vector<float> bias_val(l, bv); + _fc->bias(create_const_node(g, loco::DataType::FLOAT32, {l}, bias_val)); + _fc->name("fc"); + } + +public: + luci::CircleFullyConnected *fc() { return _fc; } + +protected: + luci::CircleFullyConnected *_fc = nullptr; + luci::CircleTranspose *_tr_y = nullptr; + luci::CircleInput *_x = nullptr; + luci::CircleInput *_y = nullptr; +}; + +struct FCGraph : public TestIsGraphlet<2>, public TestOGraphlet, public FCGraphlet +{ + FCGraph() = default; + virtual ~FCGraph() = default; + void init(const ShapeU32 x_shape, const ShapeU32 y_shape, const ShapeU32 r_shape, const float bv) + { + TestIsGraphlet<2>::init(g(), {x_shape, y_shape}); + TestOGraphlet::init(g(), r_shape); + _x = input(0); + _y = input(1); + FCGraphlet::init(g(), r_shape, bv); + output()->from(_fc); + } +}; + +class ReplaceNonConstFCWithBatchMatMulPassTest : public ::testing::Test +{ +public: + FCGraph g; + luci::ReplaceNonConstFCWithBatchMatMulPass pass; +}; + +} // namespace + +TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, simple_test) +{ + g.init({2, 3}, {2, 3}, {2, 2}, 0.0f); + + auto ret = pass.run(g.g()); + EXPECT_EQ(true, ret); + + auto mm = dynamic_cast<luci::CircleBatchMatMul *>(g.output()->from()); + EXPECT_NE(nullptr, mm); +} + +TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, nonzero_bias_test) +{ + g.init({2, 3}, {2, 3}, {2, 2}, 1.0f); + + auto ret = pass.run(g.g()); + EXPECT_EQ(true, ret); + + auto mm = dynamic_cast<luci::CircleAdd *>(g.output()->from()); + EXPECT_NE(nullptr, mm); +} + +TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, wrong_op_NEG) +{ + loco::Graph g; + + auto inp = g.nodes()->create<luci::CircleInput>(); + auto relu = g.nodes()->create<luci::CircleRelu>(); + relu->features(inp); + + luci::ReplaceNonConstFCWithBatchMatMulPass pass; + auto changed = pass.run(&g); + + EXPECT_EQ(false, changed); +} diff --git a/compiler/luci/pass/src/ResolveCustomOpSplitVPass.cpp b/compiler/luci/pass/src/ResolveCustomOpSplitVPass.cpp new file mode 100644 index 000000000..a65065800 --- /dev/null +++ b/compiler/luci/pass/src/ResolveCustomOpSplitVPass.cpp @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/ResolveCustomOpSplitVPass.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/Profile/CircleNodeOrigin.h> +#include <luci/Service/Nodes/CircleConst.h> + +namespace +{ + +// Input node is const S64 +// Return s32 version of node +// Return nullptr if s64 value is out of range of s32 +luci::CircleConst *s64_to_s32(luci::CircleConst *node) +{ + assert(node); + assert(node->dtype() == loco::DataType::S64); + + auto cloned = luci::clone(node); + luci::add_origin(cloned, luci::get_origin(node)); + + const auto num_elems = node->size<loco::DataType::S64>(); + + cloned->dtype(loco::DataType::S32); + cloned->size<loco::DataType::S32>(num_elems); + + for (uint32_t i = 0; i < num_elems; i++) + { + int64_t val = node->at<loco::DataType::S64>(i); + if (val < std::numeric_limits<int32_t>::min() or val > std::numeric_limits<int32_t>::max()) + return nullptr; + + cloned->at<loco::DataType::S32>(i) = static_cast<int32_t>(val); + } + + return cloned; +} + +/** BEFORE + * + * [CircleNode] + * \ + * \ [size_splits] [split_dim] + * \ | / + * [CircleCustom(SplitV))] + * | + * [CircleCustomOut] + * | + * [CircleNode] + * + * AFTER + * + * [CircleNode] + * | \ + * | \ [size_splits] [split_dim] + * | \ | / + * | \ | / + * | \ | / + * [CircleCustom(SplitV)] [CircleSplitV] + * | | + * [CircleCustomOut] [CircleSplitVOut] + * | + * [CircleNode] + */ +bool resolve_splitv(luci::CircleCustom *node) +{ + const std::string custom_code = node->custom_code(); + const std::vector<uint8_t> custom_options = node->custom_options(); + + if (custom_code != "SplitV") + return false; + + if (node->numInputs() != 3) + return false; + + auto size_splits = dynamic_cast<luci::CircleConst *>(node->inputs(1)); + if (not size_splits) + return false; + + // Convert size_splits to S32, because luci-interpeter does not support + // S64 size_splits yet + // TODO Support S64 size_splits + if (size_splits->dtype() == loco::DataType::S64) + { + size_splits = s64_to_s32(size_splits); + if (not size_splits) + return false; + } + if (size_splits->dtype() != loco::DataType::S32) + return false; + + auto split_dim = dynamic_cast<luci::CircleConst *>(node->inputs(2)); + if (not split_dim) + return false; + + if (split_dim->dtype() == loco::DataType::S64) + { + split_dim = s64_to_s32(split_dim); + if (not split_dim) + return false; + } + if (split_dim->dtype() != loco::DataType::S32) + return false; + + if (size_splits->rank() != 1) + return false; + + const auto num_split = size_splits->dim(0).value(); + + auto split_v = node->graph()->nodes()->create<luci::CircleSplitV>(); + split_v->input(node->inputs(0)); + split_v->size_splits(size_splits); + split_v->split_dim(split_dim); + split_v->num_split(num_split); + split_v->name(node->name()); + luci::add_origin(split_v, luci::get_origin(node)); + + int32_t i = 0; + const auto succs = loco::succs(node); + for (auto succ : succs) + { + auto custom_out = loco::must_cast<luci::CircleCustomOut *>(succ); // FIX_CALLER_UNLESS + + auto split_v_out = node->graph()->nodes()->create<luci::CircleSplitVOut>(); + split_v_out->input(split_v); + split_v_out->name(node->name() + "_out_" + std::to_string(i)); + split_v_out->index(i++); + luci::add_origin(split_v_out, luci::get_origin(node)); + loco::replace(custom_out).with(split_v_out); + } + + return true; +} + +} // namespace + +namespace luci +{ + +bool ResolveCustomOpSplitVPass::run(loco::Graph *g) +{ + bool changed = false; + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto cop = dynamic_cast<luci::CircleCustom *>(node); + if (not cop) + continue; + + if (resolve_splitv(cop)) + changed = true; + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/ResolveCustomOpSplitVPass.test.cpp b/compiler/luci/pass/src/ResolveCustomOpSplitVPass.test.cpp new file mode 100644 index 000000000..e7738aadb --- /dev/null +++ b/compiler/luci/pass/src/ResolveCustomOpSplitVPass.test.cpp @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/ResolveCustomOpSplitVPass.h" + +#include <luci/test/TestIOGraph.h> + +#include <luci/IR/CircleNodes.h> +#include <gtest/gtest.h> + +using namespace luci::test; + +namespace +{ + +/** + * graph having Custom operator SplitV + * + * [Input] [Const] [Const] + * \ | / + * [Custom(SplitV)] + * / | \ + * [CustomOut] [CustomOut] [CustomOut] + * | | | + * [Output] [Output] [Output] + */ +class SplitVGraphlet +{ +public: + SplitVGraphlet() = default; + +public: + void init(loco::Graph *g) + { + // CircleCustom(SplitV) + _splitv = g->nodes()->create<luci::CircleCustom>(3, 3); + _splitv->custom_code("SplitV"); + _splitv->shape({1, 2, 2, 192}); + _splitv->dtype(loco::DataType::FLOAT32); + _splitv->name("splitv"); + + // CircleConst + auto size_splits = g->nodes()->create<luci::CircleConst>(); + size_splits->dtype(loco::DataType::S64); + size_splits->shape({3}); + size_splits->size<loco::DataType::S64>(3); + size_splits->at<loco::DataType::S64>(0) = 32; + size_splits->at<loco::DataType::S64>(1) = 32; + size_splits->at<loco::DataType::S64>(2) = 128; + + // CircleConst + auto split_dim = g->nodes()->create<luci::CircleConst>(); + split_dim->dtype(loco::DataType::S32); + split_dim->rank(0); + split_dim->size<loco::DataType::S32>(1); + split_dim->scalar<loco::DataType::S32>() = 3; + + _splitv->inputs(1, size_splits); + _splitv->inputs(2, split_dim); + + // CircleCustomOut + _splitv_out1 = g->nodes()->create<luci::CircleCustomOut>(); + _splitv_out1->shape({1, 2, 2, 32}); + _splitv_out1->dtype(loco::DataType::FLOAT32); + _splitv_out1->index(0); + _splitv_out1->input(_splitv); + + // CircleCustomOut + _splitv_out2 = g->nodes()->create<luci::CircleCustomOut>(); + _splitv_out2->shape({1, 2, 2, 32}); + _splitv_out2->dtype(loco::DataType::FLOAT32); + _splitv_out2->index(1); + _splitv_out2->input(_splitv); + + // CircleCustomOut + _splitv_out3 = g->nodes()->create<luci::CircleCustomOut>(); + _splitv_out3->shape({1, 2, 2, 128}); + _splitv_out3->dtype(loco::DataType::FLOAT32); + _splitv_out3->index(2); + _splitv_out3->input(_splitv); + } + +public: + luci::CircleCustom *splitv() { return _splitv; } + +protected: + luci::CircleCustom *_splitv = nullptr; + luci::CircleCustomOut *_splitv_out1 = nullptr; + luci::CircleCustomOut *_splitv_out2 = nullptr; + luci::CircleCustomOut *_splitv_out3 = nullptr; +}; + +class SplitVGraph : public TestIGraphlet, public TestOsGraphlet<3>, public SplitVGraphlet +{ +public: + SplitVGraph() = default; + + void init(void) + { + TestIGraphlet::init(g(), {1, 2, 2, 192}); + TestOsGraphlet<3>::init(g(), {{1, 2, 2, 32}, {1, 2, 2, 32}, {1, 2, 2, 128}}); + SplitVGraphlet::init(g()); + + // connect graph + _splitv->inputs(0, input()); + + output(0)->from(_splitv_out1); + output(1)->from(_splitv_out2); + output(2)->from(_splitv_out3); + } +}; + +class SplitVGraphTest : public ::testing::Test +{ +public: + SplitVGraph g; + luci::ResolveCustomOpSplitVPass pass; +}; + +} // namespace + +TEST_F(SplitVGraphTest, simple_test) +{ + g.init(); + + auto ret = pass.run(g.g()); + EXPECT_EQ(true, ret); + + auto svo_1 = dynamic_cast<luci::CircleSplitVOut *>(g.output(0)->from()); + EXPECT_NE(nullptr, svo_1); + auto svo_2 = dynamic_cast<luci::CircleSplitVOut *>(g.output(1)->from()); + EXPECT_NE(nullptr, svo_2); + auto svo_3 = dynamic_cast<luci::CircleSplitVOut *>(g.output(2)->from()); + EXPECT_NE(nullptr, svo_3); + + auto sv = dynamic_cast<luci::CircleSplitV *>(svo_1->input()); + EXPECT_NE(nullptr, sv); + sv = dynamic_cast<luci::CircleSplitV *>(svo_2->input()); + EXPECT_NE(nullptr, sv); + sv = dynamic_cast<luci::CircleSplitV *>(svo_3->input()); + EXPECT_NE(nullptr, sv); + + auto size_splits = loco::must_cast<luci::CircleConst *>(sv->size_splits()); + EXPECT_EQ(loco::DataType::S32, size_splits->dtype()); + EXPECT_EQ(32, size_splits->at<loco::DataType::S32>(0)); + EXPECT_EQ(32, size_splits->at<loco::DataType::S32>(1)); + EXPECT_EQ(128, size_splits->at<loco::DataType::S32>(2)); + + auto split_dim = loco::must_cast<luci::CircleConst *>(sv->split_dim()); + EXPECT_EQ(loco::DataType::S32, split_dim->dtype()); + EXPECT_EQ(3, split_dim->scalar<loco::DataType::S32>()); +} + +TEST_F(SplitVGraphTest, wrong_op_NEG) +{ + g.init(); + + g.splitv()->custom_code("AddV2"); + + auto ret = pass.run(g.g()); + EXPECT_EQ(false, ret); +} diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h index 442183c18..408e6b8d9 100644 --- a/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h +++ b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h @@ -197,6 +197,13 @@ private: return true; } + bool visit(const luci::CircleReduceMax *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->input())); + return true; + } + bool visit(const luci::CircleRelu *node) { RETURN_FALSE_UNLESS(is_lwq(node)); diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp b/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp index 4e1c062c0..cf86acabe 100644 --- a/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp +++ b/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp @@ -302,6 +302,15 @@ bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePow *nod } template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleReduceMax *node) +{ + RETURN_FALSE_UNLESS(has_type(node, Qtype)) + RETURN_FALSE_UNLESS(has_type(node->input(), Qtype)) + RETURN_FALSE_UNLESS(has_type(node->reduction_indices(), loco::DataType::S32)) + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleRelu *node) { return group_has_type(node, Qtype); diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeType.h b/compiler/luci/pass/src/VerifyQuantizedNodeType.h index ff1acbd6f..789d3c7cd 100644 --- a/compiler/luci/pass/src/VerifyQuantizedNodeType.h +++ b/compiler/luci/pass/src/VerifyQuantizedNodeType.h @@ -104,6 +104,7 @@ private: bool visit(const luci::CirclePadV2 *node); bool visit(const luci::CirclePRelu *node); bool visit(const luci::CirclePow *node); + bool visit(const luci::CircleReduceMax *node); bool visit(const luci::CircleRelu *node); bool visit(const luci::CircleReshape *node); bool visit(const luci::CircleResizeBilinear *node); diff --git a/compiler/luci/pass/src/helpers/SparsityFormatConverter.cpp b/compiler/luci/pass/src/helpers/SparsityFormatConverter.cpp new file mode 100644 index 000000000..72b7d60ff --- /dev/null +++ b/compiler/luci/pass/src/helpers/SparsityFormatConverter.cpp @@ -0,0 +1,312 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// codes under namespace sparsity referenced from +// https://github.com/tensorflow/tensorflow/blob/3f878cff5b698b82eea85db2b60d65a2e320850e/ +// tensorflow/lite/kernels/internal/utils/sparsity_format_converter.h +// tensorflow/lite/kernels/internal/utils/sparsity_format_converter.cc + +#include "SparsityFormatConverter.h" + +#include <oops/InternalExn.h> + +#include <cassert> + +namespace sparsity +{ + +namespace +{ + +uint64_t GetFlattenedIndex(const std::vector<int> &indices, const std::vector<int> &shape) +{ + uint64_t index = 0; + int sub_elements = 1; + for (int i = shape.size() - 1; i >= 0; i--) + { + index += indices[i] * sub_elements; + sub_elements *= shape[i]; + } + return index; +} + +std::vector<int> TfLiteIntArrayToVector(const TfLiteIntArray *int_array) +{ + std::vector<int> values; + if (!int_array) + { + return values; + } + + values.resize(int_array->size); + for (int i = 0; i < int_array->size; i++) + { + values[i] = int_array->data[i]; + } + + return values; +} + +} // namespace + +template <typename T> +FormatConverter<T>::FormatConverter(const std::vector<int> &shape, const TfLiteSparsity &sparsity) +{ + auto traversal_order = TfLiteIntArrayToVector(sparsity.traversal_order); + auto block_map = TfLiteIntArrayToVector(sparsity.block_map); + + std::vector<TfLiteDimensionType> format(sparsity.dim_metadata_size); + std::vector<int> dense_size(sparsity.dim_metadata_size); + std::vector<std::vector<int>> segments(sparsity.dim_metadata_size); + std::vector<std::vector<int>> indices(sparsity.dim_metadata_size); + for (int i = 0; i < sparsity.dim_metadata_size; i++) + { + format[i] = sparsity.dim_metadata[i].format; + dense_size[i] = sparsity.dim_metadata[i].dense_size; + segments[i] = TfLiteIntArrayToVector(sparsity.dim_metadata[i].array_segments); + indices[i] = TfLiteIntArrayToVector(sparsity.dim_metadata[i].array_indices); + } + + InitSparseToDenseConverter(shape, std::move(traversal_order), std::move(format), + std::move(dense_size), std::move(segments), std::move(indices), + std::move(block_map)); +} + +template <typename T> +void FormatConverter<T>::InitSparseToDenseConverter( + std::vector<int> shape, std::vector<int> traversal_order, std::vector<TfLiteDimensionType> format, + std::vector<int> dense_size, std::vector<std::vector<int>> segments, + std::vector<std::vector<int>> indices, std::vector<int> block_map) +{ + dense_shape_ = std::move(shape); + traversal_order_ = std::move(traversal_order); + block_map_ = std::move(block_map); + format_ = std::move(format); + + dense_size_ = 1; + for (size_t i = 0; i < dense_shape_.size(); i++) + { + dense_size_ *= dense_shape_[i]; + } + + dim_metadata_.resize(2 * format_.size()); + for (size_t i = 0; i < format_.size(); i++) + { + if (format_[i] == kTfLiteDimDense) + { + dim_metadata_[2 * i] = {dense_size[i]}; + } + else + { + dim_metadata_[2 * i] = std::move(segments[i]); + dim_metadata_[2 * i + 1] = std::move(indices[i]); + } + } + + int original_rank = dense_shape_.size(); + int block_dim = 0; + + blocked_shape_.resize(original_rank); + block_size_.resize(block_map_.size()); + for (int i = 0; i < original_rank; i++) + { + if (block_dim < (int)block_map_.size() && block_map_[block_dim] == i) + { + if (original_rank + block_dim < (int)traversal_order_.size()) + { + int orig_dim = traversal_order_[original_rank + block_dim]; + block_size_[block_dim] = dense_size[orig_dim]; + blocked_shape_[i] = dense_shape_[i] / dense_size[orig_dim]; + block_dim++; + } + } + else + { + blocked_shape_[i] = dense_shape_[i]; + } + } +} + +template <typename T> +void FormatConverter<T>::Populate(const T *src_data, std::vector<int> indices, int level, + int prev_idx, int *src_data_ptr, T *dest_data) +{ + if (static_cast<size_t>(level) == indices.size()) + { + int orig_rank = dense_shape_.size(); + std::vector<int> orig_idx; + orig_idx.resize(orig_rank); + int i = 0; + for (; static_cast<size_t>(i) < orig_idx.size(); i++) + { + int orig_dim = traversal_order_[i]; + orig_idx[orig_dim] = indices[i]; + } + + for (; static_cast<size_t>(i) < indices.size(); i++) + { + const int block_idx = traversal_order_[i] - orig_rank; + const int orig_dim = block_map_[block_idx]; + orig_idx[orig_dim] = orig_idx[orig_dim] * block_size_[block_idx] + indices[i]; + } + + dest_data[GetFlattenedIndex(orig_idx, dense_shape_)] = src_data[*src_data_ptr]; + + *src_data_ptr = *src_data_ptr + 1; + return; + } + + const int metadata_idx = 2 * level; + const int shape_of_level = dim_metadata_[metadata_idx][0]; + if (format_[level] == kTfLiteDimDense) + { + for (int i = 0; i < shape_of_level; i++) + { + indices[level] = i; + Populate(src_data, indices, level + 1, prev_idx * shape_of_level + i, src_data_ptr, + dest_data); + } + } + else if (static_cast<size_t>(prev_idx + 1) < dim_metadata_[metadata_idx].size()) + { + const auto &array_segments = dim_metadata_[metadata_idx]; + const auto &array_indices = dim_metadata_[metadata_idx + 1]; + for (int i = array_segments[prev_idx]; i < array_segments[prev_idx + 1]; i++) + { + if (static_cast<size_t>(i) < array_indices.size() && + static_cast<size_t>(level) < indices.size()) + { + indices[level] = array_indices[i]; + Populate(src_data, indices, level + 1, i, src_data_ptr, dest_data); + } + } + } +} + +template <typename T> bool FormatConverter<T>::SparseToDense(const T *src_data) +{ + data_.resize(dense_size_); + std::fill(data_.begin(), data_.end(), T(0)); + + int total_rank = traversal_order_.size(); + int src_data_ptr = 0; + std::vector<int> indices(total_rank); + Populate(src_data, indices, 0, 0, &src_data_ptr, data_.data()); + + return true; +} + +template class FormatConverter<float>; +template class FormatConverter<uint16_t>; + +} // namespace sparsity + +#include <luci/IR/SparsityParam.h> + +namespace luci +{ + +sparsity::TfLiteDimensionType to_tflite_sparsity(luci::DimensionType dt) +{ + switch (dt) + { + case luci::DimensionType::DENSE: + return sparsity::TfLiteDimensionType::kTfLiteDimDense; + case luci::DimensionType::SPARSE_CSR: + return sparsity::TfLiteDimensionType::kTfLiteDimSparseCSR; + } + return sparsity::TfLiteDimensionType::kTfLiteDimDense; +} + +sparsity::TfLiteIntArray *to_tflite_sparsity(const luci::SparseIndexVector &data) +{ + auto type = data.type(); + switch (type) + { + case luci::SparseIndexVectorType::NONE: + { + std::vector<int32_t> empty; + return makeTfLiteArray(empty); + } + case luci::SparseIndexVectorType::I32: + return makeTfLiteArray<int32_t>(*data.as_int32_vector()); + case luci::SparseIndexVectorType::U16: + return makeTfLiteArray<uint16_t>(*data.as_uint16_vector()); + case luci::SparseIndexVectorType::U8: + return makeTfLiteArray<uint8_t>(*data.as_uint8_vector()); + default: + INTERNAL_EXN_V("unsupported SparseIndexVectorType", oops::to_uint32(type)); + } +} + +sparsity::TfLiteSparsity to_tflite_sparsity(const luci::SparsityParam *sp) +{ + sparsity::TfLiteSparsity tflsp; + tflsp.traversal_order = makeTfLiteArray(sp->traversal_order); + tflsp.block_map = makeTfLiteArray(sp->block_map); + tflsp.dim_metadata = makeTfLiteDimensionMetadata(sp->dim_metadata); + tflsp.dim_metadata_size = sp->dim_metadata.size(); + return tflsp; +} + +template <typename T> sparsity::TfLiteIntArray *makeTfLiteArray(const std::vector<T> &data) +{ + size_t cn = data.size(); + size_t sz = 1 + data.size(); + sparsity::TfLiteIntArray *sp = (sparsity::TfLiteIntArray *)(new int[sz]); + sp->size = cn; + for (size_t i = 0; i < cn; ++i) + { + sp->data[i] = data[i]; + } + return sp; +} + +sparsity::TfLiteDimensionMetadata * +makeTfLiteDimensionMetadata(const std::vector<luci::DimMetaData> &data) +{ + size_t cn = data.size(); + sparsity::TfLiteDimensionMetadata *tfldm = new sparsity::TfLiteDimensionMetadata[cn]; + + for (size_t i = 0; i < cn; ++i) + { + tfldm[i].format = to_tflite_sparsity(data[i].format()); + tfldm[i].dense_size = data[i].dense_size(); + tfldm[i].array_segments = to_tflite_sparsity(data[i].array_segments()); + tfldm[i].array_indices = to_tflite_sparsity(data[i].array_indices()); + } + + return tfldm; +} + +void freeTfLiteSparsity(sparsity::TfLiteSparsity &tflsp) +{ + assert(tflsp.traversal_order); + assert(tflsp.block_map); + delete[] tflsp.traversal_order; + delete[] tflsp.block_map; + + for (int i = 0; i < tflsp.dim_metadata_size; ++i) + { + assert(tflsp.dim_metadata[i].array_segments); + assert(tflsp.dim_metadata[i].array_indices); + delete[] tflsp.dim_metadata[i].array_segments; + delete[] tflsp.dim_metadata[i].array_indices; + } +} + +} // namespace luci diff --git a/compiler/luci/pass/src/helpers/SparsityFormatConverter.h b/compiler/luci/pass/src/helpers/SparsityFormatConverter.h new file mode 100644 index 000000000..fcd9bbcd0 --- /dev/null +++ b/compiler/luci/pass/src/helpers/SparsityFormatConverter.h @@ -0,0 +1,129 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_PASS_HELPERS_SPARSITY_FORMAT_CONVERTER_H__ +#define __LUCI_PASS_HELPERS_SPARSITY_FORMAT_CONVERTER_H__ + +#include <cstdint> +#include <vector> + +// codes under namespace sparsity referenced from +// https://github.com/tensorflow/tensorflow/blob/3f878cff5b698b82eea85db2b60d65a2e320850e/ +// tensorflow/lite/kernels/internal/utils/sparsity_format_converter.h +// tensorflow/lite/kernels/internal/utils/sparsity_format_converter.cc + +namespace sparsity +{ + +// Storage format of each dimension in a sparse tensor. +typedef enum TfLiteDimensionType +{ + kTfLiteDimDense = 0, + kTfLiteDimSparseCSR, +} TfLiteDimensionType; + +// Fixed size list of integers. Used for dimensions and inputs/outputs tensor +// indices +typedef struct TfLiteIntArray +{ + int size; + int data[]; +} TfLiteIntArray; + +// Metadata to encode each dimension in a sparse tensor. +typedef struct TfLiteDimensionMetadata +{ + TfLiteDimensionType format; + int dense_size; + TfLiteIntArray *array_segments; + TfLiteIntArray *array_indices; +} TfLiteDimensionMetadata; + +// Parameters used to encode a sparse tensor. For detailed explanation of each +// field please refer to lite/schema/schema.fbs. +typedef struct TfLiteSparsity +{ + TfLiteIntArray *traversal_order; + TfLiteIntArray *block_map; + TfLiteDimensionMetadata *dim_metadata; + int dim_metadata_size; +} TfLiteSparsity; + +// A converter that keeps an internal representation of sparse tensor parameters +// and converts tensors between dense and sparse formats. +template <typename T> class FormatConverter +{ +public: + /* Creates a sparse to dense converter. + * @param shape Shape of the target dense tensor. + * @param sparsity Sparsity parameter of the sparse TfLiteTensor. + */ + FormatConverter(const std::vector<int> &shape, const TfLiteSparsity &sparsity); + + const std::vector<T> &GetData() { return data_; } + const std::vector<std::vector<int>> &GetDimMetadata() { return dim_metadata_; } + + bool SparseToDense(const T *src_data); + +private: + // Helper function for initializing this converter for sparse to dense + // conversion. + void InitSparseToDenseConverter(std::vector<int> shape, std::vector<int> traversal_order, + std::vector<TfLiteDimensionType> format, + std::vector<int> dense_size, + std::vector<std::vector<int>> segments, + std::vector<std::vector<int>> indices, + std::vector<int> block_map); + + void Populate(const T *src_data, std::vector<int> indices, int level, int prev_idx, + int *src_data_ptr, T *dest_data); + +private: + std::vector<int> dense_shape_; + std::vector<int> blocked_shape_; + size_t dense_size_; + std::vector<int> traversal_order_; + std::vector<TfLiteDimensionType> format_; + std::vector<int> block_size_; + std::vector<int> block_map_; + std::vector<std::vector<int>> dim_metadata_; + std::vector<T> data_; +}; + +extern template class FormatConverter<float>; +extern template class FormatConverter<uint16_t>; + +} // namespace sparsity + +#include <luci/IR/SparsityParam.h> + +namespace luci +{ + +sparsity::TfLiteDimensionType to_tflite_sparsity(luci::DimensionType dt); +sparsity::TfLiteIntArray *to_tflite_sparsity(const luci::SparseIndexVector &data); +sparsity::TfLiteSparsity to_tflite_sparsity(const luci::SparsityParam *sp); + +template <typename T> sparsity::TfLiteIntArray *makeTfLiteArray(const std::vector<T> &data); +sparsity::TfLiteDimensionMetadata * +makeTfLiteDimensionMetadata(const std::vector<luci::DimMetaData> &data); + +void freeTfLiteSparsity(sparsity::TfLiteSparsity &tflsp); + +} // namespace luci + +#endif // __LUCI_PASS_HELPERS_SPARSITY_FORMAT_CONVERTER_H__ |