diff options
Diffstat (limited to 'compiler/luci')
194 files changed, 14022 insertions, 8555 deletions
diff --git a/compiler/luci/CMakeLists.txt b/compiler/luci/CMakeLists.txt index b92eefb40..460dc7b23 100644 --- a/compiler/luci/CMakeLists.txt +++ b/compiler/luci/CMakeLists.txt @@ -23,4 +23,8 @@ add_subdirectory(import) add_subdirectory(export) add_subdirectory(tester) +if(NOT ENABLE_TEST) + return() +endif(NOT ENABLE_TEST) + add_subdirectory(tests) diff --git a/compiler/luci/export/CMakeLists.txt b/compiler/luci/export/CMakeLists.txt index a267d0e1f..f46181eb6 100644 --- a/compiler/luci/export/CMakeLists.txt +++ b/compiler/luci/export/CMakeLists.txt @@ -12,7 +12,7 @@ target_include_directories(luci_export PUBLIC include) target_link_libraries(luci_export PRIVATE luci_lang) target_link_libraries(luci_export PRIVATE luci_service) target_link_libraries(luci_export PRIVATE luci_pass) -target_link_libraries(luci_export PRIVATE mio_circle) +target_link_libraries(luci_export PRIVATE mio_circle04) target_link_libraries(luci_export PRIVATE luci_env) target_link_libraries(luci_export PRIVATE luci_log) target_link_libraries(luci_export PRIVATE luci_logex) @@ -36,6 +36,6 @@ target_include_directories(luci_export_test PRIVATE src) target_link_libraries(luci_export_test luci_export) target_link_libraries(luci_export_test luci_plan) target_link_libraries(luci_export_test luci_lang) -target_link_libraries(luci_export_test mio_circle) +target_link_libraries(luci_export_test mio_circle04) target_link_libraries(luci_export_test luci_env) target_link_libraries(luci_export_test oops) diff --git a/compiler/luci/export/src/CircleBuiltinTypesExtractor.h b/compiler/luci/export/src/CircleBuiltinTypesExtractor.h new file mode 100644 index 000000000..0ff21a34b --- /dev/null +++ b/compiler/luci/export/src/CircleBuiltinTypesExtractor.h @@ -0,0 +1,539 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CIRCLE_BUILTIN_TYPES_EXTRACTOR_H__ +#define __CIRCLE_BUILTIN_TYPES_EXTRACTOR_H__ + +#include "CircleExporterUtils.h" + +#include <luci/IR/CircleNode.h> +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNodeVisitor.h> + +#include <flatbuffers/flexbuffers.h> + +namespace luci +{ + +// NOTE Virtual nodes are not circle builtin operators. +// Therefore, they are not defined here. +class BuiltinOptionsExtractor final + : public luci::CircleNodeMutableVisitor<flatbuffers::Offset<void>> +{ +public: + BuiltinOptionsExtractor(flatbuffers::FlatBufferBuilder &builder) : _builder{builder} + { + // DO NOTHING + } + +public: + flatbuffers::Offset<void> visit(luci::CircleAbs *) + { + return circle::CreateAbsOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleAdd *node) + { + return circle::CreateAddOptions(_builder, to_circle_actfunc(node->fusedActivationFunction())) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleAddN *) + { + return circle::CreateAddNOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleArgMax *node) + { + return circle::CreateArgMaxOptions(_builder, luci::to_circle_tensortype(node->output_type())) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleArgMin *node) + { + return circle::CreateArgMinOptions(_builder, luci::to_circle_tensortype(node->output_type())) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleAveragePool2D *node) + { + return circle::CreatePool2DOptions(_builder, getOpPadding(node->padding()), node->stride()->w(), + node->stride()->h(), node->filter()->w(), + node->filter()->h(), + to_circle_actfunc(node->fusedActivationFunction())) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleBatchMatMul *node) + { + return circle::CreateBatchMatMulOptions(_builder, node->adj_x(), node->adj_y()).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleBatchToSpaceND *) + { + return circle::CreateBatchToSpaceNDOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleBidirectionalSequenceLSTM *node) + { + return circle::CreateBidirectionalSequenceLSTMOptions( + _builder, to_circle_actfunc(node->fusedActivationFunction()), node->cell_clip(), + node->proj_clip(), node->merge_outputs(), node->time_major(), + node->asymmetric_quantize_inputs()) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleCast *node) + { + if (node->out_data_type() == loco::DataType::Unknown) + return _no_option; + else + return circle::CreateCastOptions(_builder, luci::to_circle_tensortype(node->in_data_type()), + luci::to_circle_tensortype(node->out_data_type())) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleCeil *) { return _no_option; } + flatbuffers::Offset<void> visit(luci::CircleConcatenation *node) + { + return circle::CreateConcatenationOptions(_builder, node->axis(), + to_circle_actfunc(node->fusedActivationFunction())) + .Union(); + } + // CircleConst is not virtual but not builtinOperator + // flatbuffers::Offset<void> visit(luci::CircleConst *) + flatbuffers::Offset<void> visit(luci::CircleConv2D *node) + { + return circle::CreateConv2DOptions(_builder, getOpPadding(node->padding()), node->stride()->w(), + node->stride()->h(), + to_circle_actfunc(node->fusedActivationFunction()), + node->dilation()->w(), node->dilation()->h()) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleCos *) + { + return circle::CreateCosOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleCustom *) { return _no_option; } + flatbuffers::Offset<void> visit(luci::CircleDepthToSpace *node) + { + return circle::CreateDepthToSpaceOptions(_builder, node->block_size()).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleDepthwiseConv2D *node) + { + return circle::CreateDepthwiseConv2DOptions( + _builder, getOpPadding(node->padding()), node->stride()->w(), node->stride()->h(), + node->depthMultiplier(), to_circle_actfunc(node->fusedActivationFunction()), + node->dilation()->w(), node->dilation()->h()) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleDequantize *) { return _no_option; } + flatbuffers::Offset<void> visit(luci::CircleDiv *node) + { + return circle::CreateDivOptions(_builder, to_circle_actfunc(node->fusedActivationFunction())) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleElu *) { return _no_option; } + flatbuffers::Offset<void> visit(luci::CircleEqual *) + { + return circle::CreateEqualOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleExp *) + { + return circle::CreateExpOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleExpandDims *) + { + return circle::CreateExpandDimsOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleFakeQuant *node) + { + return circle::CreateFakeQuantOptions(_builder, node->min(), node->max(), node->num_bits(), + node->narrow_range()) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleFill *) + { + return circle::CreateFillOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleFloor *) { return _no_option; } + flatbuffers::Offset<void> visit(luci::CircleFloorDiv *) + { + return circle::CreateFloorDivOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleFloorMod *) + { + return circle::CreateFloorModOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleFullyConnected *node) + { + return circle::CreateFullyConnectedOptions( + _builder, to_circle_actfunc(node->fusedActivationFunction()), + to_circle_weightsformat(node->weights_format()), node->keep_num_dims()) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleGather *node) + { + return circle::CreateGatherOptions(_builder, node->axis()).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleGatherNd *) + { + return circle::CreateGatherNdOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleGreater *) + { + return circle::CreateGreaterOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleGreaterEqual *) + { + return circle::CreateGreaterEqualOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleIf *node) + { + return circle::CreateIfOptions(_builder, node->then_branch(), node->else_branch()).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleL2Normalize *node) + { + return circle::CreateL2NormOptions(_builder, to_circle_actfunc(node->fusedActivationFunction())) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleL2Pool2D *node) + { + return circle::CreatePool2DOptions(_builder, getOpPadding(node->padding()), node->stride()->w(), + node->stride()->h(), node->filter()->w(), + node->filter()->h(), + to_circle_actfunc(node->fusedActivationFunction())) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleLeakyRelu *node) + { + return circle::CreateLeakyReluOptions(_builder, node->alpha()).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleLess *) + { + return circle::CreateLessOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleLessEqual *) + { + return circle::CreateLessEqualOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleLocalResponseNormalization *node) + { + return circle::CreateLocalResponseNormalizationOptions(_builder, node->radius(), node->bias(), + node->alpha(), node->beta()) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleLog *) { return _no_option; } + flatbuffers::Offset<void> visit(luci::CircleLogicalAnd *) + { + return circle::CreateLogicalAndOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleLogicalNot *) + { + return circle::CreateLogicalNotOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleLogicalOr *) + { + return circle::CreateLogicalOrOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleLogistic *) { return _no_option; } + flatbuffers::Offset<void> visit(luci::CircleLogSoftmax *) + { + return circle::CreateLogSoftmaxOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleMatrixDiag *) + { + return circle::CreateMatrixDiagOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleMatrixSetDiag *) + { + return circle::CreateMatrixSetDiagOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleMaximum *) + { + return circle::CreateMaximumMinimumOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleMaxPool2D *node) + { + return circle::CreatePool2DOptions(_builder, getOpPadding(node->padding()), node->stride()->w(), + node->stride()->h(), node->filter()->w(), + node->filter()->h(), + to_circle_actfunc(node->fusedActivationFunction())) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleMean *node) + { + return circle::CreateReducerOptions(_builder, node->keep_dims()).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleMinimum *) + { + return circle::CreateMaximumMinimumOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleMirrorPad *node) + { + return circle::CreateMirrorPadOptions(_builder, to_circle_mirrorpadmode(node->mode())).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleMul *node) + { + return circle::CreateMulOptions(_builder, to_circle_actfunc(node->fusedActivationFunction())) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleNeg *) + { + return circle::CreateNegOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleNonMaxSuppressionV4 *) + { + return circle::CreateNonMaxSuppressionV4Options(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleNonMaxSuppressionV5 *) + { + return circle::CreateNonMaxSuppressionV5Options(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleNotEqual *) + { + return circle::CreateNotEqualOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleOneHot *node) + { + return circle::CreateOneHotOptions(_builder, node->axis()).Union(); + } + flatbuffers::Offset<void> visit(luci::CirclePack *node) + { + return circle::CreatePackOptions(_builder, node->values_count(), node->axis()).Union(); + } + flatbuffers::Offset<void> visit(luci::CirclePad *) + { + return circle::CreatePadOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CirclePadV2 *) + { + return circle::CreatePadV2Options(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CirclePow *) + { + return circle::CreatePowOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CirclePRelu *) { return _no_option; } + flatbuffers::Offset<void> visit(luci::CircleQuantize *) { return _no_option; } + flatbuffers::Offset<void> visit(luci::CircleRange *) + { + return circle::CreateRangeOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleRank *) + { + return circle::CreateRankOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleReduceAny *node) + { + return circle::CreateReducerOptions(_builder, node->keep_dims()).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleReduceMax *node) + { + return circle::CreateReducerOptions(_builder, node->keep_dims()).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleReduceMin *node) + { + return circle::CreateReducerOptions(_builder, node->keep_dims()).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleReduceProd *node) + { + return circle::CreateReducerOptions(_builder, node->keep_dims()).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleRelu *) { return _no_option; } + flatbuffers::Offset<void> visit(luci::CircleRelu6 *) { return _no_option; } + flatbuffers::Offset<void> visit(luci::CircleReluN1To1 *) { return _no_option; } + flatbuffers::Offset<void> visit(luci::CircleReshape *node) + { + auto new_shape = _builder.CreateVector<int32_t>( + node->newShape()->rank(), [node](size_t i) { return node->newShape()->dim(i); }); + return circle::CreateReshapeOptions(_builder, new_shape).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleResizeBilinear *node) + { + return circle::CreateResizeBilinearOptions(_builder, node->align_corners(), + node->half_pixel_centers()) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleResizeNearestNeighbor *node) + { + return circle::CreateResizeNearestNeighborOptions(_builder, node->align_corners()).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleReverseSequence *node) + { + return circle::CreateReverseSequenceOptions(_builder, node->seq_axis(), node->batch_axis()) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleReverseV2 *) + { + return circle::CreateReverseV2Options(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleRound *) { return _no_option; } + flatbuffers::Offset<void> visit(luci::CircleRsqrt *) { return _no_option; } + flatbuffers::Offset<void> visit(luci::CircleScatterNd *) + { + return circle::CreateScatterNdOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleSegmentSum *) + { + return circle::CreateSegmentSumOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleSelect *) + { + return circle::CreateSelectOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleSelectV2 *) + { + return circle::CreateSelectV2Options(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleShape *node) + { + return circle::CreateShapeOptions(_builder, luci::to_circle_tensortype(node->out_type())) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleSin *) { return _no_option; } + flatbuffers::Offset<void> visit(luci::CircleSlice *) + { + return circle::CreateSliceOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleSoftmax *node) + { + return circle::CreateSoftmaxOptions(_builder, node->beta()).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleSpaceToBatchND *) + { + return circle::CreateSpaceToBatchNDOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleSpaceToDepth *node) + { + return circle::CreateSpaceToDepthOptions(_builder, node->block_size()).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleSparseToDense *node) + { + return circle::CreateSparseToDenseOptions(_builder, node->validate_indices()).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleSplit *node) + { + return circle::CreateSplitOptions(_builder, node->num_split()).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleSplitV *node) + { + return circle::CreateSplitVOptions(_builder, node->num_split()).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleSqrt *) { return _no_option; } + flatbuffers::Offset<void> visit(luci::CircleSquare *) + { + return circle::CreateSquareOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleSquaredDifference *) + { + return circle::CreateSquaredDifferenceOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleSqueeze *node) + { + auto squeeze_dims = _builder.CreateVector<int32_t>(node->squeeze_dims()); + return circle::CreateSqueezeOptions(_builder, squeeze_dims).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleStridedSlice *node) + { + return circle::CreateStridedSliceOptions(_builder, node->begin_mask(), node->end_mask(), + node->ellipsis_mask(), node->new_axis_mask(), + node->shrink_axis_mask()) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleSub *node) + { + return circle::CreateSubOptions(_builder, to_circle_actfunc(node->fusedActivationFunction())) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleSum *node) + { + return circle::CreateReducerOptions(_builder, node->keep_dims()).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleSVDF *node) + { + return circle::CreateSVDFOptions(_builder, node->svdf_rank(), + to_circle_actfunc(node->fusedActivationFunction()), + node->asymmetric_quantize_inputs()) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleTanh *) { return _no_option; } + flatbuffers::Offset<void> visit(luci::CircleTile *) + { + return circle::CreateTileOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleTopKV2 *) + { + return circle::CreateTopKV2Options(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleTranspose *) + { + return circle::CreateTransposeOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleTransposeConv *node) + { + return circle::CreateTransposeConvOptions(_builder, getOpPadding(node->padding()), + node->stride()->w(), node->stride()->h()) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleUnidirectionalSequenceLSTM *node) + { + return circle::CreateUnidirectionalSequenceLSTMOptions( + _builder, to_circle_actfunc(node->fusedActivationFunction()), node->cell_clip(), + node->proj_clip(), node->time_major(), node->asymmetric_quantize_inputs()) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleUnique *node) + { + return circle::CreateUniqueOptions(_builder, luci::to_circle_tensortype(node->idx_out_type())) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleUnpack *node) + { + return circle::CreateUnpackOptions(_builder, node->num(), node->axis()).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleWhere *) + { + return circle::CreateWhereOptions(_builder).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleWhile *node) + { + return circle::CreateWhileOptions(_builder, node->cond_branch(), node->body_branch()).Union(); + } + flatbuffers::Offset<void> visit(luci::CircleZerosLike *) + { + return circle::CreateZerosLikeOptions(_builder).Union(); + } + // Circle only + flatbuffers::Offset<void> visit(luci::CircleBCQFullyConnected *node) + { + return circle::CreateBCQFullyConnectedOptions( + _builder, node->weights_hidden_size(), + to_circle_actfunc(node->fusedActivationFunction())) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleBCQGather *node) + { + return circle::CreateBCQGatherOptions(_builder, node->input_hidden_size(), node->axis()) + .Union(); + } + flatbuffers::Offset<void> visit(luci::CircleInstanceNorm *node) + { + return circle::CreateInstanceNormOptions(_builder, node->epsilon(), + to_circle_actfunc(node->fusedActivationFunction())) + .Union(); + } + +protected: + flatbuffers::FlatBufferBuilder &_builder; + +private: + const flatbuffers::Offset<void> _no_option = 0; +}; + +} // namespace luci + +#endif // __CIRCLE_BUILTIN_TYPES_EXTRACTOR_H__ diff --git a/compiler/luci/export/src/CircleBuiltinTypesMappingRule.h b/compiler/luci/export/src/CircleBuiltinTypesMappingRule.h new file mode 100644 index 000000000..6f7c0f70e --- /dev/null +++ b/compiler/luci/export/src/CircleBuiltinTypesMappingRule.h @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CIRCLE_EXPORT_BUILTIN_TYPES_MAPPING_RULE_H__ +#define __CIRCLE_EXPORT_BUILTIN_TYPES_MAPPING_RULE_H__ + +#include <luci/IR/CircleNode.h> +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNodeVisitor.h> + +namespace luci +{ + +class BuiltinOperatorMappingRule final : public CircleNodeVisitor<circle::BuiltinOperator> +{ +public: + BuiltinOperatorMappingRule() + { + // DO NOTHING + } + +public: + static BuiltinOperatorMappingRule &get() + { + static BuiltinOperatorMappingRule instance; + return instance; + } + +public: +#define CIRCLE_NODE(CIRCLE_NODE, OP, OPTION) \ + circle::BuiltinOperator visit(const CIRCLE_NODE *) final { return circle::OP; } +// Virtual nodes are not circle builtin operator +#define CIRCLE_VNODE(CIRCLE_NODE) +#include "CircleOps.lst" +#undef CIRCLE_VNODE +#undef CIRCLE_NODE +}; + +class BuiltinOptionsMappingRule final : public CircleNodeVisitor<circle::BuiltinOptions> +{ +public: + BuiltinOptionsMappingRule() + { + // DO NOTHING + } + +public: + static BuiltinOptionsMappingRule &get() + { + static BuiltinOptionsMappingRule instance; + return instance; + } + +public: +#define CIRCLE_NODE(CIRCLE_NODE, OP, OPTION) \ + circle::BuiltinOptions visit(const CIRCLE_NODE *) final { return circle::OPTION; } +// Virtual nodes are not circle builtin operator +#define CIRCLE_VNODE(CIRCLE_NODE) +#include "CircleOps.lst" +#undef CIRCLE_VNODE +#undef CIRCLE_NODE +}; + +} // namespace luci + +#endif // __CIRCLE_EXPORT_BUILTIN_TYPES_MAPPING_RULE_H__ diff --git a/compiler/luci/export/src/CircleExporterImpl.cpp b/compiler/luci/export/src/CircleExporterImpl.cpp index 5868c176c..083add9be 100644 --- a/compiler/luci/export/src/CircleExporterImpl.cpp +++ b/compiler/luci/export/src/CircleExporterImpl.cpp @@ -79,14 +79,19 @@ encodeOperatorCodes(FlatBufferBuilder &builder, std::unordered_map<luci::OpCode, for (auto it : opcodes) { uint32_t idx = it.second; + int8_t dep_code = 127; // BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES + if (it.first.opcode < BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES) + dep_code = static_cast<int8_t>(it.first.opcode); if (it.first.opcode != BuiltinOperator_CUSTOM) { - operator_codes_vec[idx] = CreateOperatorCode(builder, it.first.opcode, 0, it.first.version); + operator_codes_vec[idx] = + CreateOperatorCode(builder, dep_code, 0, it.first.version, it.first.opcode); } else { operator_codes_vec[idx] = - CreateOperatorCode(builder, it.first.opcode, builder.CreateString(it.first.custom_code)); + CreateOperatorCode(builder, dep_code, builder.CreateString(it.first.custom_code), + it.first.version, it.first.opcode); } } diff --git a/compiler/luci/export/src/CircleExporterUtils.cpp b/compiler/luci/export/src/CircleExporterUtils.cpp index 3a7ba304f..9473c2c4e 100644 --- a/compiler/luci/export/src/CircleExporterUtils.cpp +++ b/compiler/luci/export/src/CircleExporterUtils.cpp @@ -15,6 +15,7 @@ */ #include "CircleExporterUtils.h" +#include "CircleBuiltinTypesMappingRule.h" #include <oops/InternalExn.h> @@ -163,36 +164,63 @@ circle::SparseIndexVector to_circle_sparse_index_vector_type(luci::SparseIndexVe } } -} // namespace luci +circle::BuiltinOperator circle_builtin_operator(const luci::CircleNode *node) +{ + return node->accept(&BuiltinOperatorMappingRule::get()); +} -namespace luci +circle::BuiltinOptions circle_builtin_options(const luci::CircleNode *node) { + if (auto cast = dynamic_cast<const luci::CircleCast *>(node)) + { + return (cast->out_data_type() == loco::DataType::Unknown) ? circle::BuiltinOptions_NONE + : circle::BuiltinOptions_CastOptions; + } -uint32_t SerializedModelData::registerBuiltinOpcode(circle::BuiltinOperator builtin_code, - const int32_t op_version) + return node->accept(&BuiltinOptionsMappingRule::get()); +} + +std::string circle_custom_code(const luci::CircleNode *node) { - assert(op_version > 0); + if (auto custom_node = dynamic_cast<const luci::CircleCustom *>(node)) + { + return custom_node->custom_code(); + } - auto it = _operator_codes.find(OpCode{builtin_code, "", op_version}); - if (it != _operator_codes.end()) + return ""; +} + +flatbuffers::Offset<flatbuffers::Vector<uint8_t>> +circle_custom_options(flatbuffers::FlatBufferBuilder &fb, const luci::CircleNode *node) +{ + if (auto custom_node = dynamic_cast<const luci::CircleCustom *>(node)) { - return it->second; + std::vector<uint8_t> custom_options_vec{custom_node->custom_options().begin(), + custom_node->custom_options().end()}; + return fb.CreateVector(custom_options_vec); } - auto idx = static_cast<uint32_t>(_operator_codes.size()); - _operator_codes.emplace(OpCode{builtin_code, "", op_version}, idx); - return idx; + + return 0; } -uint32_t SerializedModelData::registerCustomOpcode(const std::string &custom_code) +} // namespace luci + +namespace luci { - const circle::BuiltinOperator builtin_code = circle::BuiltinOperator_CUSTOM; - auto it = _operator_codes.find(OpCode{builtin_code, custom_code}); + +uint32_t SerializedModelData::registerBuiltinOpcode(circle::BuiltinOperator builtin_code, + const std::string &custom_code, + const int32_t op_version) +{ + assert(op_version > 0); + + auto it = _operator_codes.find(OpCode{builtin_code, custom_code, op_version}); if (it != _operator_codes.end()) { return it->second; } auto idx = static_cast<uint32_t>(_operator_codes.size()); - _operator_codes.emplace(OpCode{builtin_code, custom_code}, idx); + _operator_codes.emplace(OpCode{builtin_code, custom_code, op_version}, idx); return idx; } diff --git a/compiler/luci/export/src/CircleExporterUtils.h b/compiler/luci/export/src/CircleExporterUtils.h index 95310b353..4a4c54a69 100644 --- a/compiler/luci/export/src/CircleExporterUtils.h +++ b/compiler/luci/export/src/CircleExporterUtils.h @@ -39,6 +39,12 @@ flatbuffers::Offset<void> to_circle_sparse_index_vector(flatbuffers::FlatBufferB const SparseIndexVector &sparse_idx_vec); circle::SparseIndexVector to_circle_sparse_index_vector_type(luci::SparseIndexVectorType type); +circle::BuiltinOperator circle_builtin_operator(const luci::CircleNode *node); +circle::BuiltinOptions circle_builtin_options(const luci::CircleNode *node); +std::string circle_custom_code(const luci::CircleNode *node); +flatbuffers::Offset<flatbuffers::Vector<uint8_t>> +circle_custom_options(flatbuffers::FlatBufferBuilder &fb, const luci::CircleNode *node); + } // namespace luci namespace luci diff --git a/compiler/luci/export/src/CircleOperationExporter.cpp b/compiler/luci/export/src/CircleOperationExporter.cpp index be64a52d4..b300a7fcf 100644 --- a/compiler/luci/export/src/CircleOperationExporter.cpp +++ b/compiler/luci/export/src/CircleOperationExporter.cpp @@ -15,1686 +15,30 @@ */ #include "CircleOperationExporter.h" -#include "CircleExporterUtils.h" -#include "Check.h" +#include "CircleOperationExporterRule.h" #include <luci/IR/CircleNode.h> -#include <luci/IR/CircleNodes.h> -#include <luci/IR/CircleNodeVisitor.h> #include <luci/Profile/CircleNodeOrigin.h> #include <luci/Plan/CircleNodeExecutionPlan.h> -#include <luci/UserSettings.h> -#include <luci/Log.h> +#include <loco/IR/Algorithm.h> -#include <loco/IR/CanonicalNodeVisitor.h> -#include <oops/InternalExn.h> - -#include <flatbuffers/flexbuffers.h> - -using namespace flatbuffers; -using namespace circle; - -namespace -{ - -using namespace luci; - -struct ExportContext -{ - FlatBufferBuilder &builder; - SerializedModelData &md; - SerializedGraphData &gd; -}; - -/** - * @brief Exports CircleMaxPool2D or CircleAveragePool2D - * - * @note CirclePool2D should be one of CircleMaxPool2D or CircleAveragePool2D - */ -template <class CirclePool2D> -void export_pool_2d(ExportContext &ctx, CirclePool2D *node, circle::BuiltinOperator builtin_op) -{ - LUCI_ASSERT(builtin_op == circle::BuiltinOperator_MAX_POOL_2D || - builtin_op == circle::BuiltinOperator_L2_POOL_2D || - builtin_op == circle::BuiltinOperator_AVERAGE_POOL_2D, - "Should be L2Pool, MaxPool or AvgPool"); - LUCI_ASSERT(node->padding() != luci::Padding::UNDEFINED, "Padding is not set"); - - uint32_t op_idx = ctx.md.registerBuiltinOpcode(builtin_op, node->op_version()); - std::vector<int32_t> inputs_vec{get_tensor_index(node->value())}; - std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; - auto inputs = ctx.builder.CreateVector(inputs_vec); - auto outputs = ctx.builder.CreateVector(outputs_vec); - - circle::Padding padding = getOpPadding(node->padding()); - - auto options = CreatePool2DOptions(ctx.builder, padding, node->stride()->w(), node->stride()->h(), - node->filter()->w(), node->filter()->h(), - to_circle_actfunc(node->fusedActivationFunction())); - auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs, - circle::BuiltinOptions_Pool2DOptions, options.Union()); - ctx.gd._operators.push_back(op_offset); -} - -/** - * @brief export simple nodes - */ -void export_node(ExportContext &ctx, loco::Node *node, circle::BuiltinOperator bop, - circle::BuiltinOptions bot, flatbuffers::Offset<void> options_offset) -{ - uint32_t op_idx = - ctx.md.registerBuiltinOpcode(bop, loco::must_cast<luci::CircleNode *>(node)->op_version()); - std::vector<int32_t> inputs_vec; - std::vector<int32_t> outputs_vec{get_tensor_index(node)}; - for (uint32_t i = 0; i < node->arity(); ++i) - inputs_vec.push_back(get_tensor_index(node->arg(i))); - auto inputs = ctx.builder.CreateVector(inputs_vec); - auto outputs = ctx.builder.CreateVector(outputs_vec); - auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs, bot, options_offset); - ctx.gd._operators.push_back(op_offset); -} - -/** - * @brief export simple nodes having void options - */ -void export_node(ExportContext &ctx, loco::Node *node, circle::BuiltinOperator bop) -{ - uint32_t op_idx = - ctx.md.registerBuiltinOpcode(bop, loco::must_cast<luci::CircleNode *>(node)->op_version()); - std::vector<int32_t> inputs_vec; - std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; - for (uint32_t i = 0; i < node->arity(); ++i) - inputs_vec.push_back(get_tensor_index(node->arg(i))); - auto inputs = ctx.builder.CreateVector(inputs_vec); - auto outputs = ctx.builder.CreateVector(outputs_vec); - auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs); - ctx.gd._operators.push_back(op_offset); -} - -void export_node(ExportContext &ctx, luci::CircleAddN *node) -{ - uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_ADD_N, node->op_version()); - std::vector<int32_t> inputs_vec; - std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; - - for (uint32_t i = 0; i < node->arity(); ++i) - inputs_vec.push_back(get_tensor_index(node->inputs(i))); - - auto inputs = ctx.builder.CreateVector(inputs_vec); - auto outputs = ctx.builder.CreateVector(outputs_vec); - auto options = CreateAddNOptions(ctx.builder); - auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs, - circle::BuiltinOptions_AddNOptions, options.Union()); - ctx.gd._operators.push_back(op_offset); -} - -void export_node(ExportContext &ctx, luci::CircleCast *node) -{ - uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_CAST, node->op_version()); - std::vector<int32_t> inputs_vec{get_tensor_index(node->x())}; - std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; - auto inputs = ctx.builder.CreateVector(inputs_vec); - auto outputs = ctx.builder.CreateVector(outputs_vec); - - flatbuffers::Offset<Operator> op_offset; - if (node->out_data_type() != loco::DataType::Unknown) - { - auto options = CreateCastOptions(ctx.builder, to_circle_tensortype(node->in_data_type()), - to_circle_tensortype(node->out_data_type())); - op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs, - circle::BuiltinOptions_CastOptions, options.Union()); - } - else - { - op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs); - } - ctx.gd._operators.push_back(op_offset); -} - -void export_node(ExportContext &ctx, luci::CircleConcatenation *node) -{ - uint32_t op_idx = - ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_CONCATENATION, node->op_version()); - std::vector<int32_t> inputs_vec; - std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; - - for (uint32_t i = 0; i < node->numValues(); ++i) - inputs_vec.push_back(get_tensor_index(node->values(i))); - - auto inputs = ctx.builder.CreateVector(inputs_vec); - auto outputs = ctx.builder.CreateVector(outputs_vec); - auto options = CreateConcatenationOptions(ctx.builder, node->axis(), - to_circle_actfunc(node->fusedActivationFunction())); - auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs, - circle::BuiltinOptions_ConcatenationOptions, options.Union()); - ctx.gd._operators.push_back(op_offset); -} - -void export_node(ExportContext &ctx, luci::CircleCustom *node) -{ - auto custom_outputs = loco::succs(node); - assert(custom_outputs.size() == node->numOutputs()); - - uint32_t op_idx = ctx.md.registerCustomOpcode(node->custom_code()); - std::vector<int32_t> inputs_vec; - std::vector<int32_t> outputs_vec; - - for (uint32_t index = 0; index < node->numInputs(); index++) - { - inputs_vec.push_back(get_tensor_index(node->inputs(index))); - } - for (uint32_t index = 0; index < custom_outputs.size(); index++) - { - // store in order of index - bool found = false; - for (auto out : custom_outputs) - { - auto custom_out = loco::must_cast<luci::CircleCustomOut *>(out); - if (custom_out->index() == static_cast<int32_t>(index)) - { - outputs_vec.push_back(get_tensor_index(custom_out)); - found = true; - break; - } - } - if (!found) - { - INTERNAL_EXN("Invalid Custom output"); - } - } - - auto inputs = ctx.builder.CreateVector(inputs_vec); - auto outputs = ctx.builder.CreateVector(outputs_vec); - flatbuffers::Offset<flatbuffers::Vector<uint8_t>> circle_custom_options; - std::vector<uint8_t> custom_options_vec{node->custom_options().begin(), - node->custom_options().end()}; - circle_custom_options = ctx.builder.CreateVector(custom_options_vec); - auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs, circle::BuiltinOptions_NONE, - flatbuffers::Offset<void>(), circle_custom_options); - ctx.gd._operators.push_back(op_offset); -} - -void export_node(ExportContext &ctx, luci::CircleIf *node) -{ - auto if_outs = loco::succs(node); - assert(if_outs.size() == node->output_count()); - - uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_IF, node->op_version()); - std::vector<int32_t> inputs_vec; - std::vector<int32_t> outputs_vec; - - inputs_vec.push_back(get_tensor_index(node->cond())); - for (uint32_t idx = 0; idx < node->input_count(); ++idx) - inputs_vec.push_back(get_tensor_index(node->input(idx))); - - for (uint32_t idx = 0; idx < node->output_count(); ++idx) - { - // store in order of index - bool found = false; - for (auto out : if_outs) - { - auto if_out = loco::must_cast<luci::CircleIfOut *>(out); - if (if_out->index() == static_cast<int32_t>(idx)) - { - outputs_vec.push_back(get_tensor_index(if_out)); - found = true; - break; - } - } - if (!found) - { - INTERNAL_EXN("Invalid CircleIf output"); - } - } - - auto inputs = ctx.builder.CreateVector(inputs_vec); - auto outputs = ctx.builder.CreateVector(outputs_vec); - auto options = CreateIfOptions(ctx.builder, node->then_branch(), node->else_branch()); - auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs, - circle::BuiltinOptions_IfOptions, options.Union()); - ctx.gd._operators.push_back(op_offset); -} - -void export_node(ExportContext &ctx, luci::CircleNonMaxSuppressionV4 *node) -{ - auto nms_outs = loco::succs(node); - assert(nms_outs.size() == 2); - - uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_NON_MAX_SUPPRESSION_V4, - node->op_version()); - std::vector<int32_t> inputs_vec{ - get_tensor_index(node->boxes()), get_tensor_index(node->scores()), - get_tensor_index(node->max_output_size()), get_tensor_index(node->iou_threshold()), - get_tensor_index(node->score_threshold()), - }; - std::vector<int32_t> outputs_vec; - - for (uint32_t idx = 0; idx < nms_outs.size(); ++idx) - { - // store in order of index - bool found = false; - for (auto out : nms_outs) - { - auto nms_out = loco::must_cast<luci::CircleNonMaxSuppressionV4Out *>(out); - if (nms_out->index() == static_cast<int32_t>(idx)) - { - outputs_vec.push_back(get_tensor_index(nms_out)); - found = true; - break; - } - } - if (!found) - { - INTERNAL_EXN("Invalid NonMaxSuppressionV4 output"); - } - } - - auto inputs = ctx.builder.CreateVector(inputs_vec); - auto outputs = ctx.builder.CreateVector(outputs_vec); - auto options = CreateNonMaxSuppressionV4Options(ctx.builder); - auto op_offset = - CreateOperator(ctx.builder, op_idx, inputs, outputs, - circle::BuiltinOptions_NonMaxSuppressionV4Options, options.Union()); - ctx.gd._operators.push_back(op_offset); -} - -void export_node(ExportContext &ctx, luci::CircleNonMaxSuppressionV5 *node) -{ - auto nms_outs = loco::succs(node); - assert(nms_outs.size() == 3); - - uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_NON_MAX_SUPPRESSION_V5, - node->op_version()); - std::vector<int32_t> inputs_vec{ - get_tensor_index(node->boxes()), get_tensor_index(node->scores()), - get_tensor_index(node->max_output_size()), get_tensor_index(node->iou_threshold()), - get_tensor_index(node->score_threshold()), get_tensor_index(node->soft_nms_sigma()), - }; - std::vector<int32_t> outputs_vec; - - for (uint32_t idx = 0; idx < nms_outs.size(); ++idx) - { - // store in order of index - bool found = false; - for (auto out : nms_outs) - { - auto nms_out = loco::must_cast<luci::CircleNonMaxSuppressionV5Out *>(out); - if (nms_out->index() == static_cast<int32_t>(idx)) - { - outputs_vec.push_back(get_tensor_index(nms_out)); - found = true; - break; - } - } - if (!found) - { - INTERNAL_EXN("Invalid NonMaxSuppressionV5 output"); - } - } - - auto inputs = ctx.builder.CreateVector(inputs_vec); - auto outputs = ctx.builder.CreateVector(outputs_vec); - auto options = CreateNonMaxSuppressionV5Options(ctx.builder); - auto op_offset = - CreateOperator(ctx.builder, op_idx, inputs, outputs, - circle::BuiltinOptions_NonMaxSuppressionV5Options, options.Union()); - ctx.gd._operators.push_back(op_offset); -} - -void export_node(ExportContext &ctx, luci::CircleReverseV2 *node) -{ - uint32_t op_idx = - ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_REVERSE_V2, node->op_version()); - std::vector<int32_t> inputs_vec{get_tensor_index(node->tensor()), get_tensor_index(node->axis())}; - std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; - auto inputs = ctx.builder.CreateVector(inputs_vec); - auto outputs = ctx.builder.CreateVector(outputs_vec); - auto options = CreateReverseV2Options(ctx.builder); - auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs, - circle::BuiltinOptions_ReverseSequenceOptions, options.Union()); - ctx.gd._operators.push_back(op_offset); -} - -void export_node(ExportContext &ctx, luci::CircleSplit *node) -{ - auto split_outs = loco::succs(node); - assert(int32_t(split_outs.size()) == node->num_split()); - - uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_SPLIT, node->op_version()); - // NOTE BuiltinOperator_SPLIT input is placed at second position - std::vector<int32_t> inputs_vec{get_tensor_index(node->split_dim()), - get_tensor_index(node->input())}; - std::vector<int32_t> outputs_vec; - - for (int32_t index = 0; index < node->num_split(); index++) - { - // store in order of index - bool found = false; - for (auto out : split_outs) - { - auto split_out = loco::must_cast<luci::CircleSplitOut *>(out); - if (split_out->index() == index) - { - outputs_vec.push_back(get_tensor_index(split_out)); - found = true; - break; - } - } - if (!found) - { - INTERNAL_EXN("Invalid Split output"); - } - } - - auto inputs = ctx.builder.CreateVector(inputs_vec); - auto outputs = ctx.builder.CreateVector(outputs_vec); - auto options = CreateSplitOptions(ctx.builder, node->num_split()); - auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs, - circle::BuiltinOptions_SplitOptions, options.Union()); - ctx.gd._operators.push_back(op_offset); -} - -void export_node(ExportContext &ctx, luci::CircleSplitV *node) -{ - auto split_outs = loco::succs(node); - assert(int32_t(split_outs.size()) == node->num_split()); - - uint32_t op_idx = - ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_SPLIT_V, node->op_version()); - std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), - get_tensor_index(node->size_splits()), - get_tensor_index(node->split_dim())}; - std::vector<int32_t> outputs_vec; - - for (int32_t index = 0; index < node->num_split(); index++) - { - // store in order of index - bool found = false; - for (auto out : split_outs) - { - auto split_out = loco::must_cast<luci::CircleSplitVOut *>(out); - if (split_out->index() == index) - { - outputs_vec.push_back(get_tensor_index(split_out)); - found = true; - break; - } - } - if (!found) - { - INTERNAL_EXN("Invalid SplitV output"); - } - } - - auto inputs = ctx.builder.CreateVector(inputs_vec); - auto outputs = ctx.builder.CreateVector(outputs_vec); - auto options = CreateSplitVOptions(ctx.builder, node->num_split()); - auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs, - circle::BuiltinOptions_SplitVOptions, options.Union()); - ctx.gd._operators.push_back(op_offset); -} - -void export_node(ExportContext &ctx, luci::CircleTopKV2 *node) -{ - auto topkv2_outs = loco::succs(node); - int outs_count = int32_t(topkv2_outs.size()); - assert(outs_count == 2); - - uint32_t op_idx = - ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_TOPK_V2, node->op_version()); - std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), get_tensor_index(node->k())}; - std::vector<int32_t> outputs_vec; - - for (int32_t index = 0; index < outs_count; index++) - { - // store in order of index - bool found = false; - for (auto out : topkv2_outs) - { - auto topkv2_out = loco::must_cast<luci::CircleTopKV2Out *>(out); - if (topkv2_out->index() == index) - { - outputs_vec.push_back(get_tensor_index(topkv2_out)); - found = true; - break; - } - } - if (!found) - { - INTERNAL_EXN("Invalid TopKV2 output"); - } - } - - auto inputs = ctx.builder.CreateVector(inputs_vec); - auto outputs = ctx.builder.CreateVector(outputs_vec); - auto options = CreateTopKV2Options(ctx.builder); - auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs, - circle::BuiltinOptions_TopKV2Options, options.Union()); - ctx.gd._operators.push_back(op_offset); -} - -void export_node(ExportContext &ctx, luci::CircleUnique *node) -{ - auto unique_outs = loco::succs(node); - assert(int32_t(unique_outs.size()) == 2); - uint32_t op_idx = - ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_UNIQUE, node->op_version()); - - std::vector<int32_t> inputs_vec{get_tensor_index(node->input())}; - std::vector<int32_t> outputs_vec; - - for (int32_t index = 0; index < 2; index++) - { - // store in order of index - bool found = false; - for (auto out : unique_outs) - { - auto unique_out = loco::must_cast<luci::CircleUniqueOut *>(out); - if (unique_out->index() == index) - { - outputs_vec.push_back(get_tensor_index(unique_out)); - found = true; - break; - } - } - if (!found) - { - INTERNAL_EXN("Invalid Unique output"); - } - } - - auto inputs = ctx.builder.CreateVector(inputs_vec); - auto outputs = ctx.builder.CreateVector(outputs_vec); - auto options = CreateUniqueOptions(ctx.builder, to_circle_tensortype(node->idx_out_type())); - auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs, - circle::BuiltinOptions_UniqueOptions, options.Union()); - ctx.gd._operators.push_back(op_offset); -} - -void export_node(ExportContext &ctx, luci::CircleUnpack *node) -{ - LOGGER(l); - auto settings = luci::UserSettings::settings(); - - auto unpack_outs = loco::succs(node); - // NOTE real models may not use all of the outputs - if (static_cast<int32_t>(unpack_outs.size()) != node->num()) - { - if (settings->get(luci::UserSettings::Key::DisableValidation)) - { - WARN(l) << "Warning: export Unpack(" << node->name() << ") 'num' not same as outputs"; - } - else - assert(false); - } - - uint32_t op_idx = - ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_UNPACK, node->op_version()); - std::vector<int32_t> inputs_vec{get_tensor_index(node->value())}; - std::vector<int32_t> outputs_vec; - - for (int32_t index = 0; index < node->num(); index++) - { - // store in order of index - bool found = false; - for (auto out : unpack_outs) - { - auto unpack_out = loco::must_cast<luci::CircleUnpackOut *>(out); - if (unpack_out->index() == index) - { - outputs_vec.push_back(get_tensor_index(unpack_out)); - found = true; - break; - } - } - // NOTE real models may not use all of the outputs - if (!found) - { - if (settings->get(luci::UserSettings::Key::DisableValidation)) - { - WARN(l) << "Warning: export Unpack(" << node->name() << ") output " << index << " not used"; - } - else - assert(false); - } - } - - auto inputs = ctx.builder.CreateVector(inputs_vec); - auto outputs = ctx.builder.CreateVector(outputs_vec); - auto options = CreateUnpackOptions(ctx.builder, node->num(), node->axis()); - auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs, - circle::BuiltinOptions_UnpackOptions, options.Union()); - ctx.gd._operators.push_back(op_offset); -} - -void export_node(ExportContext &ctx, luci::CircleWhile *node) -{ - auto while_outs = loco::succs(node); - assert(while_outs.size() == node->output_count()); - - uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_WHILE, node->op_version()); - std::vector<int32_t> inputs_vec; - std::vector<int32_t> outputs_vec; - - for (uint32_t idx = 0; idx < node->input_count(); ++idx) - inputs_vec.push_back(get_tensor_index(node->input(idx))); - - for (uint32_t idx = 0; idx < node->output_count(); ++idx) - { - // store in order of index - bool found = false; - for (auto out : while_outs) - { - auto while_out = loco::must_cast<luci::CircleWhileOut *>(out); - if (while_out->index() == static_cast<int32_t>(idx)) - { - outputs_vec.push_back(get_tensor_index(while_out)); - found = true; - break; - } - } - if (!found) - { - INTERNAL_EXN("Invalid CircleWhile output"); - } - } - - auto inputs = ctx.builder.CreateVector(inputs_vec); - auto outputs = ctx.builder.CreateVector(outputs_vec); - auto options = CreateWhileOptions(ctx.builder, node->cond_branch(), node->body_branch()); - auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs, - circle::BuiltinOptions_WhileOptions, options.Union()); - ctx.gd._operators.push_back(op_offset); -} - -class ExportHelper -{ -public: - ExportHelper(ExportContext &ctx) : _ctx{ctx} - { - // DO NOTHING - } - -protected: - /** - * @brief export simple nodes - */ - void export_simple(loco::Node *node, circle::BuiltinOperator bop, circle::BuiltinOptions bot, - flatbuffers::Offset<void> options_offset) - { - export_node(_ctx, node, bop, bot, options_offset); - } - - /** - * @brief export simple nodes having void options - */ - void export_simple(loco::Node *node, circle::BuiltinOperator bop) - { - export_node(_ctx, node, bop); - } - -protected: - ExportContext &_ctx; -}; - -enum class OE -{ - ABC, - DEF, - GHIJ, - KLMN, - OPQR, - STUV, - WXYZ, - CIRC, // circle only - VIRT, // virtual -}; - -class OperationExporter final : public ExportHelper -{ -public: - OperationExporter(ExportContext &ctx) : ExportHelper(ctx) - { - // DO NOTHING - } - -public: - void export_node(luci::CircleNode *); -}; - -template <OE oe> class OpExporterLet; - -template <> -class OpExporterLet<OE::ABC> final : public luci::CircleNodeMutableVisitor<void>, - public ExportHelper -{ -public: - OpExporterLet(ExportContext &ctx) : ExportHelper(ctx) - { - // DO NOTHING - } - -public: - // NOTE visit for luci::CircleNode is added NOT to throw NYI - void visit(luci::CircleNode *) final {} - -public: - void visit(luci::CircleAbs *) final; - void visit(luci::CircleAdd *) final; - void visit(luci::CircleAddN *) final; - void visit(luci::CircleArgMax *) final; - void visit(luci::CircleArgMin *) final; - void visit(luci::CircleAveragePool2D *) final; - void visit(luci::CircleBatchMatMul *) final; - void visit(luci::CircleBatchToSpaceND *) final; - void visit(luci::CircleBidirectionalSequenceLSTM *) final; - void visit(luci::CircleCast *) final; - void visit(luci::CircleCeil *) final; - void visit(luci::CircleConcatenation *) final; - void visit(luci::CircleConst *) final{/* skip, everything is done in exportOpDefinedTensors */}; - void visit(luci::CircleConv2D *) final; - void visit(luci::CircleCos *) final; - void visit(luci::CircleCustom *) final; -}; - -template <> -class OpExporterLet<OE::DEF> final : public luci::CircleNodeMutableVisitor<void>, - public ExportHelper -{ -public: - OpExporterLet(ExportContext &ctx) : ExportHelper(ctx) - { - // DO NOTHING - } - -public: - void visit(luci::CircleNode *) final {} - -public: - void visit(luci::CircleDepthToSpace *) final; - void visit(luci::CircleDepthwiseConv2D *) final; - void visit(luci::CircleDequantize *) final; - void visit(luci::CircleDiv *) final; - void visit(luci::CircleElu *) final; - void visit(luci::CircleEqual *) final; - void visit(luci::CircleExp *) final; - void visit(luci::CircleExpandDims *) final; - void visit(luci::CircleFakeQuant *) final; - void visit(luci::CircleFill *) final; - void visit(luci::CircleFloor *) final; - void visit(luci::CircleFloorDiv *) final; - void visit(luci::CircleFloorMod *) final; - void visit(luci::CircleFullyConnected *) final; -}; - -template <> -class OpExporterLet<OE::GHIJ> final : public luci::CircleNodeMutableVisitor<void>, - public ExportHelper -{ -public: - OpExporterLet(ExportContext &ctx) : ExportHelper(ctx) - { - // DO NOTHING - } - -public: - void visit(luci::CircleNode *) final {} - -public: - void visit(luci::CircleGather *) final; - void visit(luci::CircleGatherNd *) final; - void visit(luci::CircleGreater *) final; - void visit(luci::CircleGreaterEqual *) final; - void visit(luci::CircleIf *) final; -}; - -template <> -class OpExporterLet<OE::KLMN> final : public luci::CircleNodeMutableVisitor<void>, - public ExportHelper -{ -public: - OpExporterLet(ExportContext &ctx) : ExportHelper(ctx) - { - // DO NOTHING - } - -public: - void visit(luci::CircleNode *) final {} - -public: - void visit(luci::CircleL2Normalize *) final; - void visit(luci::CircleL2Pool2D *) final; - void visit(luci::CircleLeakyRelu *) final; - void visit(luci::CircleLess *) final; - void visit(luci::CircleLessEqual *) final; - void visit(luci::CircleLocalResponseNormalization *) final; - void visit(luci::CircleLog *) final; - void visit(luci::CircleLogicalAnd *) final; - void visit(luci::CircleLogicalNot *) final; - void visit(luci::CircleLogicalOr *) final; - void visit(luci::CircleLogistic *) final; - void visit(luci::CircleLogSoftmax *) final; - void visit(luci::CircleMatrixDiag *) final; - void visit(luci::CircleMatrixSetDiag *) final; - void visit(luci::CircleMaximum *) final; - void visit(luci::CircleMaxPool2D *) final; - void visit(luci::CircleMean *) final; - void visit(luci::CircleMinimum *) final; - void visit(luci::CircleMirrorPad *) final; - void visit(luci::CircleMul *) final; - void visit(luci::CircleNeg *) final; - void visit(luci::CircleNonMaxSuppressionV4 *) final; - void visit(luci::CircleNonMaxSuppressionV5 *) final; - void visit(luci::CircleNotEqual *) final; -}; - -template <> -class OpExporterLet<OE::OPQR> final : public luci::CircleNodeMutableVisitor<void>, - public ExportHelper -{ -public: - OpExporterLet(ExportContext &ctx) : ExportHelper(ctx) - { - // DO NOTHING - } - -public: - void visit(luci::CircleNode *) final {} - -public: - void visit(luci::CircleOneHot *) final; - void visit(luci::CirclePack *) final; - void visit(luci::CirclePad *) final; - void visit(luci::CirclePadV2 *) final; - void visit(luci::CirclePow *) final; - void visit(luci::CirclePRelu *) final; - void visit(luci::CircleQuantize *) final; - void visit(luci::CircleRange *) final; - void visit(luci::CircleRank *) final; - void visit(luci::CircleReduceAny *) final; - void visit(luci::CircleReduceMax *) final; - void visit(luci::CircleReduceMin *) final; - void visit(luci::CircleReduceProd *) final; - void visit(luci::CircleRelu *) final; - void visit(luci::CircleRelu6 *) final; - void visit(luci::CircleReluN1To1 *) final; - void visit(luci::CircleReshape *) final; - void visit(luci::CircleResizeBilinear *) final; - void visit(luci::CircleResizeNearestNeighbor *) final; - void visit(luci::CircleReverseSequence *) final; - void visit(luci::CircleReverseV2 *) final; - void visit(luci::CircleRound *) final; - void visit(luci::CircleRsqrt *) final; -}; - -template <> -class OpExporterLet<OE::STUV> final : public luci::CircleNodeMutableVisitor<void>, - public ExportHelper -{ -public: - OpExporterLet(ExportContext &ctx) : ExportHelper(ctx) - { - // DO NOTHING - } - -public: - void visit(luci::CircleNode *) final {} - -public: - void visit(luci::CircleScatterNd *) final; - void visit(luci::CircleSegmentSum *) final; - void visit(luci::CircleSelect *) final; - void visit(luci::CircleSelectV2 *) final; - void visit(luci::CircleShape *) final; - void visit(luci::CircleSin *) final; - void visit(luci::CircleSlice *) final; - void visit(luci::CircleSoftmax *) final; - void visit(luci::CircleSpaceToBatchND *) final; - void visit(luci::CircleSpaceToDepth *) final; - void visit(luci::CircleSparseToDense *) final; - void visit(luci::CircleSplit *) final; - void visit(luci::CircleSplitV *) final; - void visit(luci::CircleSqrt *) final; - void visit(luci::CircleSquare *) final; - void visit(luci::CircleSquaredDifference *) final; - void visit(luci::CircleSqueeze *) final; - void visit(luci::CircleStridedSlice *) final; - void visit(luci::CircleSub *) final; - void visit(luci::CircleSum *) final; - void visit(luci::CircleTanh *) final; - void visit(luci::CircleTile *) final; - void visit(luci::CircleTopKV2 *) final; - void visit(luci::CircleTranspose *) final; - void visit(luci::CircleTransposeConv *) final; - void visit(luci::CircleUnidirectionalSequenceLSTM *) final; - void visit(luci::CircleUnique *) final; - void visit(luci::CircleUnpack *) final; -}; - -template <> -class OpExporterLet<OE::WXYZ> final : public luci::CircleNodeMutableVisitor<void>, - public ExportHelper -{ -public: - OpExporterLet(ExportContext &ctx) : ExportHelper(ctx) - { - // DO NOTHING - } - -public: - void visit(luci::CircleNode *) final {} - -public: - void visit(luci::CircleWhere *) final; - void visit(luci::CircleWhile *) final; - void visit(luci::CircleZerosLike *) final; -}; - -template <> -class OpExporterLet<OE::CIRC> final : public luci::CircleNodeMutableVisitor<void>, - public ExportHelper -{ -public: - OpExporterLet(ExportContext &ctx) : ExportHelper(ctx) - { - // DO NOTHING - } - -public: - void visit(luci::CircleNode *) final {} - -public: - // Circle only - void visit(luci::CircleBCQFullyConnected *) final; - void visit(luci::CircleBCQGather *) final; - void visit(luci::CircleInstanceNorm *) final; -}; - -template <> -class OpExporterLet<OE::VIRT> final : public luci::CircleNodeMutableVisitor<void>, - public ExportHelper -{ -public: - OpExporterLet(ExportContext &ctx) : ExportHelper(ctx) - { - // DO NOTHING - } - -public: - void visit(luci::CircleNode *) final {} - -public: - // Virtual - void visit(luci::CircleInput *) final {} - void visit(luci::CircleOutput *) final {} - void visit(luci::CircleOutputDummy *) final {} - void visit(luci::CircleOutputExclude *) final {} - // Virtual for multiple-outputs - void visit(luci::CircleBidirectionalSequenceLSTMOut *) final {} - void visit(luci::CircleCustomOut *) final {} - void visit(luci::CircleIfOut *) final {} - void visit(luci::CircleNonMaxSuppressionV4Out *) final {} - void visit(luci::CircleNonMaxSuppressionV5Out *) final {} - void visit(luci::CircleSplitOut *) final {} - void visit(luci::CircleSplitVOut *) final {} - void visit(luci::CircleTopKV2Out *) final {} - void visit(luci::CircleUniqueOut *) final {} - void visit(luci::CircleUnpackOut *) final {} - void visit(luci::CircleWhileOut *) final {} -}; - -void OperationExporter::export_node(luci::CircleNode *node) -{ - // TODO revise return type to bool and return if handled -#define VISIT_OE(GRP) \ - do \ - { \ - OpExporterLet<OE::GRP> oe(_ctx); \ - node->accept(&oe); \ - } while (false) - - VISIT_OE(ABC); - VISIT_OE(DEF); - VISIT_OE(GHIJ); - VISIT_OE(KLMN); - VISIT_OE(OPQR); - VISIT_OE(STUV); - VISIT_OE(WXYZ); - VISIT_OE(CIRC); - VISIT_OE(VIRT); - -#undef VISIT_OE -} - -void OpExporterLet<OE::ABC>::visit(luci::CircleAbs *node) -{ - export_simple(node, circle::BuiltinOperator_ABS, circle::BuiltinOptions_AbsOptions, - CreateAbsOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::ABC>::visit(luci::CircleAdd *node) -{ - export_simple( - node, circle::BuiltinOperator_ADD, circle::BuiltinOptions_AddOptions, - CreateAddOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union()); -} - -void OpExporterLet<OE::ABC>::visit(luci::CircleAddN *node) { export_node(_ctx, node); } - -void OpExporterLet<OE::ABC>::visit(luci::CircleArgMax *node) -{ - export_simple( - node, circle::BuiltinOperator_ARG_MAX, circle::BuiltinOptions_ArgMaxOptions, - CreateArgMaxOptions(_ctx.builder, to_circle_tensortype(node->output_type())).Union()); -} - -void OpExporterLet<OE::ABC>::visit(luci::CircleArgMin *node) -{ - export_simple( - node, circle::BuiltinOperator_ARG_MIN, circle::BuiltinOptions_ArgMinOptions, - CreateArgMinOptions(_ctx.builder, to_circle_tensortype(node->output_type())).Union()); -} - -void OpExporterLet<OE::ABC>::visit(luci::CircleAveragePool2D *node) -{ - export_pool_2d<luci::CircleAveragePool2D>(_ctx, node, circle::BuiltinOperator_AVERAGE_POOL_2D); -} - -void OpExporterLet<OE::ABC>::visit(luci::CircleBatchMatMul *node) -{ - export_simple(node, circle::BuiltinOperator_BATCH_MATMUL, - circle::BuiltinOptions_BatchMatMulOptions, - CreateBatchMatMulOptions(_ctx.builder, node->adj_x(), node->adj_y()).Union()); -} - -void OpExporterLet<OE::ABC>::visit(luci::CircleBidirectionalSequenceLSTM *node) -{ - auto bidi_lstm_outs = loco::succs(node); - assert((bidi_lstm_outs.size() == 1) || (bidi_lstm_outs.size() == 2)); - uint32_t op_idx = _ctx.md.registerBuiltinOpcode( - circle::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, node->op_version()); - - std::vector<int32_t> inputs_vec{get_tensor_index(node->input())}; - std::vector<int32_t> outputs_vec; - - for (int32_t index = 0; index < 2; index++) - { - // store in order of index - bool found = false; - for (auto out : bidi_lstm_outs) - { - auto bidi_lstm_out = loco::must_cast<luci::CircleBidirectionalSequenceLSTMOut *>(out); - if (bidi_lstm_out->index() == index) - { - outputs_vec.push_back(get_tensor_index(bidi_lstm_out)); - found = true; - break; - } - } - if (!found) - { - INTERNAL_EXN("Invalid BidirectionalSequenceLSTM output"); - } - } - - auto inputs = _ctx.builder.CreateVector(inputs_vec); - auto outputs = _ctx.builder.CreateVector(outputs_vec); - auto options = CreateBidirectionalSequenceLSTMOptions( - _ctx.builder, to_circle_actfunc(node->fusedActivationFunction()), node->cell_clip(), - node->proj_clip(), node->merge_outputs(), node->time_major(), - node->asymmetric_quantize_inputs()); - auto op_offset = - CreateOperator(_ctx.builder, op_idx, inputs, outputs, - circle::BuiltinOptions_BidirectionalSequenceLSTMOptions, options.Union()); - _ctx.gd._operators.push_back(op_offset); -} - -void OpExporterLet<OE::ABC>::visit(luci::CircleCast *node) { export_node(_ctx, node); } - -void OpExporterLet<OE::ABC>::visit(luci::CircleCeil *node) -{ - export_simple(node, circle::BuiltinOperator_CEIL); -} - -void OpExporterLet<OE::ABC>::visit(luci::CircleConcatenation *node) { export_node(_ctx, node); } - -void OpExporterLet<OE::ABC>::visit(luci::CircleBatchToSpaceND *node) -{ - export_simple(node, circle::BuiltinOperator_BATCH_TO_SPACE_ND, - circle::BuiltinOptions_BatchToSpaceNDOptions, - CreateBatchToSpaceNDOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::ABC>::visit(luci::CircleConv2D *node) -{ - export_simple(node, circle::BuiltinOperator_CONV_2D, circle::BuiltinOptions_Conv2DOptions, - CreateConv2DOptions(_ctx.builder, getOpPadding(node->padding()), - node->stride()->w(), node->stride()->h(), - to_circle_actfunc(node->fusedActivationFunction()), - node->dilation()->w(), node->dilation()->h()) - .Union()); -} - -void OpExporterLet<OE::ABC>::visit(luci::CircleCos *node) -{ - export_simple(node, circle::BuiltinOperator_COS, circle::BuiltinOptions_CosOptions, - CreateCosOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::ABC>::visit(luci::CircleCustom *node) { export_node(_ctx, node); } - -void OpExporterLet<OE::DEF>::visit(luci::CircleDepthToSpace *node) -{ - export_simple(node, circle::BuiltinOperator_DEPTH_TO_SPACE, - circle::BuiltinOptions_DepthToSpaceOptions, - CreateDepthToSpaceOptions(_ctx.builder, node->block_size()).Union()); -} - -void OpExporterLet<OE::DEF>::visit(luci::CircleDepthwiseConv2D *node) -{ - export_simple( - node, circle::BuiltinOperator_DEPTHWISE_CONV_2D, circle::BuiltinOptions_DepthwiseConv2DOptions, - CreateDepthwiseConv2DOptions(_ctx.builder, getOpPadding(node->padding()), node->stride()->w(), - node->stride()->h(), node->depthMultiplier(), - to_circle_actfunc(node->fusedActivationFunction()), - node->dilation()->w(), node->dilation()->h()) - .Union()); -} - -void OpExporterLet<OE::DEF>::visit(luci::CircleDequantize *node) -{ - export_simple(node, circle::BuiltinOperator_DEQUANTIZE); -} - -void OpExporterLet<OE::DEF>::visit(luci::CircleDiv *node) -{ - export_simple( - node, circle::BuiltinOperator_DIV, circle::BuiltinOptions_DivOptions, - CreateDivOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union()); -} - -void OpExporterLet<OE::DEF>::visit(luci::CircleElu *node) -{ - export_simple(node, circle::BuiltinOperator_ELU); -} - -void OpExporterLet<OE::DEF>::visit(luci::CircleEqual *node) -{ - export_simple(node, circle::BuiltinOperator_EQUAL, circle::BuiltinOptions_EqualOptions, - CreateEqualOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::DEF>::visit(luci::CircleExp *node) -{ - export_simple(node, circle::BuiltinOperator_EXP, circle::BuiltinOptions_ExpOptions, - CreateExpOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::DEF>::visit(luci::CircleExpandDims *node) -{ - export_simple(node, circle::BuiltinOperator_EXPAND_DIMS, circle::BuiltinOptions_ExpandDimsOptions, - CreateExpandDimsOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::DEF>::visit(luci::CircleFakeQuant *node) -{ - export_simple(node, circle::BuiltinOperator_FAKE_QUANT, circle::BuiltinOptions_FakeQuantOptions, - CreateFakeQuantOptions(_ctx.builder, node->min(), node->max(), node->num_bits(), - node->narrow_range()) - .Union()); -} - -void OpExporterLet<OE::DEF>::visit(luci::CircleFill *node) -{ - export_simple(node, circle::BuiltinOperator_FILL, circle::BuiltinOptions_FillOptions, - CreateFillOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::DEF>::visit(luci::CircleFloor *node) -{ - export_simple(node, circle::BuiltinOperator_FLOOR); -} - -void OpExporterLet<OE::DEF>::visit(luci::CircleFloorDiv *node) -{ - export_simple(node, circle::BuiltinOperator_FLOOR_DIV, circle::BuiltinOptions_FloorDivOptions, - CreateFloorDivOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::DEF>::visit(luci::CircleFloorMod *node) -{ - export_simple(node, circle::BuiltinOperator_FLOOR_MOD, circle::BuiltinOptions_FloorModOptions, - CreateFloorModOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::DEF>::visit(luci::CircleFullyConnected *node) -{ - export_simple( - node, circle::BuiltinOperator_FULLY_CONNECTED, circle::BuiltinOptions_FullyConnectedOptions, - CreateFullyConnectedOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction()), - to_circle_weightsformat(node->weights_format())) - .Union()); -} - -void OpExporterLet<OE::GHIJ>::visit(luci::CircleGather *node) -{ - export_simple(node, circle::BuiltinOperator_GATHER, circle::BuiltinOptions_GatherOptions, - CreateGatherOptions(_ctx.builder, node->axis()).Union()); -} - -void OpExporterLet<OE::GHIJ>::visit(luci::CircleGatherNd *node) -{ - export_simple(node, circle::BuiltinOperator_GATHER_ND, circle::BuiltinOptions_GatherNdOptions, - CreateGatherNdOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::GHIJ>::visit(luci::CircleGreater *node) -{ - export_simple(node, circle::BuiltinOperator_GREATER, circle::BuiltinOptions_GreaterOptions, - CreateGreaterOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::GHIJ>::visit(luci::CircleGreaterEqual *node) -{ - export_simple(node, circle::BuiltinOperator_GREATER_EQUAL, - circle::BuiltinOptions_GreaterEqualOptions, - CreateGreaterEqualOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::GHIJ>::visit(luci::CircleIf *node) { export_node(_ctx, node); } - -void OpExporterLet<OE::KLMN>::visit(luci::CircleL2Normalize *node) -{ - export_simple( - node, circle::BuiltinOperator_L2_NORMALIZATION, circle::BuiltinOptions_L2NormOptions, - CreateL2NormOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union()); -} - -void OpExporterLet<OE::KLMN>::visit(luci::CircleL2Pool2D *node) -{ - export_pool_2d<luci::CircleL2Pool2D>(_ctx, node, circle::BuiltinOperator_L2_POOL_2D); -} - -void OpExporterLet<OE::KLMN>::visit(luci::CircleLeakyRelu *node) -{ - export_simple(node, circle::BuiltinOperator_LEAKY_RELU, circle::BuiltinOptions_LeakyReluOptions, - CreateLeakyReluOptions(_ctx.builder, node->alpha()).Union()); -} - -void OpExporterLet<OE::KLMN>::visit(luci::CircleLess *node) -{ - export_simple(node, circle::BuiltinOperator_LESS, circle::BuiltinOptions_LessOptions, - CreateLessOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::KLMN>::visit(luci::CircleLessEqual *node) -{ - export_simple(node, circle::BuiltinOperator_LESS_EQUAL, circle::BuiltinOptions_LessEqualOptions, - CreateLessEqualOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::KLMN>::visit(luci::CircleLocalResponseNormalization *node) -{ - export_simple(node, circle::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, - circle::BuiltinOptions_LocalResponseNormalizationOptions, - CreateLocalResponseNormalizationOptions(_ctx.builder, node->radius(), node->bias(), - node->alpha(), node->beta()) - .Union()); -} - -void OpExporterLet<OE::KLMN>::visit(luci::CircleLog *node) -{ - export_simple(node, circle::BuiltinOperator_LOG); -} - -void OpExporterLet<OE::KLMN>::visit(luci::CircleLogicalAnd *node) -{ - export_simple(node, circle::BuiltinOperator_LOGICAL_AND, circle::BuiltinOptions_LogicalAndOptions, - CreateLogicalAndOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::KLMN>::visit(luci::CircleLogicalNot *node) -{ - export_simple(node, circle::BuiltinOperator_LOGICAL_NOT, circle::BuiltinOptions_LogicalNotOptions, - CreateLogicalNotOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::KLMN>::visit(luci::CircleLogicalOr *node) -{ - export_simple(node, circle::BuiltinOperator_LOGICAL_OR, circle::BuiltinOptions_LogicalOrOptions, - CreateLogicalOrOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::KLMN>::visit(luci::CircleLogistic *node) -{ - export_simple(node, circle::BuiltinOperator_LOGISTIC); -} - -void OpExporterLet<OE::KLMN>::visit(luci::CircleLogSoftmax *node) -{ - export_simple(node, circle::BuiltinOperator_LOG_SOFTMAX, circle::BuiltinOptions_LogSoftmaxOptions, - CreateLogSoftmaxOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::KLMN>::visit(luci::CircleMatrixDiag *node) -{ - export_simple(node, circle::BuiltinOperator_MATRIX_DIAG, circle::BuiltinOptions_MatrixDiagOptions, - CreateMatrixDiagOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::KLMN>::visit(luci::CircleMatrixSetDiag *node) -{ - export_simple(node, circle::BuiltinOperator_MATRIX_SET_DIAG, - circle::BuiltinOptions_MatrixSetDiagOptions, - CreateMatrixSetDiagOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::KLMN>::visit(luci::CircleMaximum *node) -{ - export_simple(node, circle::BuiltinOperator_MAXIMUM, circle::BuiltinOptions_MaximumMinimumOptions, - CreateMaximumMinimumOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::KLMN>::visit(luci::CircleMaxPool2D *node) -{ - export_pool_2d<luci::CircleMaxPool2D>(_ctx, node, circle::BuiltinOperator_MAX_POOL_2D); -} - -void OpExporterLet<OE::KLMN>::visit(luci::CircleMean *node) -{ - export_simple(node, circle::BuiltinOperator_MEAN, circle::BuiltinOptions_ReducerOptions, - CreateReducerOptions(_ctx.builder, node->keep_dims()).Union()); -} - -void OpExporterLet<OE::KLMN>::visit(luci::CircleMinimum *node) -{ - export_simple(node, circle::BuiltinOperator_MINIMUM, circle::BuiltinOptions_MaximumMinimumOptions, - CreateMaximumMinimumOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::KLMN>::visit(luci::CircleMirrorPad *node) -{ - export_simple( - node, circle::BuiltinOperator_MIRROR_PAD, circle::BuiltinOptions_MirrorPadOptions, - CreateMirrorPadOptions(_ctx.builder, to_circle_mirrorpadmode(node->mode())).Union()); -} - -void OpExporterLet<OE::KLMN>::visit(luci::CircleMul *node) -{ - export_simple( - node, circle::BuiltinOperator_MUL, circle::BuiltinOptions_MulOptions, - CreateMulOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union()); -} - -void OpExporterLet<OE::KLMN>::visit(luci::CircleNeg *node) -{ - export_simple(node, circle::BuiltinOperator_NEG, circle::BuiltinOptions_NegOptions, - CreateNegOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::KLMN>::visit(luci::CircleNonMaxSuppressionV4 *node) -{ - export_node(_ctx, node); -} - -void OpExporterLet<OE::KLMN>::visit(luci::CircleNonMaxSuppressionV5 *node) -{ - export_node(_ctx, node); -} - -void OpExporterLet<OE::KLMN>::visit(luci::CircleNotEqual *node) -{ - export_simple(node, circle::BuiltinOperator_NOT_EQUAL, circle::BuiltinOptions_NotEqualOptions, - CreateNotEqualOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::OPQR>::visit(luci::CircleOneHot *node) -{ - export_simple(node, circle::BuiltinOperator_ONE_HOT, circle::BuiltinOptions_OneHotOptions, - CreateOneHotOptions(_ctx.builder, node->axis()).Union()); -} - -void OpExporterLet<OE::OPQR>::visit(luci::CirclePack *node) -{ - export_simple(node, circle::BuiltinOperator_PACK, circle::BuiltinOptions_PackOptions, - CreatePackOptions(_ctx.builder, node->values_count(), node->axis()).Union()); -} - -void OpExporterLet<OE::OPQR>::visit(luci::CirclePad *node) -{ - export_simple(node, circle::BuiltinOperator_PAD, circle::BuiltinOptions_PadOptions, - CreatePadOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::OPQR>::visit(luci::CirclePadV2 *node) -{ - export_simple(node, circle::BuiltinOperator_PADV2, circle::BuiltinOptions_PadV2Options, - CreatePadV2Options(_ctx.builder).Union()); -} - -void OpExporterLet<OE::OPQR>::visit(luci::CirclePow *node) -{ - export_simple(node, circle::BuiltinOperator_POW, circle::BuiltinOptions_PowOptions, - CreatePowOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::OPQR>::visit(luci::CirclePRelu *node) -{ - export_simple(node, circle::BuiltinOperator_PRELU); -} - -void OpExporterLet<OE::OPQR>::visit(luci::CircleQuantize *node) -{ - export_simple(node, circle::BuiltinOperator_QUANTIZE); -} - -void OpExporterLet<OE::OPQR>::visit(luci::CircleRange *node) -{ - export_simple(node, circle::BuiltinOperator_RANGE, circle::BuiltinOptions_RangeOptions, - CreateRangeOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::OPQR>::visit(luci::CircleRank *node) -{ - export_simple(node, circle::BuiltinOperator_RANK, circle::BuiltinOptions_RankOptions, - CreateRankOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::OPQR>::visit(luci::CircleReduceAny *node) -{ - export_simple(node, circle::BuiltinOperator_REDUCE_ANY, circle::BuiltinOptions_ReducerOptions, - CreateReducerOptions(_ctx.builder, node->keep_dims()).Union()); -} - -void OpExporterLet<OE::OPQR>::visit(luci::CircleReduceMax *node) -{ - export_simple(node, circle::BuiltinOperator_REDUCE_MAX, circle::BuiltinOptions_ReducerOptions, - CreateReducerOptions(_ctx.builder, node->keep_dims()).Union()); -} - -void OpExporterLet<OE::OPQR>::visit(luci::CircleReduceMin *node) -{ - export_simple(node, circle::BuiltinOperator_REDUCE_MIN, circle::BuiltinOptions_ReducerOptions, - CreateReducerOptions(_ctx.builder, node->keep_dims()).Union()); -} - -void OpExporterLet<OE::OPQR>::visit(luci::CircleReduceProd *node) -{ - export_simple(node, circle::BuiltinOperator_REDUCE_PROD, circle::BuiltinOptions_ReducerOptions, - CreateReducerOptions(_ctx.builder, node->keep_dims()).Union()); -} - -void OpExporterLet<OE::OPQR>::visit(luci::CircleRelu *node) -{ - export_simple(node, circle::BuiltinOperator_RELU); -} - -void OpExporterLet<OE::OPQR>::visit(luci::CircleRelu6 *node) -{ - export_simple(node, circle::BuiltinOperator_RELU6); -} - -void OpExporterLet<OE::OPQR>::visit(luci::CircleReluN1To1 *node) -{ - export_simple(node, circle::BuiltinOperator_RELU_N1_TO_1); -} - -void OpExporterLet<OE::OPQR>::visit(luci::CircleReshape *node) -{ - auto new_shape = _ctx.builder.CreateVector<int32_t>( - node->newShape()->rank(), [node](size_t i) { return node->newShape()->dim(i); }); - - export_simple(node, circle::BuiltinOperator_RESHAPE, circle::BuiltinOptions_ReshapeOptions, - CreateReshapeOptions(_ctx.builder, new_shape).Union()); -} - -void OpExporterLet<OE::OPQR>::visit(luci::CircleResizeBilinear *node) -{ - export_simple( - node, circle::BuiltinOperator_RESIZE_BILINEAR, circle::BuiltinOptions_ResizeBilinearOptions, - CreateResizeBilinearOptions(_ctx.builder, node->align_corners(), node->half_pixel_centers()) - .Union()); -} - -void OpExporterLet<OE::OPQR>::visit(luci::CircleResizeNearestNeighbor *node) -{ - export_simple(node, circle::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, - circle::BuiltinOptions_ResizeNearestNeighborOptions, - CreateResizeNearestNeighborOptions(_ctx.builder, node->align_corners()).Union()); -} - -void OpExporterLet<OE::OPQR>::visit(luci::CircleReverseSequence *node) -{ - export_simple( - node, circle::BuiltinOperator_REVERSE_SEQUENCE, circle::BuiltinOptions_ReverseSequenceOptions, - CreateReverseSequenceOptions(_ctx.builder, node->seq_axis(), node->batch_axis()).Union()); -} - -void OpExporterLet<OE::OPQR>::visit(luci::CircleReverseV2 *node) { export_node(_ctx, node); } - -void OpExporterLet<OE::OPQR>::visit(luci::CircleRound *node) -{ - export_simple(node, circle::BuiltinOperator_ROUND); -} - -void OpExporterLet<OE::OPQR>::visit(luci::CircleRsqrt *node) -{ - export_simple(node, circle::BuiltinOperator_RSQRT); -} - -void OpExporterLet<OE::STUV>::visit(luci::CircleScatterNd *node) -{ - export_simple(node, circle::BuiltinOperator_SCATTER_ND, circle::BuiltinOptions_ScatterNdOptions, - CreateScatterNdOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::STUV>::visit(luci::CircleSegmentSum *node) -{ - export_simple(node, circle::BuiltinOperator_SEGMENT_SUM, circle::BuiltinOptions_SegmentSumOptions, - CreateSegmentSumOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::STUV>::visit(luci::CircleSelect *node) -{ - export_simple(node, circle::BuiltinOperator_SELECT, circle::BuiltinOptions_SelectOptions, - CreateSelectOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::STUV>::visit(luci::CircleSelectV2 *node) -{ - export_simple(node, circle::BuiltinOperator_SELECT_V2, circle::BuiltinOptions_SelectV2Options, - CreateSelectV2Options(_ctx.builder).Union()); -} - -void OpExporterLet<OE::STUV>::visit(luci::CircleShape *node) -{ - export_simple(node, circle::BuiltinOperator_SHAPE, circle::BuiltinOptions_ShapeOptions, - CreateShapeOptions(_ctx.builder, to_circle_tensortype(node->out_type())).Union()); -} - -void OpExporterLet<OE::STUV>::visit(luci::CircleSin *node) -{ - export_simple(node, circle::BuiltinOperator_SIN); -} - -void OpExporterLet<OE::STUV>::visit(luci::CircleSlice *node) -{ - export_simple(node, circle::BuiltinOperator_SLICE, circle::BuiltinOptions_SliceOptions, - CreateSliceOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::STUV>::visit(luci::CircleSoftmax *node) -{ - export_simple(node, circle::BuiltinOperator_SOFTMAX, circle::BuiltinOptions_SoftmaxOptions, - CreateSoftmaxOptions(_ctx.builder, node->beta()).Union()); -} - -void OpExporterLet<OE::STUV>::visit(luci::CircleSpaceToBatchND *node) -{ - export_simple(node, circle::BuiltinOperator_SPACE_TO_BATCH_ND, - circle::BuiltinOptions_SpaceToBatchNDOptions, - CreateSpaceToBatchNDOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::STUV>::visit(luci::CircleSpaceToDepth *node) -{ - export_simple(node, circle::BuiltinOperator_SPACE_TO_DEPTH, - circle::BuiltinOptions_SpaceToDepthOptions, - CreateSpaceToDepthOptions(_ctx.builder, node->block_size()).Union()); -} - -void OpExporterLet<OE::STUV>::visit(luci::CircleSparseToDense *node) -{ - export_simple(node, circle::BuiltinOperator_SPARSE_TO_DENSE, - circle::BuiltinOptions_SparseToDenseOptions, - CreateSparseToDenseOptions(_ctx.builder, node->validate_indices()).Union()); -} - -void OpExporterLet<OE::STUV>::visit(luci::CircleSplit *node) { export_node(_ctx, node); } - -void OpExporterLet<OE::STUV>::visit(luci::CircleSplitV *node) { export_node(_ctx, node); } - -void OpExporterLet<OE::STUV>::visit(luci::CircleSqrt *node) -{ - export_simple(node, circle::BuiltinOperator_SQRT); -} - -void OpExporterLet<OE::STUV>::visit(luci::CircleSquare *node) -{ - export_simple(node, circle::BuiltinOperator_SQUARE, circle::BuiltinOptions_SquareOptions, - CreateSquareOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::STUV>::visit(luci::CircleSquaredDifference *node) -{ - export_simple(node, circle::BuiltinOperator_SQUARED_DIFFERENCE, - circle::BuiltinOptions_SquaredDifferenceOptions, - CreateSquaredDifferenceOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::STUV>::visit(luci::CircleSqueeze *node) -{ - auto squeeze_dims = _ctx.builder.CreateVector<int32_t>(node->squeeze_dims()); - export_simple(node, circle::BuiltinOperator_SQUEEZE, circle::BuiltinOptions_SqueezeOptions, - CreateSqueezeOptions(_ctx.builder, squeeze_dims).Union()); -} - -void OpExporterLet<OE::STUV>::visit(luci::CircleStridedSlice *node) -{ - export_simple(node, circle::BuiltinOperator_STRIDED_SLICE, - circle::BuiltinOptions_StridedSliceOptions, - CreateStridedSliceOptions(_ctx.builder, node->begin_mask(), node->end_mask(), - node->ellipsis_mask(), node->new_axis_mask(), - node->shrink_axis_mask()) - .Union()); -} - -void OpExporterLet<OE::STUV>::visit(luci::CircleSub *node) -{ - export_simple( - node, circle::BuiltinOperator_SUB, circle::BuiltinOptions_SubOptions, - CreateSubOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union()); -} - -void OpExporterLet<OE::STUV>::visit(luci::CircleSum *node) -{ - export_simple(node, circle::BuiltinOperator_SUM, circle::BuiltinOptions_ReducerOptions, - CreateReducerOptions(_ctx.builder, node->keep_dims()).Union()); -} - -void OpExporterLet<OE::STUV>::visit(luci::CircleTanh *node) -{ - export_simple(node, circle::BuiltinOperator_TANH); -} - -void OpExporterLet<OE::STUV>::visit(luci::CircleTile *node) -{ - export_simple(node, circle::BuiltinOperator_TILE, circle::BuiltinOptions_TileOptions, - CreateTileOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::STUV>::visit(luci::CircleTopKV2 *node) { export_node(_ctx, node); } - -void OpExporterLet<OE::STUV>::visit(luci::CircleTranspose *node) -{ - export_simple(node, circle::BuiltinOperator_TRANSPOSE, circle::BuiltinOptions_TransposeOptions, - CreateTransposeOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::STUV>::visit(luci::CircleTransposeConv *node) -{ - export_simple(node, circle::BuiltinOperator_TRANSPOSE_CONV, - circle::BuiltinOptions_TransposeConvOptions, - CreateTransposeConvOptions(_ctx.builder, getOpPadding(node->padding()), - node->stride()->w(), node->stride()->h()) - .Union()); -} - -void OpExporterLet<OE::STUV>::visit(luci::CircleUnidirectionalSequenceLSTM *node) -{ - export_simple(node, circle::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, - circle::BuiltinOptions_UnidirectionalSequenceLSTMOptions, - CreateUnidirectionalSequenceLSTMOptions( - _ctx.builder, to_circle_actfunc(node->fusedActivationFunction()), - node->cell_clip(), node->proj_clip(), node->time_major(), - node->asymmetric_quantize_inputs()) - .Union()); -} - -void OpExporterLet<OE::STUV>::visit(luci::CircleUnique *node) { export_node(_ctx, node); } - -void OpExporterLet<OE::STUV>::visit(luci::CircleUnpack *node) { export_node(_ctx, node); } - -void OpExporterLet<OE::WXYZ>::visit(luci::CircleWhere *node) -{ - export_simple(node, circle::BuiltinOperator_WHERE, circle::BuiltinOptions_WhereOptions, - CreateWhereOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::WXYZ>::visit(luci::CircleWhile *node) { export_node(_ctx, node); } - -void OpExporterLet<OE::WXYZ>::visit(luci::CircleZerosLike *node) -{ - export_simple(node, circle::BuiltinOperator_ZEROS_LIKE, circle::BuiltinOptions_ZerosLikeOptions, - CreateZerosLikeOptions(_ctx.builder).Union()); -} - -void OpExporterLet<OE::CIRC>::visit(luci::CircleBCQFullyConnected *node) -{ - export_simple(node, circle::BuiltinOperator_BCQ_FULLY_CONNECTED, - circle::BuiltinOptions_BCQFullyConnectedOptions, - CreateBCQFullyConnectedOptions(_ctx.builder, node->weights_hidden_size(), - to_circle_actfunc(node->fusedActivationFunction())) - .Union()); -} - -void OpExporterLet<OE::CIRC>::visit(luci::CircleBCQGather *node) -{ - export_simple( - node, circle::BuiltinOperator_BCQ_GATHER, circle::BuiltinOptions_BCQGatherOptions, - CreateBCQGatherOptions(_ctx.builder, node->input_hidden_size(), node->axis()).Union()); -} - -void OpExporterLet<OE::CIRC>::visit(luci::CircleInstanceNorm *node) +namespace luci { - export_simple(node, circle::BuiltinOperator_INSTANCE_NORM, - circle::BuiltinOptions_InstanceNormOptions, - CreateInstanceNormOptions(_ctx.builder, node->epsilon(), - to_circle_actfunc(node->fusedActivationFunction())) - .Union()); -} -void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder, SerializedModelData &md, - SerializedGraphData &gd, uint32_t node_position) +void exportNodes(loco::Graph *g, flatbuffers::FlatBufferBuilder &builder, SerializedModelData &md, + SerializedGraphData &gd) { - if (auto circle_node = dynamic_cast<luci::CircleNode *>(node)) + uint32_t node_position = 0; + for (auto node : loco::postorder_traversal(loco::output_nodes(g))) { ExportContext ctx{builder, md, gd}; - OperationExporter exporter{ctx}; + OperationExporterRule exporter_rule{ctx}; + + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + circle_node->accept(&exporter_rule); const auto ops_size = gd._operators.size(); - exporter.export_node(circle_node); if (has_origin(circle_node) && ops_size != gd._operators.size()) { const auto node_id = gd._operators.size() - 1; @@ -1716,25 +60,7 @@ void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder, Seria } md._metadata.add_execution_plan_table(node_position, execution_plan_vector); } - } - else - { - INTERNAL_EXN("Node with unsupported dialect found"); - } -} -} // namespace - -namespace luci -{ - -void exportNodes(loco::Graph *g, FlatBufferBuilder &builder, SerializedModelData &md, - SerializedGraphData &gd) -{ - uint32_t node_position = 0; - for (auto node : loco::postorder_traversal(loco::output_nodes(g))) - { - exportNode(node, builder, md, gd, node_position); node_position++; } } diff --git a/compiler/luci/export/src/CircleOperationExporter.h b/compiler/luci/export/src/CircleOperationExporter.h index de6abfc54..f2b3cfd6b 100644 --- a/compiler/luci/export/src/CircleOperationExporter.h +++ b/compiler/luci/export/src/CircleOperationExporter.h @@ -17,7 +17,7 @@ #ifndef __CIRCLE_OPERATION_EXPORTER_H__ #define __CIRCLE_OPERATION_EXPORTER_H__ -#include "CircleExporterUtils.h" +#include "SerializedData.h" #include <loco/IR/Graph.h> diff --git a/compiler/luci/export/src/CircleOperationExporterRule.cpp b/compiler/luci/export/src/CircleOperationExporterRule.cpp new file mode 100644 index 000000000..8dc59fa9c --- /dev/null +++ b/compiler/luci/export/src/CircleOperationExporterRule.cpp @@ -0,0 +1,277 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleOperationExporterRule.h" +#include "CircleBuiltinTypesExtractor.h" +#include "Check.h" + +#include <loco/IR/Graph.h> +#include <luci/IR/CircleNode.h> +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNodeVisitor.h> +#include <oops/InternalExn.h> + +#include <vector> + +namespace +{ +class OutputVectorExtractor final : public luci::CircleNodeMutableVisitor<std::vector<int32_t>> +{ +public: + OutputVectorExtractor() + { + // DO NOTHING + } + +public: + std::vector<int32_t> visit(luci::CircleNode *node) final + { + std::vector<int32_t> outputs_vec{luci::get_tensor_index(node)}; + return outputs_vec; + } + + std::vector<int32_t> visit(luci::CircleBidirectionalSequenceLSTM *node) final + { + auto bidi_lstm_outs = loco::succs(node); + assert((bidi_lstm_outs.size() == 1) || (bidi_lstm_outs.size() == 2)); + + std::vector<int32_t> outputs_vec(bidi_lstm_outs.size()); + + for (auto out : bidi_lstm_outs) + { + auto bidi_lstm_out = loco::must_cast<luci::CircleBidirectionalSequenceLSTMOut *>(out); + if (bidi_lstm_out->index() >= int32_t(bidi_lstm_outs.size())) + INTERNAL_EXN("Invalid BidirectionalSequenceLSTM output"); + outputs_vec[bidi_lstm_out->index()] = luci::get_tensor_index(bidi_lstm_out); + } + + return outputs_vec; + } + + std::vector<int32_t> visit(luci::CircleCustom *node) final + { + auto custom_outputs = loco::succs(node); + assert(custom_outputs.size() == node->numOutputs()); + + std::vector<int32_t> outputs_vec(node->numOutputs()); + + for (auto out : custom_outputs) + { + auto custom_out = loco::must_cast<luci::CircleCustomOut *>(out); + if (custom_out->index() >= int32_t(node->numOutputs())) + INTERNAL_EXN("Invalid Custom output"); + outputs_vec[custom_out->index()] = luci::get_tensor_index(custom_out); + } + + return outputs_vec; + } + + std::vector<int32_t> visit(luci::CircleIf *node) final + { + auto if_outs = loco::succs(node); + assert(if_outs.size() == node->output_count()); + + std::vector<int32_t> outputs_vec(node->output_count()); + + for (auto out : if_outs) + { + auto if_out = loco::must_cast<luci::CircleIfOut *>(out); + if (if_out->index() >= int32_t(node->output_count())) + INTERNAL_EXN("Invalid If output"); + outputs_vec[if_out->index()] = luci::get_tensor_index(if_out); + } + + return outputs_vec; + } + + std::vector<int32_t> visit(luci::CircleNonMaxSuppressionV4 *node) final + { + auto nms_outs = loco::succs(node); + assert(nms_outs.size() == 2); + + std::vector<int32_t> outputs_vec(2); + + for (auto out : nms_outs) + { + auto nms_out = loco::must_cast<luci::CircleNonMaxSuppressionV4Out *>(out); + if (nms_out->index() >= 2) + INTERNAL_EXN("Invalid NonMaxSuppressionV4 output"); + outputs_vec[nms_out->index()] = luci::get_tensor_index(nms_out); + } + + return outputs_vec; + } + + std::vector<int32_t> visit(luci::CircleNonMaxSuppressionV5 *node) final + { + auto nms_outs = loco::succs(node); + assert(nms_outs.size() == 3); + + std::vector<int32_t> outputs_vec(3); + + for (auto out : nms_outs) + { + auto nms_out = loco::must_cast<luci::CircleNonMaxSuppressionV5Out *>(out); + if (nms_out->index() >= 3) + INTERNAL_EXN("Invalid NonMaxSuppressionV5 output"); + outputs_vec[nms_out->index()] = luci::get_tensor_index(nms_out); + } + + return outputs_vec; + } + + std::vector<int32_t> visit(luci::CircleSplit *node) final + { + auto split_outs = loco::succs(node); + assert(int32_t(split_outs.size()) == node->num_split()); + + std::vector<int32_t> outputs_vec(node->num_split()); + + for (auto out : split_outs) + { + auto split_out = loco::must_cast<luci::CircleSplitOut *>(out); + if (split_out->index() >= node->num_split()) + INTERNAL_EXN("Invalid Split output"); + outputs_vec[split_out->index()] = luci::get_tensor_index(split_out); + } + + return outputs_vec; + } + + std::vector<int32_t> visit(luci::CircleSplitV *node) final + { + auto split_outs = loco::succs(node); + assert(int32_t(split_outs.size()) == node->num_split()); + + std::vector<int32_t> outputs_vec(node->num_split()); + + for (auto out : split_outs) + { + auto split_out = loco::must_cast<luci::CircleSplitVOut *>(out); + if (split_out->index() >= node->num_split()) + INTERNAL_EXN("Invalid SplitV output"); + outputs_vec[split_out->index()] = luci::get_tensor_index(split_out); + } + + return outputs_vec; + } + + std::vector<int32_t> visit(luci::CircleTopKV2 *node) final + { + auto topkv2_outs = loco::succs(node); + assert(topkv2_outs.size() == 2); + + std::vector<int32_t> outputs_vec(2); + + for (auto out : topkv2_outs) + { + auto topkv2_out = loco::must_cast<luci::CircleTopKV2Out *>(out); + if (topkv2_out->index() >= 2) + INTERNAL_EXN("Invalid TopKV2 output"); + outputs_vec[topkv2_out->index()] = luci::get_tensor_index(topkv2_out); + } + + return outputs_vec; + } + + std::vector<int32_t> visit(luci::CircleUnique *node) final + { + auto unique_outs = loco::succs(node); + assert(unique_outs.size() == 2); + + std::vector<int32_t> outputs_vec(2); + + for (auto out : unique_outs) + { + auto unique_out = loco::must_cast<luci::CircleUniqueOut *>(out); + if (unique_out->index() >= 2) + INTERNAL_EXN("Invalid Unique output"); + outputs_vec[unique_out->index()] = luci::get_tensor_index(unique_out); + } + + return outputs_vec; + } + + std::vector<int32_t> visit(luci::CircleUnpack *node) final + { + auto unpack_outs = loco::succs(node); + assert(int32_t(unpack_outs.size()) == node->num()); + + std::vector<int32_t> outputs_vec(node->num()); + + for (auto out : unpack_outs) + { + auto unpack_out = loco::must_cast<luci::CircleUnpackOut *>(out); + if (unpack_out->index() >= node->num()) + INTERNAL_EXN("Invalid Unpack output"); + outputs_vec[unpack_out->index()] = luci::get_tensor_index(unpack_out); + } + + return outputs_vec; + } + + std::vector<int32_t> visit(luci::CircleWhile *node) final + { + auto while_outs = loco::succs(node); + assert(while_outs.size() == node->output_count()); + + std::vector<int32_t> outputs_vec(node->output_count()); + + for (auto out : while_outs) + { + auto while_out = loco::must_cast<luci::CircleWhileOut *>(out); + if (while_out->index() >= int32_t(node->output_count())) + INTERNAL_EXN("Invalid While output"); + outputs_vec[while_out->index()] = luci::get_tensor_index(while_out); + } + + return outputs_vec; + } +}; + +} // namespace + +namespace luci +{ + +void OperationExporterRule::visit(luci::CircleNode *node) +{ + auto op_idx = _ctx.md.registerBuiltinOpcode(circle_builtin_operator(node), + circle_custom_code(node), node->op_version()); + + std::vector<int32_t> inputs_vec; + for (uint32_t i = 0; i < node->arity(); ++i) + inputs_vec.push_back(luci::get_tensor_index(node->arg(i))); + auto inputs = _ctx.builder.CreateVector(inputs_vec); + + OutputVectorExtractor outputs_vec_extractor; + auto outputs_vec = node->accept(&outputs_vec_extractor); + auto outputs = _ctx.builder.CreateVector(outputs_vec); + + auto builtin_options = circle_builtin_options(node); + + luci::BuiltinOptionsExtractor builtin_options_extractor(_ctx.builder); + auto options_offset = node->accept(&builtin_options_extractor); + + // If node is not CircleCustom, null offset(0) is returned + auto custom_options = circle_custom_options(_ctx.builder, node); + + auto op_offset = circle::CreateOperator(_ctx.builder, op_idx, inputs, outputs, builtin_options, + options_offset, custom_options); + _ctx.gd._operators.push_back(op_offset); +} + +} // namespace luci diff --git a/compiler/luci/export/src/CircleOperationExporterRule.h b/compiler/luci/export/src/CircleOperationExporterRule.h new file mode 100644 index 000000000..23e7546cf --- /dev/null +++ b/compiler/luci/export/src/CircleOperationExporterRule.h @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CIRCLE_OPERATION_EXPORTER_RULE_H__ +#define __CIRCLE_OPERATION_EXPORTER_RULE_H__ + +#include "CircleOperationExporter.h" + +#include <luci/IR/CircleNode.h> +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNodeVisitor.h> + +namespace luci +{ + +struct ExportContext +{ + flatbuffers::FlatBufferBuilder &builder; + luci::SerializedModelData &md; + luci::SerializedGraphData &gd; +}; + +class OperationExporterRule final : public luci::CircleNodeMutableVisitor<void> +{ +public: + OperationExporterRule(ExportContext &ctx) : _ctx{ctx} + { + // DO NOTHING + } + +public: + // Default export rule + void visit(luci::CircleNode *node) final; + + // Non-virtual + void visit(luci::CircleConst *) final{/* skip, everything is done in exportOpDefinedTensors */}; + + // Virtual + void visit(luci::CircleInput *) final {} + void visit(luci::CircleOutput *) final {} + void visit(luci::CircleOutputDummy *) final {} + void visit(luci::CircleOutputExclude *) final {} + // Virtual for multiple-outputs + void visit(luci::CircleBidirectionalSequenceLSTMOut *) final {} + void visit(luci::CircleCustomOut *) final {} + void visit(luci::CircleIfOut *) final {} + void visit(luci::CircleNonMaxSuppressionV4Out *) final {} + void visit(luci::CircleNonMaxSuppressionV5Out *) final {} + void visit(luci::CircleSplitOut *) final {} + void visit(luci::CircleSplitVOut *) final {} + void visit(luci::CircleTopKV2Out *) final {} + void visit(luci::CircleUniqueOut *) final {} + void visit(luci::CircleUnpackOut *) final {} + void visit(luci::CircleVariable *) final {} + void visit(luci::CircleWhileOut *) final {} + +protected: + ExportContext &_ctx; +}; + +} // namespace luci + +#endif // __CIRCLE_OPERATION_EXPORTER_RULE_H__ diff --git a/compiler/luci/export/src/CircleOps.lst b/compiler/luci/export/src/CircleOps.lst new file mode 100644 index 000000000..1b6909303 --- /dev/null +++ b/compiler/luci/export/src/CircleOps.lst @@ -0,0 +1,154 @@ +#ifndef CIRCLE_NODE +#error "Define CIRCLE_NODE" +#endif // CIRCLE_NODE + +#ifndef CIRCLE_VNODE +#error "Define CIRCLE_VNODE" +#endif // CIRCLE_VNODE + +// +// PLEASE SORT NODE DECLS IN ALPHABETICAL ORDER +// +// NOTE : CIRCLE_VNODE does not have any additional parameters +// because they are not circle builtin operators +// Please add parameters when they are needed. +// +// CIRCLE_NODE(CircleNode, circle::BuiltinOperator, circle::BuiltinOptions) +// CIRCLE_VNODE(CircleNode) +// + +CIRCLE_NODE(CircleAbs, BuiltinOperator_ABS, BuiltinOptions_AbsOptions) +CIRCLE_NODE(CircleAdd, BuiltinOperator_ADD, BuiltinOptions_AddOptions) +CIRCLE_NODE(CircleAddN, BuiltinOperator_ADD_N, BuiltinOptions_AddNOptions) +CIRCLE_NODE(CircleArgMax, BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions) +CIRCLE_NODE(CircleArgMin, BuiltinOperator_ARG_MIN, BuiltinOptions_ArgMinOptions) +CIRCLE_NODE(CircleAveragePool2D, BuiltinOperator_AVERAGE_POOL_2D , BuiltinOptions_Pool2DOptions) +CIRCLE_NODE(CircleBatchToSpaceND, BuiltinOperator_BATCH_TO_SPACE_ND, BuiltinOptions_BatchToSpaceNDOptions) +CIRCLE_NODE(CircleBatchMatMul, BuiltinOperator_BATCH_MATMUL, BuiltinOptions_BatchMatMulOptions) +CIRCLE_NODE(CircleBidirectionalSequenceLSTM, BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, BuiltinOptions_BidirectionalSequenceLSTMOptions) +CIRCLE_NODE(CircleCast, BuiltinOperator_CAST, BuiltinOptions_CastOptions) +CIRCLE_NODE(CircleCeil, BuiltinOperator_CEIL, BuiltinOptions_NONE) +CIRCLE_NODE(CircleConcatenation, BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions) +CIRCLE_NODE(CircleConv2D, BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions) +CIRCLE_NODE(CircleCos, BuiltinOperator_COS, BuiltinOptions_CosOptions) +CIRCLE_NODE(CircleCustom, BuiltinOperator_CUSTOM, BuiltinOptions_NONE) +CIRCLE_NODE(CircleDepthToSpace, BuiltinOperator_DEPTH_TO_SPACE, BuiltinOptions_DepthToSpaceOptions) +CIRCLE_NODE(CircleDepthwiseConv2D, BuiltinOperator_DEPTHWISE_CONV_2D, BuiltinOptions_DepthwiseConv2DOptions) +CIRCLE_NODE(CircleDequantize, BuiltinOperator_DEQUANTIZE, BuiltinOptions_DequantizeOptions) +CIRCLE_NODE(CircleDiv, BuiltinOperator_DIV, BuiltinOptions_DivOptions) +CIRCLE_NODE(CircleElu, BuiltinOperator_ELU, BuiltinOptions_NONE) +CIRCLE_NODE(CircleEqual, BuiltinOperator_EQUAL, BuiltinOptions_EqualOptions) +CIRCLE_NODE(CircleExp, BuiltinOperator_EXP, BuiltinOptions_ExpOptions) +CIRCLE_NODE(CircleExpandDims, BuiltinOperator_EXPAND_DIMS, BuiltinOptions_ExpandDimsOptions) +CIRCLE_NODE(CircleFakeQuant, BuiltinOperator_FAKE_QUANT, BuiltinOptions_FakeQuantOptions) +CIRCLE_NODE(CircleFill, BuiltinOperator_FILL, BuiltinOptions_FillOptions) +CIRCLE_NODE(CircleFloor, BuiltinOperator_FLOOR, BuiltinOptions_NONE) +CIRCLE_NODE(CircleFloorDiv, BuiltinOperator_FLOOR_DIV, BuiltinOptions_FloorDivOptions) +CIRCLE_NODE(CircleFloorMod, BuiltinOperator_FLOOR_MOD, BuiltinOptions_FloorModOptions) +CIRCLE_NODE(CircleFullyConnected, BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions) +CIRCLE_NODE(CircleGather, BuiltinOperator_GATHER, BuiltinOptions_GatherOptions) +CIRCLE_NODE(CircleGatherNd, BuiltinOperator_GATHER_ND, BuiltinOptions_GatherNdOptions) +CIRCLE_NODE(CircleGreater, BuiltinOperator_GREATER, BuiltinOptions_GreaterOptions) +CIRCLE_NODE(CircleGreaterEqual, BuiltinOperator_GREATER_EQUAL, BuiltinOptions_GreaterEqualOptions) +CIRCLE_NODE(CircleIf, BuiltinOperator_IF, BuiltinOptions_IfOptions) +CIRCLE_NODE(CircleL2Normalize, BuiltinOperator_L2_NORMALIZATION, BuiltinOptions_L2NormOptions) +CIRCLE_NODE(CircleL2Pool2D, BuiltinOperator_L2_POOL_2D, BuiltinOptions_Pool2DOptions) +CIRCLE_NODE(CircleLeakyRelu, BuiltinOperator_LEAKY_RELU, BuiltinOptions_LeakyReluOptions) +CIRCLE_NODE(CircleLess, BuiltinOperator_LESS, BuiltinOptions_LessOptions) +CIRCLE_NODE(CircleLessEqual, BuiltinOperator_LESS_EQUAL, BuiltinOptions_LessEqualOptions) +CIRCLE_NODE(CircleLocalResponseNormalization, BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, BuiltinOptions_LocalResponseNormalizationOptions) +CIRCLE_NODE(CircleLog, BuiltinOperator_LOG, BuiltinOptions_NONE) +CIRCLE_NODE(CircleLogicalAnd, BuiltinOperator_LOGICAL_AND, BuiltinOptions_LogicalAndOptions) +CIRCLE_NODE(CircleLogicalNot, BuiltinOperator_LOGICAL_NOT, BuiltinOptions_LogicalNotOptions) +CIRCLE_NODE(CircleLogicalOr, BuiltinOperator_LOGICAL_OR, BuiltinOptions_LogicalOrOptions) +CIRCLE_NODE(CircleLogistic, BuiltinOperator_LOGISTIC, BuiltinOptions_NONE) +CIRCLE_NODE(CircleLogSoftmax, BuiltinOperator_LOG_SOFTMAX, BuiltinOptions_LogSoftmaxOptions) +CIRCLE_NODE(CircleMatrixDiag, BuiltinOperator_MATRIX_DIAG, BuiltinOptions_MatrixDiagOptions) +CIRCLE_NODE(CircleMaxPool2D, BuiltinOperator_MAX_POOL_2D, BuiltinOptions_Pool2DOptions) +CIRCLE_NODE(CircleMatrixSetDiag, BuiltinOperator_MATRIX_SET_DIAG, BuiltinOptions_MatrixSetDiagOptions) +CIRCLE_NODE(CircleMaximum, BuiltinOperator_MAXIMUM, BuiltinOptions_MaximumMinimumOptions) +CIRCLE_NODE(CircleMean, BuiltinOperator_MEAN, BuiltinOptions_ReducerOptions) +CIRCLE_NODE(CircleMinimum, BuiltinOperator_MINIMUM, BuiltinOptions_MaximumMinimumOptions) +CIRCLE_NODE(CircleMirrorPad, BuiltinOperator_MIRROR_PAD, BuiltinOptions_MirrorPadOptions) +CIRCLE_NODE(CircleMul, BuiltinOperator_MUL, BuiltinOptions_MulOptions) +CIRCLE_NODE(CircleNeg, BuiltinOperator_NEG, BuiltinOptions_NegOptions) +CIRCLE_NODE(CircleNonMaxSuppressionV4, BuiltinOperator_NON_MAX_SUPPRESSION_V4, BuiltinOptions_NonMaxSuppressionV4Options) +CIRCLE_NODE(CircleNonMaxSuppressionV5, BuiltinOperator_NON_MAX_SUPPRESSION_V5, BuiltinOptions_NonMaxSuppressionV5Options) +CIRCLE_NODE(CircleNotEqual, BuiltinOperator_NOT_EQUAL, BuiltinOptions_NotEqualOptions) +CIRCLE_NODE(CircleOneHot, BuiltinOperator_ONE_HOT, BuiltinOptions_OneHotOptions) +CIRCLE_NODE(CirclePack, BuiltinOperator_PACK, BuiltinOptions_PackOptions) +CIRCLE_NODE(CirclePad, BuiltinOperator_PAD, BuiltinOptions_PadOptions) +CIRCLE_NODE(CirclePadV2, BuiltinOperator_PADV2, BuiltinOptions_PadV2Options) +CIRCLE_NODE(CirclePow, BuiltinOperator_POW, BuiltinOptions_PowOptions) +CIRCLE_NODE(CirclePRelu, BuiltinOperator_PRELU, BuiltinOptions_NONE) +CIRCLE_NODE(CircleQuantize, BuiltinOperator_QUANTIZE, BuiltinOptions_QuantizeOptions) +CIRCLE_NODE(CircleRange, BuiltinOperator_RANGE, BuiltinOptions_RangeOptions) +CIRCLE_NODE(CircleRank, BuiltinOperator_RANK, BuiltinOptions_RankOptions) +CIRCLE_NODE(CircleReduceAny, BuiltinOperator_REDUCE_ANY, BuiltinOptions_ReducerOptions) +CIRCLE_NODE(CircleReduceMax, BuiltinOperator_REDUCE_MAX, BuiltinOptions_ReducerOptions) +CIRCLE_NODE(CircleReduceMin, BuiltinOperator_REDUCE_MIN, BuiltinOptions_ReducerOptions) +CIRCLE_NODE(CircleReduceProd, BuiltinOperator_REDUCE_PROD, BuiltinOptions_ReducerOptions) +CIRCLE_NODE(CircleRelu, BuiltinOperator_RELU, BuiltinOptions_NONE) +CIRCLE_NODE(CircleRelu6, BuiltinOperator_RELU6, BuiltinOptions_NONE) +CIRCLE_NODE(CircleReluN1To1, BuiltinOperator_RELU_N1_TO_1, BuiltinOptions_NONE) +CIRCLE_NODE(CircleReshape, BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions) +CIRCLE_NODE(CircleResizeBilinear, BuiltinOperator_RESIZE_BILINEAR, BuiltinOptions_ResizeBilinearOptions) +CIRCLE_NODE(CircleResizeNearestNeighbor, BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, BuiltinOptions_ResizeNearestNeighborOptions) +CIRCLE_NODE(CircleReverseSequence, BuiltinOperator_REVERSE_SEQUENCE, BuiltinOptions_ReverseSequenceOptions) +CIRCLE_NODE(CircleReverseV2, BuiltinOperator_REVERSE_V2, BuiltinOptions_ReverseV2Options) +CIRCLE_NODE(CircleRound, BuiltinOperator_ROUND, BuiltinOptions_NONE) +CIRCLE_NODE(CircleRsqrt, BuiltinOperator_RSQRT, BuiltinOptions_NONE) +CIRCLE_NODE(CircleScatterNd, BuiltinOperator_SCATTER_ND, BuiltinOptions_ScatterNdOptions) +CIRCLE_NODE(CircleSegmentSum, BuiltinOperator_SEGMENT_SUM, BuiltinOptions_SegmentSumOptions) +CIRCLE_NODE(CircleSelect, BuiltinOperator_SELECT, BuiltinOptions_SelectOptions) +CIRCLE_NODE(CircleSelectV2, BuiltinOperator_SELECT_V2, BuiltinOptions_SelectV2Options) +CIRCLE_NODE(CircleShape, BuiltinOperator_SHAPE, BuiltinOptions_ShapeOptions) +CIRCLE_NODE(CircleSin, BuiltinOperator_SIN, BuiltinOptions_NONE) +CIRCLE_NODE(CircleSlice, BuiltinOperator_SLICE, BuiltinOptions_SliceOptions) +CIRCLE_NODE(CircleSoftmax, BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions) +CIRCLE_NODE(CircleSpaceToBatchND, BuiltinOperator_SPACE_TO_BATCH_ND, BuiltinOptions_SpaceToBatchNDOptions) +CIRCLE_NODE(CircleSpaceToDepth, BuiltinOperator_SPACE_TO_DEPTH, BuiltinOptions_SpaceToDepthOptions) +CIRCLE_NODE(CircleSparseToDense, BuiltinOperator_SPARSE_TO_DENSE, BuiltinOptions_SparseToDenseOptions) +CIRCLE_NODE(CircleSplit, BuiltinOperator_SPLIT, BuiltinOptions_SplitOptions) +CIRCLE_NODE(CircleSplitV, BuiltinOperator_SPLIT_V, BuiltinOptions_SplitVOptions) +CIRCLE_NODE(CircleSqrt, BuiltinOperator_SQRT, BuiltinOptions_NONE) +CIRCLE_NODE(CircleSquare, BuiltinOperator_SQUARE, BuiltinOptions_SquareOptions) +CIRCLE_NODE(CircleSquaredDifference, BuiltinOperator_SQUARED_DIFFERENCE, BuiltinOptions_SquaredDifferenceOptions) +CIRCLE_NODE(CircleSqueeze, BuiltinOperator_SQUEEZE, BuiltinOptions_SqueezeOptions) +CIRCLE_NODE(CircleStridedSlice, BuiltinOperator_STRIDED_SLICE, BuiltinOptions_StridedSliceOptions) +CIRCLE_NODE(CircleSub, BuiltinOperator_SUB, BuiltinOptions_SubOptions) +CIRCLE_NODE(CircleSum, BuiltinOperator_SUM, BuiltinOptions_ReducerOptions) +CIRCLE_NODE(CircleSVDF, BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions) +CIRCLE_NODE(CircleTanh, BuiltinOperator_TANH, BuiltinOptions_NONE) +CIRCLE_NODE(CircleTile, BuiltinOperator_TILE, BuiltinOptions_TileOptions) +CIRCLE_NODE(CircleTopKV2, BuiltinOperator_TOPK_V2, BuiltinOptions_TopKV2Options) +CIRCLE_NODE(CircleTranspose, BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions) +CIRCLE_NODE(CircleTransposeConv, BuiltinOperator_TRANSPOSE_CONV, BuiltinOptions_TransposeConvOptions) +CIRCLE_NODE(CircleUnidirectionalSequenceLSTM, BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, BuiltinOptions_UnidirectionalSequenceLSTMOptions) +CIRCLE_NODE(CircleUnique, BuiltinOperator_UNIQUE, BuiltinOptions_UniqueOptions) +CIRCLE_NODE(CircleUnpack, BuiltinOperator_UNPACK, BuiltinOptions_UnpackOptions) +CIRCLE_NODE(CircleWhere, BuiltinOperator_WHERE, BuiltinOptions_WhereOptions) +CIRCLE_NODE(CircleWhile, BuiltinOperator_WHILE, BuiltinOptions_WhileOptions) +CIRCLE_NODE(CircleZerosLike, BuiltinOperator_ZEROS_LIKE, BuiltinOptions_ZerosLikeOptions) +// Circle Only +CIRCLE_NODE(CircleBCQFullyConnected, BuiltinOperator_BCQ_FULLY_CONNECTED, BuiltinOptions_BCQFullyConnectedOptions) +CIRCLE_NODE(CircleBCQGather, BuiltinOperator_BCQ_GATHER, BuiltinOptions_BCQGatherOptions) +CIRCLE_NODE(CircleInstanceNorm, BuiltinOperator_INSTANCE_NORM, BuiltinOptions_InstanceNormOptions) +// Virtual node(s) +CIRCLE_VNODE(CircleBidirectionalSequenceLSTMOut) +CIRCLE_VNODE(CircleConst) +CIRCLE_VNODE(CircleInput) +CIRCLE_VNODE(CircleOutput) +CIRCLE_VNODE(CircleOutputDummy) +CIRCLE_VNODE(CircleOutputExclude) +CIRCLE_VNODE(CircleCustomOut) +CIRCLE_VNODE(CircleIfOut) +CIRCLE_VNODE(CircleNonMaxSuppressionV4Out) +CIRCLE_VNODE(CircleNonMaxSuppressionV5Out) +CIRCLE_VNODE(CircleSplitOut) +CIRCLE_VNODE(CircleSplitVOut) +CIRCLE_VNODE(CircleTopKV2Out) +CIRCLE_VNODE(CircleUniqueOut) +CIRCLE_VNODE(CircleUnpackOut) +CIRCLE_VNODE(CircleVariable) +CIRCLE_VNODE(CircleWhileOut) diff --git a/compiler/luci/export/src/CircleTensorExporter.cpp b/compiler/luci/export/src/CircleTensorExporter.cpp index 615402aa8..b3bb850cc 100644 --- a/compiler/luci/export/src/CircleTensorExporter.cpp +++ b/compiler/luci/export/src/CircleTensorExporter.cpp @@ -67,6 +67,9 @@ public: luci::SparsityParam *sparsityparam(void) const { return _sparsityparam; } void sparsityparam(luci::SparsityParam *sp) { _sparsityparam = sp; } + bool is_variable(void) const { return _is_variable; } + void is_variable(bool v) { _is_variable = v; } + private: std::string _name; @@ -77,6 +80,8 @@ private: luci::CircleConst *_content = nullptr; luci::CircleQuantParam *_quantparam = nullptr; luci::SparsityParam *_sparsityparam = nullptr; + + bool _is_variable = false; }; class CircleTensorContext @@ -145,6 +150,8 @@ void allocateCircleTensorInfo(CircleNode *node, CircleTensorContext &ctx) tensor_info.quantparam(node->quantparam()); tensor_info.sparsityparam(node->sparsityparam()); + tensor_info.is_variable(dynamic_cast<luci::CircleVariable *>(node) != nullptr); + set_tensor_index(node, tensor_index); ctx.emplace_back(tensor_info); @@ -592,9 +599,11 @@ void exportOpDefinedTensor(const CircleTensorInfo &info, FlatBufferBuilder &buil auto buffer_id = get_buffer_id(builder, md, info.content()); auto name_offset = builder.CreateString(info.name()); - auto tensor_offset = - CreateTensor(builder, shape_offset, info.dtype(), buffer_id, name_offset, quantparam, - /*is_variable*/ false, sparsityparam, shape_signature_offset); + + auto is_variable = info.is_variable(); + + auto tensor_offset = CreateTensor(builder, shape_offset, info.dtype(), buffer_id, name_offset, + quantparam, is_variable, sparsityparam, shape_signature_offset); gd._tensors.push_back(tensor_offset); } diff --git a/compiler/luci/export/src/SerializedData.h b/compiler/luci/export/src/SerializedData.h index a945eecf7..136a8ac49 100644 --- a/compiler/luci/export/src/SerializedData.h +++ b/compiler/luci/export/src/SerializedData.h @@ -23,7 +23,7 @@ #include <luci/IR/ExecutionPlanTable.h> #include <vector> - +#include <string> #include <unordered_map> #include <map> @@ -131,8 +131,8 @@ struct SerializedModelData final * @param builtin_code * @return idx of opcode in table of opcodes (see schema) */ - uint32_t registerBuiltinOpcode(circle::BuiltinOperator builtin_code, const int32_t op_version); - uint32_t registerCustomOpcode(const std::string &custom_op); + uint32_t registerBuiltinOpcode(circle::BuiltinOperator builtin_code, + const std::string &custom_code, const int32_t op_version); }; // Prerequisites for circle::Model object creation diff --git a/compiler/luci/import/CMakeLists.txt b/compiler/luci/import/CMakeLists.txt index 6630cab9f..1b2db23ae 100644 --- a/compiler/luci/import/CMakeLists.txt +++ b/compiler/luci/import/CMakeLists.txt @@ -12,13 +12,14 @@ target_include_directories(luci_import PUBLIC include) target_link_libraries(luci_import PUBLIC luci_lang) target_link_libraries(luci_import PUBLIC luci_profile) target_link_libraries(luci_import PUBLIC luci_plan) -target_link_libraries(luci_import PUBLIC mio_circle) +target_link_libraries(luci_import PUBLIC mio_circle04) target_link_libraries(luci_import PRIVATE luci_env) target_link_libraries(luci_import PRIVATE luci_log) target_link_libraries(luci_import PRIVATE luci_logex) target_link_libraries(luci_import PRIVATE nncc_common) target_link_libraries(luci_import PRIVATE locop) target_link_libraries(luci_import PRIVATE oops) +target_link_libraries(luci_import PRIVATE mio_circle04_helper) install(TARGETS luci_import DESTINATION lib) install(DIRECTORY include/ DESTINATION include FILES_MATCHING PATTERN "*.h") @@ -32,7 +33,3 @@ nnas_find_package(GTest REQUIRED) GTest_AddTest(luci_import_test ${TESTS}) target_include_directories(luci_import_test PRIVATE src) target_link_libraries(luci_import_test luci_import) -target_link_libraries(luci_import_test oops) -target_link_libraries(luci_import_test luci_plan) -target_link_libraries(luci_import_test luci_lang) -target_link_libraries(luci_import_test mio_circle) diff --git a/compiler/luci/import/include/luci/Import/CircleReader.h b/compiler/luci/import/include/luci/Import/CircleReader.h index fb38ba90b..a0519f661 100644 --- a/compiler/luci/import/include/luci/Import/CircleReader.h +++ b/compiler/luci/import/include/luci/Import/CircleReader.h @@ -35,19 +35,7 @@ namespace luci { -bool is_valid(const circle::OperatorCodeT &opcode); -bool is_valid(const circle::OperatorCode *opcode); - -bool is_custom(const circle::OperatorCodeT &opcode); -bool is_custom(const circle::OperatorCode *opcode); - -std::string opcode_name(const circle::OperatorCodeT &opcode); -std::string opcode_name(const circle::OperatorCode *opcode); - -const char *tensor_name(const circle::TensorT &tensor); const char *tensor_name(const circle::Tensor *tensor); - -const circle::QuantizationParametersT *tensor_quantization(const circle::TensorT &tensor); const circle::QuantizationParameters *tensor_quantization(const circle::Tensor *tensor); loco::DataType luci_datatype(circle::TensorType type); @@ -57,14 +45,13 @@ MirrorPadMode luci_mirrorpad_mode(const circle::MirrorPadMode mode); luci::CircleFullyConnected::WeightsFormat luci_weights_format(const circle::FullyConnectedOptionsWeightsFormat weights_format); std::unique_ptr<CircleQuantParam> -luci_quantparam(const circle::QuantizationParametersT *quantization); -std::unique_ptr<CircleQuantParam> luci_quantparam(const circle::QuantizationParameters *quantization); /// @brief Copy common tensor attributes such as name, type, etc. to node. -void copy_tensor_attributes(const circle::TensorT &tensor, CircleNode *node); void copy_tensor_attributes(const circle::Tensor *tensor, CircleNode *node); +std::string fb_string2std_string(const flatbuffers::String *fb_str); + /** * @brief Wrapper to use flatbuffers::Vector pointer as std::vector entity */ @@ -101,13 +88,6 @@ template <typename T> VectorWrapper<T> wrap(const flatbuffers::Vector<T> *vec) */ class CircleReader { -private: // unpack API - using CircleBuffers_t = std::vector<std::unique_ptr<circle::BufferT>>; - using CircleTensors_t = std::vector<std::unique_ptr<circle::TensorT>>; - using CircleOperators_t = std::vector<std::unique_ptr<circle::OperatorT>>; - using CircleOperatorCodes_t = std::vector<std::unique_ptr<circle::OperatorCodeT>>; - using CircleMetadata_t = std::vector<std::unique_ptr<circle::MetadataT>>; - private: // direct API using CircleBuffers = VectorWrapper<flatbuffers::Offset<circle::Buffer>>; using CircleTensors = VectorWrapper<flatbuffers::Offset<circle::Tensor>>; @@ -115,40 +95,21 @@ private: // direct API using CircleOperatorCodes = VectorWrapper<flatbuffers::Offset<circle::OperatorCode>>; using CircleMetadataSet = VectorWrapper<flatbuffers::Offset<circle::Metadata>>; - using CircleSubGraphsPtr_t = flatbuffers::Vector<flatbuffers::Offset<circle::SubGraph>>; - using CircleTensorsPtr_t = flatbuffers::Vector<flatbuffers::Offset<circle::Tensor>>; - public: CircleReader() = default; -public: // unpack API - const CircleOperatorCodes_t &opcodes() const { return _model->operator_codes; } - const CircleBuffers_t &buffers() const { return _model->buffers; } - const CircleTensors_t &tensors() const { return _current_subgraph->tensors; } - const CircleOperators_t &operators() const { return _current_subgraph->operators; } - const std::vector<int32_t> &inputs() const { return _current_subgraph->inputs; } - const std::vector<int32_t> &outputs() const { return _current_subgraph->outputs; } - const std::string &name() const { return _current_subgraph->name; } - const circle::DataFormat &data_format() const { return _current_subgraph->data_format; } - const CircleMetadata_t &metadata() const { return _model->metadata; } - - const CircleTensorsPtr_t *tensors_ptr() const { return _tensors_ptr; } - - uint32_t num_subgraph() const { return _model->subgraphs.size(); } - - circle::BuiltinOperator builtin_code(const circle::OperatorT &op) const; - std::string opcode_name(const circle::OperatorT &op) const; - public: // direct API - CircleOperatorCodes native_opcodes() const { return wrap(_native_model->operator_codes()); } - CircleBuffers native_buffers() const { return wrap(_native_model->buffers()); } - CircleTensors native_tensors() const { return wrap(_native_subgraph->tensors()); } - CircleOperators native_operators() const { return wrap(_native_subgraph->operators()); } - VectorWrapper<int32_t> native_inputs() const { return wrap(_native_subgraph->inputs()); } - VectorWrapper<int32_t> native_outputs() const { return wrap(_native_subgraph->outputs()); } - std::string native_name() const { return _native_subgraph->name()->str(); } - circle::DataFormat native_data_format() const { return _native_subgraph->data_format(); } - CircleMetadataSet native_metadata() const { return wrap(_native_model->metadata()); } + CircleOperatorCodes opcodes() const { return wrap(_model->operator_codes()); } + CircleBuffers buffers() const { return wrap(_model->buffers()); } + CircleTensors tensors() const { return wrap(_current_subgraph->tensors()); } + CircleOperators operators() const { return wrap(_current_subgraph->operators()); } + VectorWrapper<int32_t> inputs() const { return wrap(_current_subgraph->inputs()); } + VectorWrapper<int32_t> outputs() const { return wrap(_current_subgraph->outputs()); } + std::string name() const { return fb_string2std_string(_current_subgraph->name()); } + circle::DataFormat data_format() const { return _current_subgraph->data_format(); } + CircleMetadataSet metadata() const { return wrap(_model->metadata()); } + + uint32_t num_subgraph() const { return wrap(_model->subgraphs()).size(); } circle::BuiltinOperator builtin_code(const circle::Operator *op) const; std::string opcode_name(const circle::Operator *op) const; @@ -158,12 +119,8 @@ public: bool select_subgraph(uint32_t subgraph); private: - std::unique_ptr<const circle::ModelT> _model; - const circle::SubGraphT *_current_subgraph{nullptr}; - - const circle::Model *_native_model{nullptr}; - const CircleTensorsPtr_t *_tensors_ptr{nullptr}; - const circle::SubGraph *_native_subgraph{nullptr}; + const circle::Model *_model{nullptr}; + const circle::SubGraph *_current_subgraph{nullptr}; }; } // namespace luci diff --git a/compiler/luci/import/include/luci/Import/GraphBuilderRegistry.h b/compiler/luci/import/include/luci/Import/GraphBuilderRegistry.h index b8dc22fdd..93e34a56b 100644 --- a/compiler/luci/import/include/luci/Import/GraphBuilderRegistry.h +++ b/compiler/luci/import/include/luci/Import/GraphBuilderRegistry.h @@ -18,6 +18,7 @@ #define __LUCI_IMPORT_GRAPH_BUILDER_REGISTRY_H__ #include "GraphBuilderBase.h" +#include "NodeBuilder.h" #include <map> @@ -32,6 +33,11 @@ struct GraphBuilderSource * @brief Returns registered GraphBuilder pointer for operator (nullptr if not present) */ virtual const GraphBuilderBase *lookup(const circle::BuiltinOperator &op) const = 0; + + /** + * @brief Returns registered NodeBuilderBase pointer for type (nullptr if not present) + */ + virtual const NodeBuilderBase *lookup(const NodeBuilderType type) const = 0; }; /** @@ -61,6 +67,17 @@ public: return _builder_map.at(op).get(); } + /** + * @brief Returns registered NodeBuilderBase pointer for type or nullptr if not registered + */ + const NodeBuilderBase *lookup(const NodeBuilderType type) const final + { + if (_node_builders.find(type) == _node_builders.end()) + return (_parent == nullptr) ? nullptr : _parent->lookup(type); + + return _node_builders.at(type).get(); + } + static GraphBuilderRegistry &get() { static GraphBuilderRegistry me; @@ -73,11 +90,17 @@ public: _builder_map[op] = std::move(builder); } + void add(std::unique_ptr<NodeBuilderBase> &&builder) + { + _node_builders[builder->builder_type()] = std::move(builder); + } + private: const GraphBuilderSource *_parent = nullptr; private: std::map<const circle::BuiltinOperator, std::unique_ptr<GraphBuilderBase>> _builder_map; + std::map<const NodeBuilderType, std::unique_ptr<NodeBuilderBase>> _node_builders; }; } // namespace luci diff --git a/compiler/luci/import/include/luci/Import/NodeBuilder.h b/compiler/luci/import/include/luci/Import/NodeBuilder.h new file mode 100644 index 000000000..440b491b0 --- /dev/null +++ b/compiler/luci/import/include/luci/Import/NodeBuilder.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_IMPORT_NODE_BUILDER_H__ +#define __LUCI_IMPORT_NODE_BUILDER_H__ + +#include "GraphBuilderContext.h" +#include "GraphBuilderBase.h" + +#include <mio/circle/schema_generated.h> + +namespace luci +{ + +/** + * @brief Tensor types which requires separated node + */ +enum class NodeBuilderType +{ + BUFFER, + // TODO Extend this struct here if needed to add new type of NodeBuilderBase +}; + +/** + * @brief Creates nodes from given Tensor and context + */ +class NodeBuilderBase +{ +public: + virtual CircleNode *build(TensorIndex tensor_idx, GraphBuilderContext *context) const = 0; + virtual NodeBuilderType builder_type() const = 0; +}; + +/** + * @brief Placeholder for builders of tensors with different types + */ +template <NodeBuilderType Type> class TypedNodeBuilder : public NodeBuilderBase +{ +public: + NodeBuilderType builder_type() const final { return Type; } +}; + +} // namespace luci + +#endif // __LUCI_IMPORT_NODE_BUILDER_H__ diff --git a/compiler/luci/import/include/luci/Import/Nodes.h b/compiler/luci/import/include/luci/Import/Nodes.h index f7d22e7aa..7a5045ede 100644 --- a/compiler/luci/import/include/luci/Import/Nodes.h +++ b/compiler/luci/import/include/luci/Import/Nodes.h @@ -122,6 +122,7 @@ #include "Nodes/CircleStridedSlice.h" #include "Nodes/CircleSub.h" #include "Nodes/CircleSum.h" +#include "Nodes/CircleSVDF.h" #include "Nodes/CircleTanh.h" #include "Nodes/CircleTile.h" #include "Nodes/CircleTopKV2.h" @@ -130,6 +131,7 @@ #include "Nodes/CircleUnidirectionalSequenceLSTM.h" #include "Nodes/CircleUnique.h" #include "Nodes/CircleUnpack.h" +#include "Nodes/CircleVariable.h" #include "Nodes/CircleWhere.h" #include "Nodes/CircleWhile.h" #include "Nodes/CircleZerosLike.h" diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleConst.h b/compiler/luci/import/include/luci/Import/Nodes/CircleConst.h index 7d4f10a59..9e50ddbde 100644 --- a/compiler/luci/import/include/luci/Import/Nodes/CircleConst.h +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleConst.h @@ -17,20 +17,21 @@ #ifndef __LUCI_IMPORT_OP_CIRCLE_CONST_H__ #define __LUCI_IMPORT_OP_CIRCLE_CONST_H__ -#include "luci/Import/GraphBuilderContext.h" +#include "luci/Import/NodeBuilder.h" #include <luci/IR/Nodes/CircleConst.h> -/* - * @note Circle does not have Const operator. - * Methods here provide helper that creates CircleConst from - * Tensor and Buffer in circle flatbuffer file. - */ - namespace luci { -CircleConst *create_circleconst(GraphBuilderContext *context, int32_t tensor_index); +/** + * @brief Builder creates CircleConst node from Tensor with buffer. + */ +class CircleConstNodeBuilder : public TypedNodeBuilder<NodeBuilderType::BUFFER> +{ +public: + CircleNode *build(TensorIndex tensor_index, GraphBuilderContext *ctx) const final; +}; } // namespace luci diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleSVDF.h b/compiler/luci/import/include/luci/Import/Nodes/CircleSVDF.h new file mode 100644 index 000000000..a91f66019 --- /dev/null +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleSVDF.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_IMPORT_OP_CIRCLE_SVDF_H__ +#define __LUCI_IMPORT_OP_CIRCLE_SVDF_H__ + +#include "luci/Import/GraphBuilder.h" + +namespace luci +{ + +class CircleSVDFBuilder : public GraphBuilder +{ +public: + bool validate(const ValidateArgs &args) const final; + +private: + CircleNode *build_node(const circle::OperatorT &op, const std::vector<CircleNode *> &inputs, + loco::Graph *graph) const final; +}; + +} // namespace luci + +#endif // __LUCI_IMPORT_OP_CIRCLE_SVDF_H__ diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleVariable.h b/compiler/luci/import/include/luci/Import/Nodes/CircleVariable.h new file mode 100644 index 000000000..4d8961fa5 --- /dev/null +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleVariable.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_IMPORT_OP_CIRCLE_VARIABLE_H__ +#define __LUCI_IMPORT_OP_CIRCLE_VARIABLE_H__ + +#include "luci/Import/GraphBuilderContext.h" + +#include <luci/IR/Nodes/CircleVariable.h> + +/* + * @note Circle does not have node for variable tensor + * Methods here provide helper that creates CircleVariable from + * Tensor having is_variable true value. + */ + +namespace luci +{ + +CircleVariable *create_circlevariable(GraphBuilderContext *context, int32_t tensor_index); + +} // namespace luci + +#endif // __LUCI_IMPORT_OP_CIRCLE_VARIABLE_H__ diff --git a/compiler/luci/import/src/CircleImportMetadata.cpp b/compiler/luci/import/src/CircleImportMetadata.cpp index 42dcebdaa..9c1fe7356 100644 --- a/compiler/luci/import/src/CircleImportMetadata.cpp +++ b/compiler/luci/import/src/CircleImportMetadata.cpp @@ -21,8 +21,10 @@ namespace { -uint32_t read_u32(const std::vector<uint8_t> &buffer, uint32_t idx) +template <typename VECTORTYPE> uint32_t read_u32(const VECTORTYPE &buffer, uint32_t idx) { + static_assert(std::is_same<typename VECTORTYPE::value_type, uint8_t>::value, "Types mismatch!"); + uint32_t val = 0; val += (buffer.at(idx + 0) << 0 * 8); val += (buffer.at(idx + 1) << 1 * 8); @@ -37,9 +39,11 @@ namespace { // 'source_table' is decoded to std::map<uint32_t, std::string> format. -const std::map<uint32_t, std::string> -decoded_source_table(const std::vector<uint8_t> &source_table_data) +template <typename VECTORTYPE> +const std::map<uint32_t, std::string> decoded_source_table(const VECTORTYPE &source_table_data) { + static_assert(std::is_same<typename VECTORTYPE::value_type, uint8_t>::value, "Types mismatch!"); + std::map<uint32_t, std::string> source_id_name_map; uint32_t idx = 0; @@ -86,9 +90,11 @@ decoded_source_table(const std::vector<uint8_t> &source_table_data) } // 'op_table' is decoded to std::map<uint32_t, std::set<uint32_t>> format. -const std::map<uint32_t, std::set<uint32_t>> -decoded_op_table(const std::vector<uint8_t> &op_table_data) +template <typename VECTORTYPE> +const std::map<uint32_t, std::set<uint32_t>> decoded_op_table(const VECTORTYPE &op_table_data) { + static_assert(std::is_same<typename VECTORTYPE::value_type, uint8_t>::value, "Types mismatch!"); + std::map<uint32_t, std::set<uint32_t>> node_source_ids_map; uint32_t idx = 0; @@ -135,9 +141,11 @@ decoded_op_table(const std::vector<uint8_t> &op_table_data) } // 'execution_plan_table' is decoded to std::map<uint32_t, std::vector<uint32_t>> format. -const luci::ExecutionPlanTable -decoded_execution_plan(const std::vector<uint8_t> &execution_plan_data) +template <typename VECTORTYPE> +const luci::ExecutionPlanTable decoded_execution_plan(const VECTORTYPE &execution_plan_data) { + static_assert(std::is_same<typename VECTORTYPE::value_type, uint8_t>::value, "Types mismatch!"); + luci::ExecutionPlanTable execution_plan_table; uint32_t idx = 0; @@ -156,6 +164,10 @@ decoded_execution_plan(const std::vector<uint8_t> &execution_plan_data) idx += sizeof(uint32_t); uint32_t size = read_u32(execution_plan_data, idx); + + if (size == 0) + throw std::runtime_error("Op table decode error : empty execution plan entry"); + idx += sizeof(uint32_t); if (idx + sizeof(uint32_t) * size > execution_plan_data.size()) @@ -190,19 +202,22 @@ namespace luci CircleImportMetadata::CircleImportMetadata(const luci::CircleReader &reader) { - const auto &metadata = reader.metadata(); + const auto metadata = reader.metadata(); for (uint32_t i = 0; i < metadata.size(); ++i) { - const circle::MetadataT &meta = *metadata[i]; + const auto *meta = metadata[i]; + assert(meta != nullptr); - assert(meta.buffer < reader.buffers().size()); - const std::vector<uint8_t> &buffer = reader.buffers()[meta.buffer]->data; + assert(meta->buffer() < reader.buffers().size()); + assert(reader.buffers()[meta->buffer()] != nullptr); + const auto buffer = luci::wrap(reader.buffers()[meta->buffer()]->data()); - if (meta.name.compare("ONE_op_table") == 0) + assert(meta->name() != nullptr); + if (meta->name()->str().compare("ONE_op_table") == 0) _op_table = decoded_op_table(buffer); - else if (meta.name.compare("ONE_source_table") == 0) + else if (meta->name()->str().compare("ONE_source_table") == 0) _source_table = decoded_source_table(buffer); - else if (meta.name.compare("ONE_execution_plan_table") == 0) + else if (meta->name()->str().compare("ONE_execution_plan_table") == 0) _execution_plan_table = decoded_execution_plan(buffer); } } diff --git a/compiler/luci/import/src/CircleReader.cpp b/compiler/luci/import/src/CircleReader.cpp index 14917ba06..a42c3f913 100644 --- a/compiler/luci/import/src/CircleReader.cpp +++ b/compiler/luci/import/src/CircleReader.cpp @@ -16,6 +16,9 @@ #include "luci/Import/CircleReader.h" +#include <mio_circle/Helper.h> + +#include <algorithm> #include <memory> #include <sstream> #include <string> @@ -23,103 +26,14 @@ namespace luci { -bool is_valid(const circle::OperatorCodeT &opcode) -{ - circle::BuiltinOperator code = opcode.builtin_code; - return (circle::BuiltinOperator_MIN <= code && code <= circle::BuiltinOperator_MAX); -} - -bool is_valid(const circle::OperatorCode *opcode) -{ - assert(opcode != nullptr); - circle::BuiltinOperator code = opcode->builtin_code(); - return (circle::BuiltinOperator_MIN <= code && code <= circle::BuiltinOperator_MAX); -} - -bool is_custom(const circle::OperatorCodeT &opcode) -{ - circle::BuiltinOperator code = opcode.builtin_code; - return (code == circle::BuiltinOperator_CUSTOM); -} - -bool is_custom(const circle::OperatorCode *opcode) -{ - assert(opcode != nullptr); - circle::BuiltinOperator code = opcode->builtin_code(); - return (code == circle::BuiltinOperator_CUSTOM); -} - -std::string opcode_name(const circle::OperatorCodeT &opcode) -{ - if (!is_valid(opcode)) - { - std::ostringstream oss; - oss << "(invalid)"; - return oss.str(); - } - - if (is_custom(opcode)) - { - if (opcode.custom_code.empty()) - return "(invalid custom)"; - - return opcode.custom_code; - } - - circle::BuiltinOperator code = opcode.builtin_code; - return circle::EnumNameBuiltinOperator(code); -} - -std::string opcode_name(const circle::OperatorCode *opcode) -{ - assert(opcode != nullptr); - - if (!is_valid(opcode)) - { - std::ostringstream oss; - oss << "(invalid)"; - return oss.str(); - } - - if (is_custom(opcode)) - { - auto custom_code = opcode->custom_code()->str(); - if (custom_code.empty()) - return "(invalid custom)"; - - return custom_code; - } - - circle::BuiltinOperator code = opcode->builtin_code(); - return circle::EnumNameBuiltinOperator(code); -} - -const char *tensor_name(const circle::TensorT &tensor) -{ - static const char *kEmptyTensorName = "(noname)"; - - if (!tensor.name.empty()) - return tensor.name.c_str(); - - return kEmptyTensorName; -} - const char *tensor_name(const circle::Tensor *tensor) { assert(tensor != nullptr); - static const char *kEmptyTensorName = "(noname)"; - const auto tensor_name = tensor->name()->c_str(); - - if (!std::string(tensor_name).empty()) - return tensor_name; + if (tensor->name() == nullptr || std::string(tensor->name()->c_str()).empty()) + return "(noname)"; - return kEmptyTensorName; -} - -const circle::QuantizationParametersT *tensor_quantization(const circle::TensorT &tensor) -{ - return tensor.quantization.get(); + return tensor->name()->c_str(); } const circle::QuantizationParameters *tensor_quantization(const circle::Tensor *tensor) @@ -334,41 +248,6 @@ std::unique_ptr<SparsityParam> luci_sparsityparam(const circle::SparsityParamete return luci_sparsityparam(&sparsity); } -void copy_tensor_attributes(const circle::TensorT &tensor, CircleNode *node) -{ - node->name(tensor_name(tensor)); - node->dtype(luci_datatype(tensor.type)); - - assert(tensor.shape_signature.size() == 0 || - tensor.shape_signature.size() == tensor.shape.size()); - - std::vector<int32_t> dims = tensor.shape; // in NHWC - node->rank(dims.size()); - for (uint32_t r = 0; r < dims.size(); ++r) - { - if (tensor.shape_signature.size() > 0 && tensor.shape_signature.at(r) == -1) - node->dim(r).unset(); - else - node->dim(r).set(dims[r]); - } - - const auto *quantization = tensor.quantization.get(); - if (quantization != nullptr) - { - auto quantparam = luci_quantparam(quantization); - if (quantparam) - node->quantparam(std::move(quantparam)); - } - - const auto *sparsity = tensor.sparsity.get(); - if (sparsity != nullptr) - { - auto sparsityparam = luci_sparsityparam(sparsity); - if (sparsityparam) - node->sparsityparam(std::move(sparsityparam)); - } -} - void copy_tensor_attributes(const circle::Tensor *tensor, CircleNode *node) { assert(tensor != nullptr); @@ -408,63 +287,60 @@ void copy_tensor_attributes(const circle::Tensor *tensor, CircleNode *node) } } -circle::BuiltinOperator CircleReader::builtin_code(const circle::OperatorT &op) const +std::string fb_string2std_string(const flatbuffers::String *fb_str) { - const auto &op_codes = opcodes(); - uint32_t index = op.opcode_index; + return fb_str == nullptr ? "" : fb_str->str(); +} + +circle::BuiltinOperator CircleReader::builtin_code(const circle::Operator *op) const +{ + assert(op != nullptr); + + const auto op_codes = opcodes(); + uint32_t index = op->opcode_index(); assert(index < op_codes.size()); - const circle::OperatorCodeT &opcode = *op_codes[index]; + const auto opcode = op_codes[index]; + assert(opcode != nullptr); - return opcode.builtin_code; + return mio::circle::builtin_code_neutral(opcode); } -std::string CircleReader::opcode_name(const circle::OperatorT &op) const +std::string CircleReader::opcode_name(const circle::Operator *op) const { - const auto &op_codes = opcodes(); - uint32_t index = op.opcode_index; - assert(index < op_codes.size()); - const circle::OperatorCodeT &opcode = *op_codes[index]; + assert(op != nullptr); - if (!is_valid(opcode)) - { - std::ostringstream oss; - oss << "(invalid: " << index << ")"; - return oss.str(); - } + const auto op_codes = opcodes(); + uint32_t index = op->opcode_index(); + assert(index < op_codes.size()); + const auto opcode = op_codes[index]; - return ::luci::opcode_name(opcode); + return mio::circle::opcode_name(opcode); } bool CircleReader::parse(const circle::Model *model) { assert(model != nullptr); - _model.reset(model->UnPack()); - // for direct pointer access - _native_model = model; + _model = model; return true; } bool CircleReader::select_subgraph(uint32_t sgindex) { - if (_model->subgraphs.size() <= sgindex) + if (num_subgraph() <= sgindex) { assert(false); return false; } - _current_subgraph = _model->subgraphs[sgindex].get(); - // for direct pointer access - auto subgraphs = _native_model->subgraphs(); + auto subgraphs = _model->subgraphs(); assert(subgraphs != nullptr); - _native_subgraph = subgraphs->Get(sgindex); - assert(_native_subgraph != nullptr); - - _tensors_ptr = _native_subgraph->tensors(); + _current_subgraph = subgraphs->Get(sgindex); + assert(_current_subgraph != nullptr); return true; } diff --git a/compiler/luci/import/src/GraphBuilder.cpp b/compiler/luci/import/src/GraphBuilder.cpp index 356501c2f..59a08b546 100644 --- a/compiler/luci/import/src/GraphBuilder.cpp +++ b/compiler/luci/import/src/GraphBuilder.cpp @@ -29,10 +29,9 @@ CircleNode *GraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext const std::vector<int32_t> &inputs = op.inputs; const std::vector<int32_t> &outputs = op.outputs; - const auto &tensors = context->reader()->tensors(); - const auto &opcodes = context->reader()->opcodes(); - auto tensors_ptr = context->reader()->tensors_ptr(); - assert(tensors_ptr != nullptr); + const auto tensors = context->reader()->tensors(); + const auto opcodes = context->reader()->opcodes(); + assert(!tensors.null()); std::vector<CircleNode *> input_nodes; for (const int32_t input_tensor_index : inputs) @@ -60,16 +59,18 @@ CircleNode *GraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext // Set up node parameters. assert(outputs.size() == 1); { - const circle::TensorT &output_tensor = *tensors[outputs[0]]; + const auto output_tensor = tensors[outputs[0]]; + assert(output_tensor != nullptr); copy_tensor_attributes(output_tensor, node); // mark shape_status - if (tensors_ptr->Get(outputs[0])->shape() == nullptr) + if (output_tensor->shape() == nullptr) node->shape_status(ShapeStatus::NOSHAPE); else node->shape_status(ShapeStatus::VALID); // mark operator version - node->op_version(opcodes[op.opcode_index].get()->version); + assert(opcodes[op.opcode_index] != nullptr); + node->op_version(opcodes[op.opcode_index]->version()); } // Register node's only output. diff --git a/compiler/luci/import/src/GraphBuilderMultiOutput.cpp b/compiler/luci/import/src/GraphBuilderMultiOutput.cpp index be553f4c0..4df8d1e5a 100644 --- a/compiler/luci/import/src/GraphBuilderMultiOutput.cpp +++ b/compiler/luci/import/src/GraphBuilderMultiOutput.cpp @@ -30,10 +30,9 @@ CircleNode *GraphBuilderMultiOutput::build(const circle::OperatorT &op, const std::vector<int32_t> &inputs = op.inputs; const std::vector<int32_t> &outputs = op.outputs; - const auto &tensors = context->reader()->tensors(); - const auto &opcodes = context->reader()->opcodes(); - auto tensors_ptr = context->reader()->tensors_ptr(); - assert(tensors_ptr != nullptr); + const auto tensors = context->reader()->tensors(); + const auto opcodes = context->reader()->opcodes(); + assert(!tensors.null()); std::vector<CircleNode *> input_nodes; for (const int32_t input_tensor_index : inputs) @@ -64,12 +63,14 @@ CircleNode *GraphBuilderMultiOutput::build(const circle::OperatorT &op, if (output_count > 0) { // Let's use attributes from output 0 for this node - const circle::TensorT &output_tensor = *tensors[outputs[0]]; + const auto output_tensor = tensors[outputs[0]]; + assert(output_tensor != nullptr); node->name(tensor_name(output_tensor)); - node->dtype(luci_datatype(output_tensor.type)); + node->dtype(luci_datatype(output_tensor->type())); // mark operator version - node->op_version(opcodes[op.opcode_index].get()->version); + assert(opcodes[op.opcode_index] != nullptr); + node->op_version(opcodes[op.opcode_index]->version()); // NOTE We don't set quantization for multiple output nodes but to virtual outputs } @@ -77,7 +78,8 @@ CircleNode *GraphBuilderMultiOutput::build(const circle::OperatorT &op, // Create virtual outputs of Virtual Output node(s) for (uint32_t n = 0; n < output_count; ++n) { - const circle::TensorT &output_tensor = *tensors[outputs[n]]; + const auto output_tensor = tensors[outputs[n]]; + assert(output_tensor != nullptr); BuildOutArgs boa(node, n); auto *nodeout = build_out(boa); @@ -85,7 +87,7 @@ CircleNode *GraphBuilderMultiOutput::build(const circle::OperatorT &op, copy_tensor_attributes(output_tensor, nodeout); // NOTE name of CxxxOut nodes may have same name // mark shape_status - if (tensors_ptr->Get(outputs[n])->shape() == nullptr) + if (output_tensor->shape() == nullptr) nodeout->shape_status(ShapeStatus::NOSHAPE); else nodeout->shape_status(ShapeStatus::VALID); diff --git a/compiler/luci/import/src/GraphBuilderRegistry.cpp b/compiler/luci/import/src/GraphBuilderRegistry.cpp index df07d9e48..fe2d830e9 100644 --- a/compiler/luci/import/src/GraphBuilderRegistry.cpp +++ b/compiler/luci/import/src/GraphBuilderRegistry.cpp @@ -131,6 +131,7 @@ GraphBuilderRegistry::GraphBuilderRegistry() CIRCLE_NODE(STRIDED_SLICE, CircleStridedSliceGraphBuilder); // 45 CIRCLE_NODE(SUB, CircleSubGraphBuilder); // 41 CIRCLE_NODE(SUM, CircleSumGraphBuilder); // 74 + CIRCLE_NODE(SVDF, CircleSVDFBuilder); // 27 CIRCLE_NODE(TANH, CircleTanhGraphBuilder); // 28 CIRCLE_NODE(TILE, CircleTileGraphBuilder); // 69 CIRCLE_NODE(TOPK_V2, CircleTopKV2GraphBuilder); // 48 @@ -150,7 +151,6 @@ GraphBuilderRegistry::GraphBuilderRegistry() // BuiltinOperator_LSH_PROJECTION = 15, // BuiltinOperator_LSTM = 16, // BuiltinOperator_RNN = 24, - // BuiltinOperator_SVDF = 27, // BuiltinOperator_CONCAT_EMBEDDINGS = 29, // BuiltinOperator_SKIP_GRAM = 30, // BuiltinOperator_CALL = 31, @@ -161,6 +161,13 @@ GraphBuilderRegistry::GraphBuilderRegistry() // BuiltinOperator_ARG_MAX = 56, // BuiltinOperator_HARD_SWISH = 117, // BuiltinOperator_DENSIFY = 124, + + // Register builders for nodes which not handles in builders registered above. +#define CIRCLE_NODE(CLASS) add(std::make_unique<CLASS>()) + + CIRCLE_NODE(CircleConstNodeBuilder); + +#undef CIRCLE_NODE } } // namespace luci diff --git a/compiler/luci/import/src/Importer.cpp b/compiler/luci/import/src/Importer.cpp index 3f7f78591..15de03df2 100644 --- a/compiler/luci/import/src/Importer.cpp +++ b/compiler/luci/import/src/Importer.cpp @@ -23,6 +23,7 @@ #include "luci/Import/GraphBuilderRegistry.h" #include "luci/Import/CircleReader.h" #include "luci/Import/Nodes/CircleConst.h" +#include "luci/Import/Nodes/CircleVariable.h" #include <luci/IR/Module.h> #include <luci/IR/CircleNodes.h> @@ -50,18 +51,18 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r luci::GraphBuilderContext gb_context(graph, &reader, nodefinder.get(), tensoroutputs.get()); - const auto &operators = reader.operators(); - const auto &tensors = reader.tensors(); - auto tensors_ptr = reader.tensors_ptr(); - assert(tensors_ptr != nullptr); + const auto operators = reader.operators(); + const auto tensors = reader.tensors(); + assert(!tensors.null()); auto circle_metadata = std::make_unique<luci::CircleImportMetadata>(reader); // build a cache to identify if a tensor is output of an operator // if this is set, we should not create a CircleConst for this tensor for (uint32_t i = 0; i < operators.size(); ++i) { - const circle::OperatorT &op = *operators[i]; - const auto &outputs = op.outputs; + const auto op = operators[i]; + assert(op != nullptr); + const auto outputs = luci::wrap(op->outputs()); for (uint32_t j = 0; j < outputs.size(); ++j) { @@ -77,10 +78,11 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r { auto input_node = graph->nodes()->create<luci::CircleInput>(); assert(input_node != nullptr); - const circle::TensorT &tensor = *tensors[input]; + const auto tensor = tensors[input]; + assert(tensor != nullptr); luci::copy_tensor_attributes(tensor, input_node); - if (tensors_ptr->Get(input)->shape() == nullptr) + if (tensor->shape() == nullptr) input_node->shape_status(luci::ShapeStatus::NOSHAPE); else input_node->shape_status(luci::ShapeStatus::VALID); @@ -101,16 +103,18 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r // Data type graph_input->dtype(input_node->dtype()); - assert(tensor.shape_signature.size() == 0 || - tensor.shape_signature.size() == tensor.shape.size()); + const auto tensor_shape_signature = luci::wrap(tensor->shape_signature()); + const auto tensor_shape = luci::wrap(tensor->shape()); + assert(tensor_shape_signature.size() == 0 || + tensor_shape_signature.size() == tensor_shape.size()); // Shape of GraphInput auto input_shape = std::make_unique<loco::TensorShape>(); - const std::vector<int32_t> &input_dims = tensor.shape; // in NHWC + const auto &input_dims = tensor_shape; // in NHWC input_shape->rank(input_dims.size()); for (uint32_t r = 0; r < input_dims.size(); ++r) { - if (tensor.shape_signature.size() > 0 && tensor.shape_signature.at(r) == -1) + if (tensor_shape_signature.size() > 0 && tensor_shape_signature.at(r) == -1) input_shape->dim(r).unset(); else input_shape->dim(r).set(input_dims[r]); @@ -118,15 +122,28 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r graph_input->shape(std::move(input_shape)); } - // Create CircleConst nodes for constant tensors. + // Create CircleNodes for constant tensors. // NOTE Origin is intentionally not provided for constants. + auto const_builder = source.lookup(luci::NodeBuilderType::BUFFER); + if (not const_builder) + throw oops::UserExn("Not supported", "tensor with buffer builder"); + for (uint32_t i = 0; i < tensors.size(); ++i) { - luci::CircleConst *const_node = luci::create_circleconst(&gb_context, i); + auto *const_node = const_builder->build(i, &gb_context); if (const_node != nullptr) nodefinder->enroll(i, const_node); } + // Create CircleVariable nodes for variable tensors + // TODO Add Origin if needed, skip for now + for (uint32_t i = 0; i < tensors.size(); ++i) + { + luci::CircleVariable *variable_node = luci::create_circlevariable(&gb_context, i); + if (variable_node != nullptr) + nodefinder->enroll(i, variable_node); + } + // Import the operators. // Note that operators in model are stored in execution order. This means that when importing // an operator, its input operators have already been imported. We exploit this fact to set up @@ -134,18 +151,23 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r auto origin_table = circle_metadata->origin_table(); for (uint32_t i = 0; i < operators.size(); ++i) { - const circle::OperatorT &op = *operators[i]; + const auto op = operators[i]; + assert(op != nullptr); circle::BuiltinOperator builtincode = reader.builtin_code(op); if (const auto *builder = source.lookup(builtincode)) { - luci::GraphBuilder::ValidateArgs args(op, reader); + // create temporary unpack API obj + circle::OperatorT oper_t; + op->UnPackTo(&oper_t); + + luci::GraphBuilder::ValidateArgs args(oper_t, reader); if (!builder->validate(args)) { throw oops::UserExn("Invalid operator", reader.opcode_name(op)); } - auto built_op = builder->build(op, &gb_context); + auto built_op = builder->build(oper_t, &gb_context); set_node_id(built_op, i); if (origin_table.find(i) != origin_table.end()) add_origin(built_op, origin_table.at(i)); @@ -161,7 +183,8 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r // graph outputs for (auto output : reader.outputs()) { - const circle::TensorT &tensor = *tensors[output]; + const auto tensor = tensors[output]; + assert(tensor != nullptr); auto output_node = graph->nodes()->create<luci::CircleOutput>(); assert(output_node != nullptr); @@ -178,7 +201,7 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r output_node->from(output_dummy); luci::copy_tensor_attributes(tensor, output_dummy); - if (tensors_ptr->Get(output)->shape() == nullptr) + if (tensor->shape() == nullptr) output_dummy->shape_status(luci::ShapeStatus::NOSHAPE); else output_dummy->shape_status(luci::ShapeStatus::VALID); @@ -197,16 +220,18 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r // Set GraphInputOutputIndex for graph output_node->index(graph_output->index()); - assert(tensor.shape_signature.size() == 0 || - tensor.shape_signature.size() == tensor.shape.size()); + const auto tensor_shape_signature = luci::wrap(tensor->shape_signature()); + const auto tensor_shape = luci::wrap(tensor->shape()); + assert(tensor_shape_signature.size() == 0 || + tensor_shape_signature.size() == tensor_shape.size()); // Shape of Output auto output_shape = std::make_unique<loco::TensorShape>(); - const std::vector<int32_t> &output_dims = tensor.shape; // in NHWC + const auto &output_dims = tensor_shape; // in NHWC output_shape->rank(output_dims.size()); for (uint32_t r = 0; r < output_dims.size(); ++r) { - if (tensor.shape_signature.size() > 0 && tensor.shape_signature.at(r) == -1) + if (tensor_shape_signature.size() > 0 && tensor_shape_signature.at(r) == -1) output_shape->dim(r).unset(); else output_shape->dim(r).set(output_dims[r]); @@ -214,7 +239,7 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r graph_output->shape(std::move(output_shape)); // Data type - auto dtype = luci::luci_datatype(tensor.type); + auto dtype = luci::luci_datatype(tensor->type()); graph_output->dtype(dtype); } } @@ -355,7 +380,12 @@ std::unique_ptr<Module> Importer::importModule(const circle::Model *model) const { if (auto circle_node = dynamic_cast<luci::CircleNode *>(node)) { + if (execution_plan_table.count(node_position) == 0) + continue; + auto node_plan = execution_plan_table[node_position]; + assert(node_plan.size() > 0); + luci::add_execution_plan( circle_node, luci::CircleNodeExecutionPlan( diff --git a/compiler/luci/import/src/Importer.test.cpp b/compiler/luci/import/src/Importer.test.cpp index d963b4d49..91e4860ea 100644 --- a/compiler/luci/import/src/Importer.test.cpp +++ b/compiler/luci/import/src/Importer.test.cpp @@ -23,7 +23,7 @@ #include <mio/circle/schema_generated.h> #include <flatbuffers/flatbuffers.h> -TEST(TensorFlowLiteImport, Dummy) +TEST(CircleImport, Dummy) { luci::Importer import; @@ -68,6 +68,7 @@ struct BasicCircleModel { uint32_t id = model->operator_codes.size(); model->operator_codes.push_back(std::make_unique<circle::OperatorCodeT>()); + model->operator_codes[id]->deprecated_builtin_code = opcode; model->operator_codes[id]->builtin_code = opcode; model->operator_codes[id]->version = 1; return id; @@ -179,7 +180,7 @@ struct SimpleRELUModel : public BasicCircleModel /** * This test checks that one op RELU model with execution plan is successfully imported */ -TEST(TensorFlowLiteImport, simple_plan) +TEST(CircleImport, simple_plan) { SimpleRELUModel model; auto metadata_buffer_id = model.add_buffer(); @@ -240,7 +241,7 @@ TEST(TensorFlowLiteImport, simple_plan) /** * This test checks that model with incomplete execution plan is successfully imported */ -TEST(TensorFlowLiteImport, DISABLED_incomplete_plan_NEG) +TEST(CircleImport, incomplete_plan_NEG) { SimpleRELUModel model; auto metadata_buffer_id = model.add_buffer(); @@ -287,7 +288,7 @@ TEST(TensorFlowLiteImport, DISABLED_incomplete_plan_NEG) /** * This test checks that corrupted execution plan induce exception */ -TEST(TensorFlowLiteImport, corrupted_plan_NEG) +TEST(CircleImport, corrupted_plan_NEG) { SimpleRELUModel model; auto metadata_buffer_id = model.add_buffer(); @@ -309,3 +310,44 @@ TEST(TensorFlowLiteImport, corrupted_plan_NEG) ASSERT_ANY_THROW(import.importModule(model_ptr)); } + +/** + * This test checks that empty execution plan entry induce exception + */ +TEST(CircleImport, corrupted_plan_entry_NEG) +{ + SimpleRELUModel model; + auto metadata_buffer_id = model.add_buffer(); + model.add_plan_metadata(metadata_buffer_id); + + model.add_plan_entry(metadata_buffer_id, 1, {100}); + + // add corrupted entry with 0 size + { + auto &buffer = model.model->buffers[metadata_buffer_id]->data; + auto old_size = buffer.size(); + + // Allocate space for new entry: + // 4 bytes for entry id + // 4 bytes for entry size + buffer.resize(old_size + 8); + uint32_t *number_of_entries_ptr = reinterpret_cast<uint32_t *>(buffer.data()); + *number_of_entries_ptr += 1; + + uint32_t *entry_data_ptr = reinterpret_cast<uint32_t *>(buffer.data() + old_size); + + entry_data_ptr[0] = *number_of_entries_ptr - 1; // entry id + entry_data_ptr[1] = 0; // entry size + } + + model.add_plan_entry(metadata_buffer_id, 3, {200}); + + flatbuffers::FlatBufferBuilder fbb; + auto model_offset = circle::Model::Pack(fbb, model.model.get(), nullptr); + circle::FinishModelBuffer(fbb, model_offset); + + auto model_ptr = circle::GetModel(fbb.GetBufferPointer()); + luci::Importer import; + + ASSERT_ANY_THROW(import.importModule(model_ptr)); +} diff --git a/compiler/luci/import/src/Nodes/CircleCast.cpp b/compiler/luci/import/src/Nodes/CircleCast.cpp index 3e8c08bfa..acde823b1 100644 --- a/compiler/luci/import/src/Nodes/CircleCast.cpp +++ b/compiler/luci/import/src/Nodes/CircleCast.cpp @@ -42,12 +42,14 @@ bool CircleCastGraphBuilder::validate(const ValidateArgs &args) const const auto *options = args.op.builtin_options.AsCastOptions(); if (options != nullptr) { - const auto &tensors = args.reader.tensors(); - const circle::TensorT &output_tensor = *tensors[outputs[0]]; + const auto tensors = args.reader.tensors(); + const auto output_tensor = tensors[outputs[0]]; + assert(output_tensor != nullptr); auto name = tensor_name(output_tensor); - const auto &tensor_in = tensors.at(inputs.at(0)); - if (tensor_in->type != options->in_data_type) + const auto tensor_in = tensors.at(inputs.at(0)); + assert(tensor_in != nullptr); + if (tensor_in->type() != options->in_data_type) { if (settings->get(luci::UserSettings::Key::DisableValidation)) { @@ -57,7 +59,7 @@ bool CircleCastGraphBuilder::validate(const ValidateArgs &args) const return false; } const auto &tensor_out = tensors.at(outputs[0]); - if (tensor_out->type != options->out_data_type) + if (tensor_out->type() != options->out_data_type) { if (settings->get(luci::UserSettings::Key::DisableValidation)) { diff --git a/compiler/luci/import/src/Nodes/CircleConst.cpp b/compiler/luci/import/src/Nodes/CircleConst.cpp index 11fbb4e54..a4f190dd9 100644 --- a/compiler/luci/import/src/Nodes/CircleConst.cpp +++ b/compiler/luci/import/src/Nodes/CircleConst.cpp @@ -30,10 +30,10 @@ namespace { -std::ostream &operator<<(std::ostream &os, const std::vector<int32_t> &vect) +std::ostream &operator<<(std::ostream &os, const luci::VectorWrapper<int32_t> &vect) { uint32_t seq = 0; - for (auto &v : vect) + for (const auto &v : vect) { if (seq) os << ", "; @@ -46,7 +46,8 @@ std::ostream &operator<<(std::ostream &os, const std::vector<int32_t> &vect) using namespace luci; template <loco::DataType DT> -void copy_data(const std::vector<uint8_t> &raw_data, uint32_t num_elements, CircleConst *const_node) +void copy_data(const VectorWrapper<uint8_t> &raw_data, uint32_t num_elements, + CircleConst *const_node) { using T = typename loco::DataTypeImpl<DT>::Type; @@ -67,8 +68,8 @@ void copy_data(const std::vector<uint8_t> &raw_data, uint32_t num_elements, Circ } template <> -void copy_data<loco::DataType::STRING>(const std::vector<uint8_t> &raw_data, uint32_t num_elements, - CircleConst *const_node) +void copy_data<loco::DataType::STRING>(const VectorWrapper<uint8_t> &raw_data, + uint32_t num_elements, CircleConst *const_node) { assert(const_node->sparsityparam() == nullptr); @@ -106,17 +107,26 @@ void copy_data<loco::DataType::STRING>(const std::vector<uint8_t> &raw_data, uin namespace luci { -CircleConst *create_circleconst(GraphBuilderContext *context, int32_t tensor_index) +CircleNode *CircleConstNodeBuilder::build(TensorIndex tensor_index, + GraphBuilderContext *context) const { + assert(tensor_index >= 0); LOGGER(l); auto graph = context->graph(); auto reader = context->reader(); - const auto &tensors = reader->tensors(); - const circle::TensorT &const_tensor = *tensors[tensor_index]; + const auto tensors = reader->tensors(); + const auto const_tensor = tensors[tensor_index]; + assert(const_tensor != nullptr); + if (const_tensor->is_variable()) + { + // Create CircleVariable for variable + return nullptr; + } - const std::vector<uint8_t> &buffer = reader->buffers()[const_tensor.buffer]->data; - std::vector<int32_t> const_dims = const_tensor.shape; // in NHWC + assert(reader->buffers()[const_tensor->buffer()] != nullptr); + const auto buffer = wrap(reader->buffers()[const_tensor->buffer()]->data()); + const auto const_dims = wrap(const_tensor->shape()); // in NHWC if (const_dims.size() == 0 && buffer.empty()) { // unknown shape tensor and scalar tensor @@ -150,7 +160,7 @@ CircleConst *create_circleconst(GraphBuilderContext *context, int32_t tensor_ind << const_dims << std::endl; if (num_elements > 0) { - switch (luci_datatype(const_tensor.type)) + switch (luci_datatype(const_tensor->type())) { case loco::DataType::FLOAT32: copy_data<loco::DataType::FLOAT32>(buffer, num_elements, const_node); @@ -186,7 +196,7 @@ CircleConst *create_circleconst(GraphBuilderContext *context, int32_t tensor_ind default: throw oops::UserExn("Unsupported tensor type", - circle::EnumNameTensorType(const_tensor.type)); + circle::EnumNameTensorType(const_tensor->type())); } } diff --git a/compiler/luci/import/src/Nodes/CircleCustom.cpp b/compiler/luci/import/src/Nodes/CircleCustom.cpp index 01ac3e2a0..4e78d5fb7 100644 --- a/compiler/luci/import/src/Nodes/CircleCustom.cpp +++ b/compiler/luci/import/src/Nodes/CircleCustom.cpp @@ -39,13 +39,15 @@ CircleNode *CircleCustomGraphBuilder::build_node(const BuildNodeArgs &bna) const node->inputs(idx, bna.input_nodes[idx]); } - const auto &opcodes = bna.context->reader()->opcodes(); + const auto opcodes = bna.context->reader()->opcodes(); const uint32_t opcode_index = bna.op.opcode_index; - const circle::OperatorCodeT &opcode = *opcodes[opcode_index]; + const auto opcode = opcodes[opcode_index]; + assert(opcode != nullptr); node->custom_options( std::vector<uint8_t>{bna.op.custom_options.begin(), bna.op.custom_options.end()}); - node->custom_code(opcode.custom_code); + assert(opcode->custom_code() != nullptr); + node->custom_code(opcode->custom_code()->c_str()); // NOTE Operator version of custom is always 1 diff --git a/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp b/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp index 49eb30a83..83fc2e37d 100644 --- a/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp +++ b/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp @@ -34,9 +34,10 @@ bool CircleDepthToSpaceGraphBuilder::validate(const ValidateArgs &args) const const auto &outputs = args.op.outputs; const auto *options = args.op.builtin_options.AsDepthToSpaceOptions(); - const auto &tensors = args.reader.tensors(); + const auto tensors = args.reader.tensors(); + assert(tensors[outputs[0]] != nullptr && tensors[inputs.at(0)] != nullptr); - if (tensors[outputs[0]]->type != tensors[inputs.at(0)]->type) + if (tensors[outputs[0]]->type() != tensors[inputs.at(0)]->type()) { return false; } diff --git a/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp b/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp index 727487c6a..a24e4160d 100644 --- a/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp +++ b/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp @@ -32,19 +32,21 @@ bool CircleDepthwiseConv2DGraphBuilder::validate(const ValidateArgs &args) const if (args.op.outputs.size() != 1) return false; - const auto &tensors = args.reader.tensors(); + const auto tensors = args.reader.tensors(); // input shape - const auto &input = tensors.at(args.op.inputs.at(0)); - const auto &input_shape = input->shape; + const auto input = tensors.at(args.op.inputs.at(0)); + assert(input != nullptr); + const auto input_shape = wrap(input->shape()); // input shape must be rank 4 if (input_shape.size() != 4) return false; // filter shape - const auto &filter = tensors.at(args.op.inputs.at(1)); - const auto &filter_shape = filter->shape; + const auto filter = tensors.at(args.op.inputs.at(1)); + assert(filter != nullptr); + const auto filter_shape = wrap(filter->shape()); // filter shape must be rank 4 if (filter_shape.size() != 4) diff --git a/compiler/luci/import/src/Nodes/CircleElu.cpp b/compiler/luci/import/src/Nodes/CircleElu.cpp index 41696a65a..e5d7a4c7a 100644 --- a/compiler/luci/import/src/Nodes/CircleElu.cpp +++ b/compiler/luci/import/src/Nodes/CircleElu.cpp @@ -31,10 +31,11 @@ bool CircleEluGraphBuilder::validate(const ValidateArgs &args) const const auto &inputs = args.op.inputs; const auto &outputs = args.op.outputs; - const auto &tensors = args.reader.tensors(); - const auto &tensor = tensors.at(inputs.at(0)); + const auto tensors = args.reader.tensors(); + const auto tensor = tensors.at(inputs.at(0)); + assert(tensor != nullptr); - switch (tensor->type) + switch (tensor->type()) { case circle::TensorType_FLOAT64: break; @@ -48,7 +49,8 @@ bool CircleEluGraphBuilder::validate(const ValidateArgs &args) const return false; } - if (tensors[outputs[0]]->type != tensor->type) + assert(tensors[outputs[0]] != nullptr); + if (tensors[outputs[0]]->type() != tensor->type()) return false; return true; diff --git a/compiler/luci/import/src/Nodes/CircleEqual.cpp b/compiler/luci/import/src/Nodes/CircleEqual.cpp index 4909692b4..b326d9b5d 100644 --- a/compiler/luci/import/src/Nodes/CircleEqual.cpp +++ b/compiler/luci/import/src/Nodes/CircleEqual.cpp @@ -29,9 +29,10 @@ bool CircleEqualGraphBuilder::validate(const ValidateArgs &args) const return false; const auto &inputs = args.op.inputs; - const auto &tensors = args.reader.tensors(); + const auto tensors = args.reader.tensors(); - return tensors[inputs.at(0)]->type == tensors[inputs.at(1)]->type; + assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(1)] != nullptr); + return tensors[inputs.at(0)]->type() == tensors[inputs.at(1)]->type(); } CircleNode *CircleEqualGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleExp.cpp b/compiler/luci/import/src/Nodes/CircleExp.cpp index 5bb7bb664..82c26f0e5 100644 --- a/compiler/luci/import/src/Nodes/CircleExp.cpp +++ b/compiler/luci/import/src/Nodes/CircleExp.cpp @@ -30,9 +30,10 @@ bool CircleExpGraphBuilder::validate(const ValidateArgs &args) const const auto &inputs = args.op.inputs; // input type check - const auto &tensors = args.reader.tensors(); - const auto &tensor = tensors.at(inputs.at(0)); - switch (tensor->type) + const auto tensors = args.reader.tensors(); + const auto tensor = tensors.at(inputs.at(0)); + assert(tensor != nullptr); + switch (tensor->type()) { case circle::TensorType_FLOAT16: case circle::TensorType_FLOAT32: diff --git a/compiler/luci/import/src/Nodes/CircleExpandDims.cpp b/compiler/luci/import/src/Nodes/CircleExpandDims.cpp index ee0fbdc7e..67d9b7e9e 100644 --- a/compiler/luci/import/src/Nodes/CircleExpandDims.cpp +++ b/compiler/luci/import/src/Nodes/CircleExpandDims.cpp @@ -29,9 +29,10 @@ bool CircleExpandDimsGraphBuilder::validate(const ValidateArgs &args) const return false; const auto &inputs = args.op.inputs; - const auto &tensors = args.reader.tensors(); + const auto tensors = args.reader.tensors(); - return tensors[inputs.at(1)]->type == circle::TensorType_INT32; + assert(tensors[inputs.at(1)] != nullptr); + return tensors[inputs.at(1)]->type() == circle::TensorType_INT32; } CircleNode *CircleExpandDimsGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp b/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp index ce329326a..67eeddf91 100644 --- a/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp +++ b/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp @@ -30,15 +30,18 @@ bool CircleFloorDivGraphBuilder::validate(const ValidateArgs &args) const const auto &inputs = args.op.inputs; const auto &outputs = args.op.outputs; - const auto &tensors = args.reader.tensors(); - const auto &tensor_in_0 = tensors.at(inputs.at(0)); - const auto &tensor_in_1 = tensors.at(inputs.at(1)); - const auto &tensor_out = tensors.at(outputs[0]); - - if (tensor_in_0->type != tensor_in_1->type) + const auto tensors = args.reader.tensors(); + const auto tensor_in_0 = tensors.at(inputs.at(0)); + const auto tensor_in_1 = tensors.at(inputs.at(1)); + const auto tensor_out = tensors.at(outputs[0]); + assert(tensor_in_0 != nullptr); + assert(tensor_in_1 != nullptr); + assert(tensor_out != nullptr); + + if (tensor_in_0->type() != tensor_in_1->type()) return false; - if (tensor_out->type != tensor_in_1->type) + if (tensor_out->type() != tensor_in_1->type()) { return false; } diff --git a/compiler/luci/import/src/Nodes/CircleFloorMod.cpp b/compiler/luci/import/src/Nodes/CircleFloorMod.cpp index d8420a43c..d2a275b62 100644 --- a/compiler/luci/import/src/Nodes/CircleFloorMod.cpp +++ b/compiler/luci/import/src/Nodes/CircleFloorMod.cpp @@ -29,10 +29,11 @@ bool CircleFloorModGraphBuilder::validate(const ValidateArgs &args) const return false; const auto &inputs = args.op.inputs; - const auto &tensors = args.reader.tensors(); - const auto &tensor_in_0 = tensors.at(inputs.at(0)); - const auto &tensor_in_1 = tensors.at(inputs.at(1)); - if (tensor_in_0->type != tensor_in_1->type) + const auto tensors = args.reader.tensors(); + const auto tensor_in_0 = tensors.at(inputs.at(0)); + const auto tensor_in_1 = tensors.at(inputs.at(1)); + assert(tensor_in_0 != nullptr && tensor_in_1 != nullptr); + if (tensor_in_0->type() != tensor_in_1->type()) return false; // TODO dtype check diff --git a/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp b/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp index 58750d79a..cc7be1693 100644 --- a/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp +++ b/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp @@ -42,6 +42,7 @@ CircleNode *CircleFullyConnectedGraphBuilder::build_node(const circle::OperatorT const auto *options = op.builtin_options.AsFullyConnectedOptions(); node->fusedActivationFunction(luci_actfunc(options->fused_activation_function)); node->weights_format(luci_weights_format(options->weights_format)); + node->keep_num_dims(options->keep_num_dims); return node; } diff --git a/compiler/luci/import/src/Nodes/CircleGatherNd.cpp b/compiler/luci/import/src/Nodes/CircleGatherNd.cpp index a4bb26a10..d336878ad 100644 --- a/compiler/luci/import/src/Nodes/CircleGatherNd.cpp +++ b/compiler/luci/import/src/Nodes/CircleGatherNd.cpp @@ -31,10 +31,11 @@ bool CircleGatherNdGraphBuilder::validate(const ValidateArgs &args) const return false; const auto &inputs = args.op.inputs; - auto &indices_tensor = args.reader.tensors()[inputs.at(1)]; + auto indices_tensor = args.reader.tensors()[inputs.at(1)]; + assert(indices_tensor != nullptr); - if (!(indices_tensor->type == circle::TensorType::TensorType_INT32 || - indices_tensor->type == circle::TensorType::TensorType_INT64)) + if (!(indices_tensor->type() == circle::TensorType::TensorType_INT32 || + indices_tensor->type() == circle::TensorType::TensorType_INT64)) { return false; } diff --git a/compiler/luci/import/src/Nodes/CircleGreater.cpp b/compiler/luci/import/src/Nodes/CircleGreater.cpp index f9c00346c..7f031b0ba 100644 --- a/compiler/luci/import/src/Nodes/CircleGreater.cpp +++ b/compiler/luci/import/src/Nodes/CircleGreater.cpp @@ -37,17 +37,19 @@ bool CircleGreaterGraphBuilder::validate(const ValidateArgs &args) const const auto &inputs = args.op.inputs; const auto &outputs = args.op.outputs; - const auto &tensors = args.reader.tensors(); + const auto tensors = args.reader.tensors(); - if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type) + assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(1)] != nullptr); + if (tensors[inputs.at(0)]->type() != tensors[inputs.at(1)]->type()) return false; // NOTE: real models do have output dtype NOT BOOL - if (tensors[outputs[0]]->type != circle::TensorType_BOOL) + assert(tensors[outputs[0]] != nullptr); + if (tensors[outputs[0]]->type() != circle::TensorType_BOOL) { if (settings->get(luci::UserSettings::Key::DisableValidation)) { - const circle::TensorT &output_tensor = *tensors[outputs[0]]; + const auto output_tensor = tensors[outputs[0]]; auto name = tensor_name(output_tensor); WARN(l) << "Warning: import Greater(" << name << ") output dtype is not boolean"; } diff --git a/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp b/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp index e20038fd9..ac4ce62f5 100644 --- a/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp +++ b/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp @@ -30,14 +30,16 @@ bool CircleGreaterEqualGraphBuilder::validate(const ValidateArgs &args) const const auto &inputs = args.op.inputs; const auto &outputs = args.op.outputs; - const auto &tensors = args.reader.tensors(); + const auto tensors = args.reader.tensors(); - if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type) + assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(1)] != nullptr); + if (tensors[inputs.at(0)]->type() != tensors[inputs.at(1)]->type()) { return false; } - return tensors[outputs[0]]->type == circle::TensorType::TensorType_BOOL; + assert(tensors[outputs[0]] != nullptr); + return tensors[outputs[0]]->type() == circle::TensorType::TensorType_BOOL; } CircleNode *CircleGreaterEqualGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleIf.cpp b/compiler/luci/import/src/Nodes/CircleIf.cpp index ffdbf0b79..e8a50ff32 100644 --- a/compiler/luci/import/src/Nodes/CircleIf.cpp +++ b/compiler/luci/import/src/Nodes/CircleIf.cpp @@ -42,12 +42,13 @@ bool CircleIfGraphBuilder::validate(const ValidateArgs &args) const return false; // input 0 should be BOOL type - const auto &tensors = args.reader.tensors(); - const auto &tensor = tensors.at(inputs.at(0)); - if (tensor->type != circle::TensorType_BOOL) + const auto tensors = args.reader.tensors(); + const auto tensor = tensors.at(inputs.at(0)); + assert(tensor != nullptr); + if (tensor->type() != circle::TensorType_BOOL) return false; - const auto &shape = tensor->shape; + const auto shape = wrap(tensor->shape()); if (shape.size() != 1 && shape.size() != 0) return false; diff --git a/compiler/luci/import/src/Nodes/CircleLess.cpp b/compiler/luci/import/src/Nodes/CircleLess.cpp index f9b99bebe..5c5ae51e1 100644 --- a/compiler/luci/import/src/Nodes/CircleLess.cpp +++ b/compiler/luci/import/src/Nodes/CircleLess.cpp @@ -30,10 +30,11 @@ bool CircleLessGraphBuilder::validate(const ValidateArgs &args) const const auto &inputs = args.op.inputs; const auto &outputs = args.op.outputs; - const auto &tensors = args.reader.tensors(); - const auto &tensor = tensors.at(inputs.at(0)); + const auto tensors = args.reader.tensors(); + const auto tensor = tensors.at(inputs.at(0)); + assert(tensor != nullptr); - switch (tensor->type) + switch (tensor->type()) { case circle::TensorType_FLOAT32: case circle::TensorType_FLOAT64: @@ -48,12 +49,14 @@ bool CircleLessGraphBuilder::validate(const ValidateArgs &args) const return false; } - if (tensors[inputs.at(1)]->type != tensor->type) + assert(tensors[inputs.at(1)] != nullptr); + if (tensors[inputs.at(1)]->type() != tensor->type()) { return false; } - return tensors[outputs[0]]->type == circle::TensorType_BOOL; + assert(tensors[outputs[0]] != nullptr); + return tensors[outputs[0]]->type() == circle::TensorType_BOOL; } CircleNode *CircleLessGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleLessEqual.cpp b/compiler/luci/import/src/Nodes/CircleLessEqual.cpp index bb1712137..8a2aea8db 100644 --- a/compiler/luci/import/src/Nodes/CircleLessEqual.cpp +++ b/compiler/luci/import/src/Nodes/CircleLessEqual.cpp @@ -30,14 +30,16 @@ bool CircleLessEqualGraphBuilder::validate(const ValidateArgs &args) const const auto &inputs = args.op.inputs; const auto &outputs = args.op.outputs; - const auto &tensors = args.reader.tensors(); + const auto tensors = args.reader.tensors(); - if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type) + assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(1)] != nullptr); + if (tensors[inputs.at(0)]->type() != tensors[inputs.at(1)]->type()) { return false; } - return tensors[outputs[0]]->type == circle::TensorType::TensorType_BOOL; + assert(tensors[outputs[0]] != nullptr); + return tensors[outputs[0]]->type() == circle::TensorType::TensorType_BOOL; } CircleNode *CircleLessEqualGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleLog.cpp b/compiler/luci/import/src/Nodes/CircleLog.cpp index 26b575070..f41926829 100644 --- a/compiler/luci/import/src/Nodes/CircleLog.cpp +++ b/compiler/luci/import/src/Nodes/CircleLog.cpp @@ -32,9 +32,10 @@ bool CircleLogGraphBuilder::validate(const ValidateArgs &args) const // input type check // Must be one of bfloat16, half, float32, float64, complex64, complex128. // Currently circle supports half(float16), float32, float64, complex64. - const auto &tensors = args.reader.tensors(); - const auto &tensor = tensors.at(inputs.at(0)); - switch (tensor->type) + const auto tensors = args.reader.tensors(); + const auto tensor = tensors.at(inputs.at(0)); + assert(tensor != nullptr); + switch (tensor->type()) { case circle::TensorType_FLOAT16: case circle::TensorType_FLOAT32: diff --git a/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp b/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp index b13fc2735..b61fb6f3e 100644 --- a/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp +++ b/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp @@ -30,11 +30,12 @@ bool CircleLogicalAndGraphBuilder::validate(const ValidateArgs &args) const // Only BOOL type is allowed for inputs const auto &inputs = args.op.inputs; - const auto &tensors = args.reader.tensors(); + const auto tensors = args.reader.tensors(); for (auto input : inputs) { - const auto &tensor = tensors.at(input); - if (tensor->type != circle::TensorType::TensorType_BOOL) + const auto tensor = tensors.at(input); + assert(tensor != nullptr); + if (tensor->type() != circle::TensorType::TensorType_BOOL) return false; } diff --git a/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp b/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp index f68218349..43e9ed39f 100644 --- a/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp +++ b/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp @@ -30,9 +30,10 @@ bool CircleLogicalNotGraphBuilder::validate(const ValidateArgs &args) const // Only BOOL type is allowed for the input const auto &inputs = args.op.inputs; - const auto &tensors = args.reader.tensors(); - const auto &tensor = tensors.at(inputs.at(0)); - if (tensor->type != circle::TensorType::TensorType_BOOL) + const auto tensors = args.reader.tensors(); + const auto tensor = tensors.at(inputs.at(0)); + assert(tensor != nullptr); + if (tensor->type() != circle::TensorType::TensorType_BOOL) return false; return true; diff --git a/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp b/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp index 8c9023dd3..6354e7dc1 100644 --- a/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp +++ b/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp @@ -30,11 +30,12 @@ bool CircleLogicalOrGraphBuilder::validate(const ValidateArgs &args) const // Only BOOL type is allowed for inputs const auto &inputs = args.op.inputs; - const auto &tensors = args.reader.tensors(); + const auto tensors = args.reader.tensors(); for (auto input : inputs) { - const auto &tensor = tensors.at(input); - if (tensor->type != circle::TensorType::TensorType_BOOL) + const auto tensor = tensors.at(input); + assert(tensor != nullptr); + if (tensor->type() != circle::TensorType::TensorType_BOOL) return false; } diff --git a/compiler/luci/import/src/Nodes/CircleLogistic.cpp b/compiler/luci/import/src/Nodes/CircleLogistic.cpp index 0f92a9bb4..b0d08e039 100644 --- a/compiler/luci/import/src/Nodes/CircleLogistic.cpp +++ b/compiler/luci/import/src/Nodes/CircleLogistic.cpp @@ -30,8 +30,9 @@ bool CircleLogisticGraphBuilder::validate(const ValidateArgs &args) const const auto &inputs = args.op.inputs; const auto &outputs = args.op.outputs; - const auto &tensors = args.reader.tensors(); - if (tensors.at(inputs.at(0))->type != tensors.at(outputs[0])->type) + const auto tensors = args.reader.tensors(); + assert(tensors.at(inputs.at(0)) != nullptr && tensors.at(outputs[0]) != nullptr); + if (tensors.at(inputs.at(0))->type() != tensors.at(outputs[0])->type()) return false; return true; diff --git a/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp b/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp index 590a07f2d..384b98586 100644 --- a/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp +++ b/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp @@ -30,10 +30,11 @@ bool CircleMatrixDiagGraphBuilder::validate(const ValidateArgs &args) const const auto &inputs = args.op.inputs; const auto &outputs = args.op.outputs; - const auto &tensors = args.reader.tensors(); - const auto &tensor = tensors.at(inputs.at(0)); + const auto tensors = args.reader.tensors(); + const auto tensor = tensors.at(inputs.at(0)); - if (tensors[outputs[0]]->type != tensor->type) + assert(tensors[outputs[0]] != nullptr && tensor != nullptr); + if (tensors[outputs[0]]->type() != tensor->type()) return false; return true; diff --git a/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp b/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp index edd7d2ae2..64870c057 100644 --- a/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp +++ b/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp @@ -30,10 +30,11 @@ bool CircleMatrixSetDiagGraphBuilder::validate(const ValidateArgs &args) const const auto &inputs = args.op.inputs; const auto &outputs = args.op.outputs; - const auto &tensors = args.reader.tensors(); - const auto &tensor = tensors.at(inputs.at(0)); + const auto tensors = args.reader.tensors(); + const auto tensor = tensors.at(inputs.at(0)); - if (tensors[outputs[0]]->type != tensor->type) + assert(tensors[outputs[0]] != nullptr && tensor != nullptr); + if (tensors[outputs[0]]->type() != tensor->type()) return false; return true; diff --git a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp index d3d69506b..e86f2ba81 100644 --- a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp +++ b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp @@ -35,20 +35,26 @@ bool CircleNonMaxSuppressionV4GraphBuilder::validate(const ValidateArgs &args) c if (outputs.size() != 2) return false; - const auto &tensors = args.reader.tensors(); - const auto &boxes_tensor = tensors.at(inputs[0]); - if (boxes_tensor->shape.size() != 2) + const auto tensors = args.reader.tensors(); + const auto boxes_tensor = tensors.at(inputs[0]); + assert(boxes_tensor != nullptr); + const auto boxes_tensor_shape = wrap(boxes_tensor->shape()); + if (boxes_tensor_shape.size() != 2) return false; - if (boxes_tensor->shape.at(1) != 4) + if (boxes_tensor_shape.at(1) != 4) return false; - if (boxes_tensor->shape.at(0) != tensors.at(inputs[1])->shape.at(0)) + assert(tensors.at(inputs[1]) != nullptr); + if (boxes_tensor_shape.at(0) != wrap(tensors.at(inputs[1])->shape()).at(0)) return false; - if (tensors.at(inputs[2])->type != circle::TensorType_INT32) + assert(tensors.at(inputs[2]) != nullptr); + if (tensors.at(inputs[2])->type() != circle::TensorType_INT32) return false; - if (tensors.at(inputs[3])->type != circle::TensorType_FLOAT32) + assert(tensors.at(inputs[3]) != nullptr); + if (tensors.at(inputs[3])->type() != circle::TensorType_FLOAT32) return false; - if (tensors.at(inputs[4])->type != circle::TensorType_FLOAT32) + assert(tensors.at(inputs[4]) != nullptr); + if (tensors.at(inputs[4])->type() != circle::TensorType_FLOAT32) return false; return true; diff --git a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp index d797d4cb7..a60eed4e4 100644 --- a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp +++ b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp @@ -35,22 +35,29 @@ bool CircleNonMaxSuppressionV5GraphBuilder::validate(const ValidateArgs &args) c if (outputs.size() != 3) return false; - const auto &tensors = args.reader.tensors(); - const auto &boxes_tensor = tensors.at(inputs[0]); - if (boxes_tensor->shape.size() != 2) + const auto tensors = args.reader.tensors(); + const auto boxes_tensor = tensors.at(inputs[0]); + assert(boxes_tensor != nullptr); + const auto boxes_tensor_shape = wrap(boxes_tensor->shape()); + if (boxes_tensor_shape.size() != 2) return false; - if (boxes_tensor->shape.at(1) != 4) + if (boxes_tensor_shape.at(1) != 4) return false; - if (boxes_tensor->shape.at(0) != tensors.at(inputs[1])->shape.at(0)) + assert(tensors.at(inputs[1]) != nullptr); + if (boxes_tensor_shape.at(0) != wrap(tensors.at(inputs[1])->shape()).at(0)) return false; - if (tensors.at(inputs[2])->type != circle::TensorType_INT32) + assert(tensors.at(inputs[2]) != nullptr); + if (tensors.at(inputs[2])->type() != circle::TensorType_INT32) return false; - if (tensors.at(inputs[3])->type != circle::TensorType_FLOAT32) + assert(tensors.at(inputs[3]) != nullptr); + if (tensors.at(inputs[3])->type() != circle::TensorType_FLOAT32) return false; - if (tensors.at(inputs[4])->type != circle::TensorType_FLOAT32) + assert(tensors.at(inputs[4]) != nullptr); + if (tensors.at(inputs[4])->type() != circle::TensorType_FLOAT32) return false; - if (tensors.at(inputs[5])->type != circle::TensorType_FLOAT32) + assert(tensors.at(inputs[5]) != nullptr); + if (tensors.at(inputs[5])->type() != circle::TensorType_FLOAT32) return false; return true; diff --git a/compiler/luci/import/src/Nodes/CircleNotEqual.cpp b/compiler/luci/import/src/Nodes/CircleNotEqual.cpp index a0b8f9e4f..3f5c1e033 100644 --- a/compiler/luci/import/src/Nodes/CircleNotEqual.cpp +++ b/compiler/luci/import/src/Nodes/CircleNotEqual.cpp @@ -30,14 +30,16 @@ bool CircleNotEqualGraphBuilder::validate(const ValidateArgs &args) const const auto &inputs = args.op.inputs; const auto &outputs = args.op.outputs; - const auto &tensors = args.reader.tensors(); + const auto tensors = args.reader.tensors(); - if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type) + assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(1)] != nullptr); + if (tensors[inputs.at(0)]->type() != tensors[inputs.at(1)]->type()) { return false; } - return tensors[outputs[0]]->type == circle::TensorType::TensorType_BOOL; + assert(tensors[outputs[0]] != nullptr); + return tensors[outputs[0]]->type() == circle::TensorType::TensorType_BOOL; } CircleNode *CircleNotEqualGraphBuilder::build_node(const circle::OperatorT &, diff --git a/compiler/luci/import/src/Nodes/CircleOneHot.cpp b/compiler/luci/import/src/Nodes/CircleOneHot.cpp index 3952cc21a..6e5f8e16f 100644 --- a/compiler/luci/import/src/Nodes/CircleOneHot.cpp +++ b/compiler/luci/import/src/Nodes/CircleOneHot.cpp @@ -32,21 +32,25 @@ bool CircleOneHotGraphBuilder::validate(const ValidateArgs &args) const const auto &inputs = args.op.inputs; const auto *options = args.op.builtin_options.AsOneHotOptions(); - const auto &tensors = args.reader.tensors(); - const auto &indices = tensors.at(inputs.at(0)); - const auto &depth = tensors.at(inputs.at(1)); - const auto &on_value = tensors.at(inputs.at(2)); - const auto &off_value = tensors.at(inputs.at(3)); + const auto tensors = args.reader.tensors(); + const auto indices = tensors.at(inputs.at(0)); + const auto depth = tensors.at(inputs.at(1)); + const auto on_value = tensors.at(inputs.at(2)); + const auto off_value = tensors.at(inputs.at(3)); + assert(indices != nullptr); + assert(depth != nullptr); + assert(on_value != nullptr); + assert(off_value != nullptr); - if (options->axis < -1 || options->axis > static_cast<int32_t>(indices->shape.size())) + if (options->axis < -1 || options->axis > static_cast<int32_t>(wrap(indices->shape()).size())) return false; - if (depth->shape.size() != 0) + if (wrap(depth->shape()).size() != 0) return false; - if (on_value->shape.size() != 0) + if (wrap(on_value->shape()).size() != 0) return false; - if (off_value->shape.size() != 0) + if (wrap(off_value->shape()).size() != 0) return false; - if (on_value->type != off_value->type) + if (on_value->type() != off_value->type()) return false; return true; diff --git a/compiler/luci/import/src/Nodes/CircleReduceAny.cpp b/compiler/luci/import/src/Nodes/CircleReduceAny.cpp index 13205dd7a..ebe2368e0 100644 --- a/compiler/luci/import/src/Nodes/CircleReduceAny.cpp +++ b/compiler/luci/import/src/Nodes/CircleReduceAny.cpp @@ -28,17 +28,20 @@ bool CircleReduceAnyGraphBuilder::validate(const ValidateArgs &args) const const auto &inputs = args.op.inputs; const auto &outputs = args.op.outputs; - const auto &tensors = args.reader.tensors(); - const auto &tensor_0 = tensors.at(inputs.at(0)); - const auto &tensor_1 = tensors.at(inputs.at(1)); - const auto &tensor_o = tensors.at(outputs[0]); + const auto tensors = args.reader.tensors(); + const auto tensor_0 = tensors.at(inputs.at(0)); + const auto tensor_1 = tensors.at(inputs.at(1)); + const auto tensor_o = tensors.at(outputs[0]); + assert(tensor_0 != nullptr); + assert(tensor_1 != nullptr); + assert(tensor_o != nullptr); - if (tensor_0->type != circle::TensorType_BOOL) + if (tensor_0->type() != circle::TensorType_BOOL) return false; - if (tensor_o->type != circle::TensorType_BOOL) + if (tensor_o->type() != circle::TensorType_BOOL) return false; - switch (tensor_1->type) + switch (tensor_1->type()) { case circle::TensorType_INT32: case circle::TensorType_INT64: diff --git a/compiler/luci/import/src/Nodes/CircleReduceProd.cpp b/compiler/luci/import/src/Nodes/CircleReduceProd.cpp index 3549c1a18..3b874b7c9 100644 --- a/compiler/luci/import/src/Nodes/CircleReduceProd.cpp +++ b/compiler/luci/import/src/Nodes/CircleReduceProd.cpp @@ -27,13 +27,14 @@ bool CircleReduceProdGraphBuilder::validate(const ValidateArgs &args) const return false; const auto &inputs = args.op.inputs; - const auto &tensors = args.reader.tensors(); - const auto &tensor_1 = tensors.at(inputs.at(1)); + const auto tensors = args.reader.tensors(); + const auto tensor_1 = tensors.at(inputs.at(1)); + assert(tensor_1 != nullptr); // TODO check input types // Check for reduction_indices types - switch (tensor_1->type) + switch (tensor_1->type()) { case circle::TensorType_INT32: case circle::TensorType_INT64: diff --git a/compiler/luci/import/src/Nodes/CircleReshape.cpp b/compiler/luci/import/src/Nodes/CircleReshape.cpp index 401dff0fc..3421620ce 100644 --- a/compiler/luci/import/src/Nodes/CircleReshape.cpp +++ b/compiler/luci/import/src/Nodes/CircleReshape.cpp @@ -34,12 +34,13 @@ bool CircleReshapeGraphBuilder::validate(const ValidateArgs &args) const if (args.op.inputs.size() == 2) { const auto &inputs = args.op.inputs; - const auto &tensors = args.reader.tensors(); - const auto &tensor_in = tensors.at(inputs.at(1)); + const auto tensors = args.reader.tensors(); + const auto tensor_in = tensors.at(inputs.at(1)); + assert(tensor_in != nullptr); // NOTE fix this if there is any other case // TensorFlow lite and circle only supports S32 - if (tensor_in->type != circle::TensorType::TensorType_INT32) + if (tensor_in->type() != circle::TensorType::TensorType_INT32) return false; } diff --git a/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp b/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp index 2fbb7a87c..c9cc792bb 100644 --- a/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp +++ b/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp @@ -30,12 +30,15 @@ bool CircleReverseSequenceGraphBuilder::validate(const ValidateArgs &args) const const auto &inputs = args.op.inputs; const auto &outputs = args.op.outputs; - const auto &tensors = args.reader.tensors(); - const auto &tensor_in = tensors.at(inputs.at(0)); - const auto &tensor_lengths = tensors.at(inputs.at(1)); - const auto &tensor_out = tensors.at(outputs[0]); + const auto tensors = args.reader.tensors(); + const auto tensor_in = tensors.at(inputs.at(0)); + const auto tensor_lengths = tensors.at(inputs.at(1)); + const auto tensor_out = tensors.at(outputs[0]); + assert(tensor_in != nullptr); + assert(tensor_lengths != nullptr); + assert(tensor_out != nullptr); - switch (tensor_lengths->type) + switch (tensor_lengths->type()) { case circle::TensorType_INT32: case circle::TensorType_INT64: @@ -44,7 +47,7 @@ bool CircleReverseSequenceGraphBuilder::validate(const ValidateArgs &args) const return false; } - if (tensor_in->type != tensor_out->type) + if (tensor_in->type() != tensor_out->type()) return false; return true; diff --git a/compiler/luci/import/src/Nodes/CircleReverseV2.cpp b/compiler/luci/import/src/Nodes/CircleReverseV2.cpp index ca7653201..c19a0fdd2 100644 --- a/compiler/luci/import/src/Nodes/CircleReverseV2.cpp +++ b/compiler/luci/import/src/Nodes/CircleReverseV2.cpp @@ -30,12 +30,15 @@ bool CircleReverseV2GraphBuilder::validate(const ValidateArgs &args) const const auto &inputs = args.op.inputs; const auto &outputs = args.op.outputs; - const auto &tensors = args.reader.tensors(); - const auto &tensor_in = tensors.at(inputs.at(0)); - const auto &tensor_axis = tensors.at(inputs.at(1)); - const auto &tensor_out = tensors.at(outputs[0]); + const auto tensors = args.reader.tensors(); + const auto tensor_in = tensors.at(inputs.at(0)); + const auto tensor_axis = tensors.at(inputs.at(1)); + const auto tensor_out = tensors.at(outputs[0]); + assert(tensor_in != nullptr); + assert(tensor_axis != nullptr); + assert(tensor_out != nullptr); - switch (tensor_axis->type) + switch (tensor_axis->type()) { case circle::TensorType_INT32: case circle::TensorType_INT64: @@ -44,7 +47,7 @@ bool CircleReverseV2GraphBuilder::validate(const ValidateArgs &args) const return false; } - if (tensor_out->type != tensor_in->type) + if (tensor_out->type() != tensor_in->type()) return false; return true; diff --git a/compiler/luci/import/src/Nodes/CircleRound.cpp b/compiler/luci/import/src/Nodes/CircleRound.cpp index d13e0fafe..08cfae6c2 100644 --- a/compiler/luci/import/src/Nodes/CircleRound.cpp +++ b/compiler/luci/import/src/Nodes/CircleRound.cpp @@ -33,11 +33,13 @@ bool CircleRoundGraphBuilder::validate(const ValidateArgs &args) const // Must be one of the following types // bfloat16, half (float16), float32, float64, complex64, complex128 // Currently, circle supports float16, float32, complex64 - const auto &tensors = args.reader.tensors(); - const auto &tensor_in = tensors.at(inputs.at(0)); - const auto &tensor_out = tensors.at(outputs[0]); + const auto tensors = args.reader.tensors(); + const auto tensor_in = tensors.at(inputs.at(0)); + const auto tensor_out = tensors.at(outputs[0]); + assert(tensor_in != nullptr); + assert(tensor_out != nullptr); - switch (tensor_in->type) + switch (tensor_in->type()) { case circle::TensorType_FLOAT16: case circle::TensorType_FLOAT32: @@ -49,7 +51,7 @@ bool CircleRoundGraphBuilder::validate(const ValidateArgs &args) const return false; } - if (tensor_out->type != tensor_in->type) + if (tensor_out->type() != tensor_in->type()) return false; return true; diff --git a/compiler/luci/import/src/Nodes/CircleRsqrt.cpp b/compiler/luci/import/src/Nodes/CircleRsqrt.cpp index a9ca90832..e3bc68f8b 100644 --- a/compiler/luci/import/src/Nodes/CircleRsqrt.cpp +++ b/compiler/luci/import/src/Nodes/CircleRsqrt.cpp @@ -32,9 +32,10 @@ bool CircleRsqrtGraphBuilder::validate(const ValidateArgs &args) const // Must be one of the following types // bfloat16, half (float16), float32, float64, complex64, complex128 // Currently, circle supports float16, float32, complex64 - const auto &tensors = args.reader.tensors(); - const auto &tensor = tensors.at(inputs.at(0)); - switch (tensor->type) + const auto tensors = args.reader.tensors(); + const auto tensor = tensors.at(inputs.at(0)); + assert(tensor != nullptr); + switch (tensor->type()) { case circle::TensorType_UINT8: case circle::TensorType_INT16: diff --git a/compiler/luci/import/src/Nodes/CircleSVDF.cpp b/compiler/luci/import/src/Nodes/CircleSVDF.cpp new file mode 100644 index 000000000..83a025177 --- /dev/null +++ b/compiler/luci/import/src/Nodes/CircleSVDF.cpp @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Import/Nodes/CircleSVDF.h" + +#include <luci/IR/Nodes/CircleSVDF.h> + +#include <loco.h> + +namespace luci +{ + +bool CircleSVDFBuilder::validate(const ValidateArgs &args) const +{ + const auto &inputs = args.op.inputs; + if (!(inputs.size() == 4 || inputs.size() == 5)) + return false; + + return true; +} + +CircleNode *CircleSVDFBuilder::build_node(const circle::OperatorT &op, + const std::vector<CircleNode *> &inputs, + loco::Graph *graph) const +{ + auto *node = graph->nodes()->create<CircleSVDF>(); + node->input(inputs.at(0)); + node->weight_feature(inputs.at(1)); + node->weight_time(inputs.at(2)); + if (inputs.size() == 4) + { + auto *bias = graph->nodes()->create<CircleOutputExclude>(); + // CircleOutputExclude doesn't need a type, but since all nodes must have a type, + // a dummy type is inserted. + bias->dtype(inputs.at(0)->dtype()); + node->bias(bias); + + node->input_activation_state(inputs.at(3)); + } + else + { + node->bias(inputs.at(3)); + node->input_activation_state(inputs.at(4)); + } + + const auto *options = op.builtin_options.AsSVDFOptions(); + node->svdf_rank(options->rank); + node->fusedActivationFunction(luci_actfunc(options->fused_activation_function)); + node->asymmetric_quantize_inputs(options->asymmetric_quantize_inputs); + + return node; +} + +} // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleScatterNd.cpp b/compiler/luci/import/src/Nodes/CircleScatterNd.cpp index f8c175110..ebe252527 100644 --- a/compiler/luci/import/src/Nodes/CircleScatterNd.cpp +++ b/compiler/luci/import/src/Nodes/CircleScatterNd.cpp @@ -30,14 +30,15 @@ bool CircleScatterNdGraphBuilder::validate(const ValidateArgs &args) const const auto &inputs = args.op.inputs; // indices must have the same type as shape - const auto &tensors = args.reader.tensors(); + const auto tensors = args.reader.tensors(); - if (tensors[inputs.at(0)]->type != tensors[inputs.at(2)]->type) + assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(2)] != nullptr); + if (tensors[inputs.at(0)]->type() != tensors[inputs.at(2)]->type()) return false; // indices must be either int32 or int64 - if (tensors[inputs.at(0)]->type != circle::TensorType_INT32 && - tensors[inputs.at(0)]->type != circle::TensorType_INT64) + if (tensors[inputs.at(0)]->type() != circle::TensorType_INT32 && + tensors[inputs.at(0)]->type() != circle::TensorType_INT64) return false; return true; diff --git a/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp b/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp index bfa333e8d..01d1aab44 100644 --- a/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp +++ b/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp @@ -30,12 +30,15 @@ bool CircleSegmentSumGraphBuilder::validate(const ValidateArgs &args) const const auto &inputs = args.op.inputs; const auto &outputs = args.op.outputs; - const auto &tensors = args.reader.tensors(); - const auto &tensor_in = tensors.at(inputs.at(0)); - const auto &tensor_out = tensors.at(outputs[0]); - const auto &tensor_ids = tensors.at(inputs.at(1)); + const auto tensors = args.reader.tensors(); + const auto tensor_in = tensors.at(inputs.at(0)); + const auto tensor_out = tensors.at(outputs[0]); + const auto tensor_ids = tensors.at(inputs.at(1)); + assert(tensor_in != nullptr); + assert(tensor_out != nullptr); + assert(tensor_ids != nullptr); - switch (tensor_ids->type) + switch (tensor_ids->type()) { case circle::TensorType_INT32: case circle::TensorType_INT64: @@ -44,7 +47,7 @@ bool CircleSegmentSumGraphBuilder::validate(const ValidateArgs &args) const return false; } - if (tensor_out->type != tensor_in->type) + if (tensor_out->type() != tensor_in->type()) { return false; } diff --git a/compiler/luci/import/src/Nodes/CircleSelect.cpp b/compiler/luci/import/src/Nodes/CircleSelect.cpp index 36a5fa8a8..002f62f6c 100644 --- a/compiler/luci/import/src/Nodes/CircleSelect.cpp +++ b/compiler/luci/import/src/Nodes/CircleSelect.cpp @@ -29,9 +29,10 @@ bool CircleSelectGraphBuilder::validate(const ValidateArgs &args) const return false; const auto &inputs = args.op.inputs; - const auto &tensors = args.reader.tensors(); - const auto &tensor = tensors.at(inputs.at(0)); - if (tensor->type != circle::TensorType_BOOL) + const auto tensors = args.reader.tensors(); + const auto tensor = tensors.at(inputs.at(0)); + assert(tensor != nullptr); + if (tensor->type() != circle::TensorType_BOOL) return false; // TODO check dtypes for input 1, 2 diff --git a/compiler/luci/import/src/Nodes/CircleSelectV2.cpp b/compiler/luci/import/src/Nodes/CircleSelectV2.cpp index 556c8fa33..062fdc143 100644 --- a/compiler/luci/import/src/Nodes/CircleSelectV2.cpp +++ b/compiler/luci/import/src/Nodes/CircleSelectV2.cpp @@ -29,14 +29,16 @@ bool CircleSelectV2GraphBuilder::validate(const ValidateArgs &args) const return false; const auto &inputs = args.op.inputs; - const auto &tensors = args.reader.tensors(); - const auto &condition = tensors.at(inputs.at(0)); - if (condition->type != circle::TensorType_BOOL) + const auto tensors = args.reader.tensors(); + const auto condition = tensors.at(inputs.at(0)); + assert(condition != nullptr); + if (condition->type() != circle::TensorType_BOOL) return false; - const auto &t = tensors.at(inputs.at(1)); - const auto &e = tensors.at(inputs.at(2)); - if (t->type != e->type) + const auto t = tensors.at(inputs.at(1)); + const auto e = tensors.at(inputs.at(2)); + assert(t != nullptr && e != nullptr); + if (t->type() != e->type()) return false; return true; diff --git a/compiler/luci/import/src/Nodes/CircleSin.cpp b/compiler/luci/import/src/Nodes/CircleSin.cpp index 22f461123..51ebf0355 100644 --- a/compiler/luci/import/src/Nodes/CircleSin.cpp +++ b/compiler/luci/import/src/Nodes/CircleSin.cpp @@ -30,9 +30,10 @@ bool CircleSinGraphBuilder::validate(const ValidateArgs &args) const const auto &inputs = args.op.inputs; // input type check - const auto &tensors = args.reader.tensors(); - const auto &tensor = tensors.at(inputs.at(0)); - switch (tensor->type) + const auto tensors = args.reader.tensors(); + const auto tensor = tensors.at(inputs.at(0)); + assert(tensor != nullptr); + switch (tensor->type()) { case circle::TensorType_FLOAT16: case circle::TensorType_FLOAT32: diff --git a/compiler/luci/import/src/Nodes/CircleSquare.cpp b/compiler/luci/import/src/Nodes/CircleSquare.cpp index 7ff2b84e6..bec84b4c0 100644 --- a/compiler/luci/import/src/Nodes/CircleSquare.cpp +++ b/compiler/luci/import/src/Nodes/CircleSquare.cpp @@ -29,13 +29,13 @@ bool CircleSquareGraphBuilder::validate(const ValidateArgs &args) const return false; const auto &inputs = args.op.inputs; - // Must be one of the following types - // bfloat16, half (float16), float32, float64, complex64, complex128 - // Currently, circle supports float16, float32, complex64 - const auto &tensors = args.reader.tensors(); - const auto &tensor = tensors.at(inputs.at(0)); - switch (tensor->type) + const auto tensors = args.reader.tensors(); + const auto tensor = tensors.at(inputs.at(0)); + assert(tensor != nullptr); + switch (tensor->type()) { + case circle::TensorType_UINT8: + case circle::TensorType_INT16: case circle::TensorType_INT32: case circle::TensorType_INT64: case circle::TensorType_FLOAT16: diff --git a/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp b/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp index 33440d5ab..1983465d3 100644 --- a/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp +++ b/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp @@ -32,9 +32,10 @@ bool CircleSquaredDifferenceGraphBuilder::validate(const ValidateArgs &args) con const auto &outputs = args.op.outputs; // Inputs must be one of the following types // bfloat16, half(float16), float32, float64, int32, int64, complex64, complex128 - const auto &tensors = args.reader.tensors(); - const auto &tensor = tensors.at(inputs.at(0)); - switch (tensor->type) + const auto tensors = args.reader.tensors(); + const auto tensor = tensors.at(inputs.at(0)); + assert(tensor != nullptr); + switch (tensor->type()) { case circle::TensorType_FLOAT16: case circle::TensorType_FLOAT32: @@ -53,11 +54,13 @@ bool CircleSquaredDifferenceGraphBuilder::validate(const ValidateArgs &args) con } // Input types must match - if (tensors.at(inputs.at(0))->type != tensors.at(inputs.at(1))->type) + assert(tensors.at(inputs.at(0)) != nullptr && tensors.at(inputs.at(1)) != nullptr); + if (tensors.at(inputs.at(0))->type() != tensors.at(inputs.at(1))->type()) return false; // Input and output types must match - if (tensors.at(inputs.at(0))->type != tensors.at(outputs[0])->type) + assert(tensors.at(outputs[0]) != nullptr); + if (tensors.at(inputs.at(0))->type() != tensors.at(outputs[0])->type()) return false; return true; diff --git a/compiler/luci/import/src/Nodes/CircleTanh.cpp b/compiler/luci/import/src/Nodes/CircleTanh.cpp index 95625a0e4..80a0e887f 100644 --- a/compiler/luci/import/src/Nodes/CircleTanh.cpp +++ b/compiler/luci/import/src/Nodes/CircleTanh.cpp @@ -30,8 +30,9 @@ bool CircleTanhGraphBuilder::validate(const ValidateArgs &args) const const auto &inputs = args.op.inputs; const auto &outputs = args.op.outputs; - const auto &tensors = args.reader.tensors(); - if (tensors.at(inputs.at(0))->type != tensors.at(outputs[0])->type) + const auto tensors = args.reader.tensors(); + assert(tensors.at(inputs.at(0)) != nullptr && tensors.at(outputs[0]) != nullptr); + if (tensors.at(inputs.at(0))->type() != tensors.at(outputs[0])->type()) return false; return true; diff --git a/compiler/luci/import/src/Nodes/CircleTile.cpp b/compiler/luci/import/src/Nodes/CircleTile.cpp index 6da44130c..c41a6ba3f 100644 --- a/compiler/luci/import/src/Nodes/CircleTile.cpp +++ b/compiler/luci/import/src/Nodes/CircleTile.cpp @@ -32,9 +32,10 @@ bool CircleTileGraphBuilder::validate(const ValidateArgs &args) const auto outputs = args.op.outputs; // Multiples (inputs.at(1)) must be one of the following types // int32, int64 - const auto &tensors = args.reader.tensors(); - const auto &tensor = tensors.at(inputs.at(1)); - switch (tensor->type) + const auto tensors = args.reader.tensors(); + const auto tensor = tensors.at(inputs.at(1)); + assert(tensor != nullptr); + switch (tensor->type()) { case circle::TensorType_INT32: case circle::TensorType_INT64: @@ -44,7 +45,8 @@ bool CircleTileGraphBuilder::validate(const ValidateArgs &args) const } // Type of input and output must be the same - if (tensors.at(inputs.at(0))->type != tensors.at(outputs[0])->type) + assert(tensors.at(inputs.at(0)) != nullptr && tensors.at(outputs[0]) != nullptr); + if (tensors.at(inputs.at(0))->type() != tensors.at(outputs[0])->type()) return false; return true; diff --git a/compiler/luci/import/src/Nodes/CircleTopKV2.cpp b/compiler/luci/import/src/Nodes/CircleTopKV2.cpp index 49f858798..9f9173738 100644 --- a/compiler/luci/import/src/Nodes/CircleTopKV2.cpp +++ b/compiler/luci/import/src/Nodes/CircleTopKV2.cpp @@ -35,9 +35,10 @@ bool CircleTopKV2GraphBuilder::validate(const ValidateArgs &args) const if (outputs.size() != 2) return false; - const auto &tensors = args.reader.tensors(); - const auto &tensor = tensors.at(inputs.at(1)); - if (tensor->type != circle::TensorType_INT32) + const auto tensors = args.reader.tensors(); + const auto tensor = tensors.at(inputs.at(1)); + assert(tensor != nullptr); + if (tensor->type() != circle::TensorType_INT32) return false; return true; diff --git a/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp b/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp index 5a60e2f54..041983dac 100644 --- a/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp +++ b/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp @@ -31,11 +31,13 @@ bool CircleTransposeConvGraphBuilder::validate(const ValidateArgs &args) const return false; const auto &inputs = args.op.inputs; - const auto &tensors = args.reader.tensors(); - const auto &filter_tensor = tensors.at(inputs.at(1)); - const auto &filter_shape = filter_tensor.get()->shape; - const auto &ifm_tensor = tensors.at(inputs.at(2)); - const auto &ifm_shape = ifm_tensor.get()->shape; + const auto tensors = args.reader.tensors(); + const auto filter_tensor = tensors.at(inputs.at(1)); + assert(filter_tensor != nullptr); + const auto filter_shape = wrap(filter_tensor->shape()); + const auto ifm_tensor = tensors.at(inputs.at(2)); + assert(ifm_tensor != nullptr); + const auto ifm_shape = wrap(ifm_tensor->shape()); // ifm and filters must be 4-D tensor if (ifm_shape.size() != 4) @@ -45,7 +47,7 @@ bool CircleTransposeConvGraphBuilder::validate(const ValidateArgs &args) const // input shape : [batch, height, width, in_channels] // filters shape : [output_channels, height, weight, in_channels] - if (ifm_tensor.get()->shape.at(3) != filter_tensor.get()->shape.at(3)) + if (ifm_shape.at(3) != filter_shape.at(3)) return false; return true; diff --git a/compiler/luci/import/src/Nodes/CircleUnpack.cpp b/compiler/luci/import/src/Nodes/CircleUnpack.cpp index 9bfc76b57..6b3401609 100644 --- a/compiler/luci/import/src/Nodes/CircleUnpack.cpp +++ b/compiler/luci/import/src/Nodes/CircleUnpack.cpp @@ -46,8 +46,8 @@ bool CircleUnpackGraphBuilder::validate(const ValidateArgs &args) const { if (settings->get(luci::UserSettings::Key::DisableValidation)) { - const auto &tensors = args.reader.tensors(); - const circle::TensorT &output_tensor = *tensors[outputs[0]]; + const auto tensors = args.reader.tensors(); + const auto output_tensor = tensors[outputs[0]]; auto name = tensor_name(output_tensor); WARN(l) << "Warning: import Unpack(" << name << ") 'num' is not same as outputs used"; } @@ -58,9 +58,10 @@ bool CircleUnpackGraphBuilder::validate(const ValidateArgs &args) const if (options->num < 0) return false; - const auto &tensors = args.reader.tensors(); - const auto &tensor = tensors.at(inputs.at(0)); - const auto &shape = tensor->shape; + const auto tensors = args.reader.tensors(); + const auto tensor = tensors.at(inputs.at(0)); + assert(tensor != nullptr); + const auto shape = wrap(tensor->shape()); auto shape_size = static_cast<int32_t>(shape.size()); if (shape_size > 0) { diff --git a/compiler/luci/import/src/Nodes/CircleVariable.cpp b/compiler/luci/import/src/Nodes/CircleVariable.cpp new file mode 100644 index 000000000..23ae9e7be --- /dev/null +++ b/compiler/luci/import/src/Nodes/CircleVariable.cpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Import/Nodes/CircleVariable.h" + +#include <luci/IR/Nodes/CircleVariable.h> +#include <luci/Log.h> + +#include <cassert> +#include <ostream> +#include <string> +#include <vector> + +namespace +{ + +std::ostream &operator<<(std::ostream &os, const luci::VectorWrapper<int32_t> &vect) +{ + uint32_t seq = 0; + for (const auto &v : vect) + { + if (seq) + os << ", "; + os << v; + seq++; + } + return os; +} + +} // namespace + +namespace luci +{ + +CircleVariable *create_circlevariable(GraphBuilderContext *context, int32_t tensor_index) +{ + LOGGER(l); + + auto graph = context->graph(); + auto reader = context->reader(); + const auto tensors = reader->tensors(); + const auto variable_tensor = tensors[tensor_index]; + assert(variable_tensor != nullptr); + + if (not variable_tensor->is_variable()) + { + // not a variable + return nullptr; + } + { + // check if there is no buffer as we don't support this for now + // TODO use buffer when this is enabled in Kernel + assert(reader->buffers()[variable_tensor->buffer()] != nullptr); + assert(reader->buffers()[variable_tensor->buffer()]->data() == nullptr); + } + + auto variable_node = graph->nodes()->create<CircleVariable>(); + copy_tensor_attributes(variable_tensor, variable_node); + variable_node->shape_status(luci::ShapeStatus::VALID); + + INFO(l) << "[luci] NodeFinder variable node(" << tensor_index << ") -> " << variable_node << " " + << wrap(variable_tensor->shape()) << std::endl; + + return variable_node; +} + +} // namespace luci diff --git a/compiler/luci/import/src/Nodes/CircleWhere.cpp b/compiler/luci/import/src/Nodes/CircleWhere.cpp index 8e4f1a0c4..bc6199ace 100644 --- a/compiler/luci/import/src/Nodes/CircleWhere.cpp +++ b/compiler/luci/import/src/Nodes/CircleWhere.cpp @@ -30,14 +30,16 @@ bool CircleWhereGraphBuilder::validate(const ValidateArgs &args) const const auto &inputs = args.op.inputs; const auto &outputs = args.op.outputs; - const auto &tensors = args.reader.tensors(); - const auto &tensor_condition = tensors.at(inputs.at(0)); - const auto &tensor_out = tensors.at(outputs[0]); + const auto tensors = args.reader.tensors(); + const auto tensor_condition = tensors.at(inputs.at(0)); + const auto tensor_out = tensors.at(outputs[0]); + assert(tensor_condition != nullptr); + assert(tensor_out != nullptr); - if (tensor_condition->type != circle::TensorType_BOOL) + if (tensor_condition->type() != circle::TensorType_BOOL) return false; - if (tensor_out->type != circle::TensorType_INT64) + if (tensor_out->type() != circle::TensorType_INT64) return false; return true; diff --git a/compiler/luci/import/src/Nodes/CircleWhile.cpp b/compiler/luci/import/src/Nodes/CircleWhile.cpp index 26147562f..27a392b2a 100644 --- a/compiler/luci/import/src/Nodes/CircleWhile.cpp +++ b/compiler/luci/import/src/Nodes/CircleWhile.cpp @@ -67,8 +67,8 @@ CircleNode *CircleWhileGraphBuilder::build(const circle::OperatorT &op, const std::vector<int32_t> &inputs = op.inputs; const std::vector<int32_t> &outputs = op.outputs; - const auto &tensors = context->reader()->tensors(); - const auto &opcodes = context->reader()->opcodes(); + const auto tensors = context->reader()->tensors(); + const auto opcodes = context->reader()->opcodes(); std::vector<CircleNode *> input_nodes; for (const int32_t input_tensor_index : inputs) @@ -96,9 +96,11 @@ CircleNode *CircleWhileGraphBuilder::build(const circle::OperatorT &op, assert(outputs.size() > 0); { // Lets use name of output 0 as While name - const circle::TensorT &output_tensor = *tensors[outputs[0]]; + const auto output_tensor = tensors[outputs[0]]; + assert(output_tensor != nullptr); node->name(tensor_name(output_tensor)); - node->op_version(opcodes[op.opcode_index].get()->version); + assert(opcodes[op.opcode_index] != nullptr); + node->op_version(opcodes[op.opcode_index]->version()); // NOTE We don't set quantization for While itself but to virtual outputs } @@ -106,7 +108,8 @@ CircleNode *CircleWhileGraphBuilder::build(const circle::OperatorT &op, // Create virtual outputs of While for (uint32_t n = 0; n < output_count; ++n) { - const circle::TensorT &output_tensor = *tensors[outputs[n]]; + const auto output_tensor = tensors[outputs[n]]; + assert(output_tensor != nullptr); auto *nodeout = graph->nodes()->create<CircleWhileOut>(); diff --git a/compiler/luci/import/src/ValidateHelpers.cpp b/compiler/luci/import/src/ValidateHelpers.cpp index 27306ba90..fc027704b 100644 --- a/compiler/luci/import/src/ValidateHelpers.cpp +++ b/compiler/luci/import/src/ValidateHelpers.cpp @@ -26,9 +26,10 @@ bool validate_batch_space_nd(const GraphBuilderBase::ValidateArgs &args) return false; // input 1 and 2 should have INT32/INT64 type - const auto &tensors = args.reader.tensors(); - const auto &tensor_1 = tensors.at(inputs.at(1)); - switch (tensor_1->type) + const auto tensors = args.reader.tensors(); + const auto tensor_1 = tensors.at(inputs.at(1)); + assert(tensor_1 != nullptr); + switch (tensor_1->type()) { case circle::TensorType_INT32: case circle::TensorType_INT64: @@ -36,8 +37,9 @@ bool validate_batch_space_nd(const GraphBuilderBase::ValidateArgs &args) default: return false; } - const auto &tensor_2 = tensors.at(inputs.at(2)); - switch (tensor_2->type) + const auto tensor_2 = tensors.at(inputs.at(2)); + assert(tensor_2 != nullptr); + switch (tensor_2->type()) { case circle::TensorType_INT32: case circle::TensorType_INT64: @@ -47,8 +49,9 @@ bool validate_batch_space_nd(const GraphBuilderBase::ValidateArgs &args) } // Only support input shape dimension 3 and 4 only - const auto &tensor_0 = tensors.at(inputs.at(0)); - const auto t_0_s = tensor_0->shape.size(); + const auto tensor_0 = tensors.at(inputs.at(0)); + assert(tensor_0 != nullptr); + const auto t_0_s = wrap(tensor_0->shape()).size(); if (t_0_s != 3 && t_0_s != 4) return false; @@ -68,10 +71,10 @@ bool validate_minmax(const GraphBuilderBase::ValidateArgs &args) if (outputs.size() != 1) return false; - const auto &tensors = args.reader.tensors(); - const auto &tensor = tensors.at(inputs.at(0)); - - switch (tensor->type) + const auto tensors = args.reader.tensors(); + const auto tensor = tensors.at(inputs.at(0)); + assert(tensor != nullptr); + switch (tensor->type()) { case circle::TensorType_FLOAT16: case circle::TensorType_FLOAT32: @@ -84,10 +87,12 @@ bool validate_minmax(const GraphBuilderBase::ValidateArgs &args) return false; } - if (tensors[inputs.at(1)]->type != tensor->type) + assert(tensors[inputs.at(1)] != nullptr); + if (tensors[inputs.at(1)]->type() != tensor->type()) return false; - if (tensors[outputs[0]]->type != tensor->type) + assert(tensors[outputs[0]] != nullptr); + if (tensors[outputs[0]]->type() != tensor->type()) return false; return true; @@ -104,10 +109,10 @@ bool validate_reduce_minmax(const GraphBuilderBase::ValidateArgs &args) if (outputs.size() != 1) return false; - const auto &tensors = args.reader.tensors(); - const auto &tensor_axis = tensors.at(inputs.at(1)); - - switch (tensor_axis->type) + const auto tensors = args.reader.tensors(); + const auto tensor_axis = tensors.at(inputs.at(1)); + assert(tensor_axis != nullptr); + switch (tensor_axis->type()) { case circle::TensorType_INT32: case circle::TensorType_INT64: diff --git a/compiler/luci/lang/include/luci/IR/CircleNodes.h b/compiler/luci/lang/include/luci/IR/CircleNodes.h index a313f9d5b..d89ea03cc 100644 --- a/compiler/luci/lang/include/luci/IR/CircleNodes.h +++ b/compiler/luci/lang/include/luci/IR/CircleNodes.h @@ -29,7 +29,6 @@ #include "Nodes/CircleCast.h" #include "Nodes/CircleCeil.h" #include "Nodes/CircleConcatenation.h" -#include "Nodes/CircleConst.h" #include "Nodes/CircleConv2D.h" #include "Nodes/CircleCos.h" #include "Nodes/CircleCustom.h" @@ -119,6 +118,7 @@ #include "Nodes/CircleStridedSlice.h" #include "Nodes/CircleSub.h" #include "Nodes/CircleSum.h" +#include "Nodes/CircleSVDF.h" #include "Nodes/CircleTanh.h" #include "Nodes/CircleTile.h" #include "Nodes/CircleTopKV2.h" @@ -135,18 +135,21 @@ #include "Nodes/CircleBCQGather.h" #include "Nodes/CircleInstanceNorm.h" // Virtual nodes +#include "Nodes/CircleConst.h" #include "Nodes/CircleInput.h" #include "Nodes/CircleOutput.h" +#include "Nodes/CircleVariable.h" +// Multi-output virtual nodes #include "Nodes/CircleBidirectionalSequenceLSTMOut.h" #include "Nodes/CircleCustomOut.h" #include "Nodes/CircleIfOut.h" #include "Nodes/CircleNonMaxSuppressionV4Out.h" #include "Nodes/CircleNonMaxSuppressionV5Out.h" -#include "Nodes/CircleUnpackOut.h" -#include "Nodes/CircleUniqueOut.h" #include "Nodes/CircleSplitOut.h" #include "Nodes/CircleSplitVOut.h" #include "Nodes/CircleTopKV2Out.h" +#include "Nodes/CircleUniqueOut.h" +#include "Nodes/CircleUnpackOut.h" #include "Nodes/CircleWhileOut.h" #include <loco/IR/Graph.h> diff --git a/compiler/luci/lang/include/luci/IR/CircleNodes.lst b/compiler/luci/lang/include/luci/IR/CircleNodes.lst index 914aa16e4..1472008df 100644 --- a/compiler/luci/lang/include/luci/IR/CircleNodes.lst +++ b/compiler/luci/lang/include/luci/IR/CircleNodes.lst @@ -116,6 +116,7 @@ CIRCLE_NODE(SQUEEZE, CircleSqueeze) CIRCLE_NODE(STRIDED_SLICE, CircleStridedSlice) CIRCLE_NODE(SUB, CircleSub) CIRCLE_NODE(SUM, CircleSum) +CIRCLE_NODE(SVDF, CircleSVDF) CIRCLE_NODE(TANH, CircleTanh) CIRCLE_NODE(TILE, CircleTile) CIRCLE_NODE(TOPK_V2, CircleTopKV2) @@ -132,12 +133,14 @@ CIRCLE_NODE(BCQ_FULLY_CONNECTED, CircleBCQFullyConnected) CIRCLE_NODE(BCQ_GATHER, CircleBCQGather) CIRCLE_NODE(INSTANCE_NORM, CircleInstanceNorm) // Virtual node(s) -CIRCLE_VNODE(CIRCLEBIDIRECTIONAL_SEQUENCE_LSTM_OUT, CircleBidirectionalSequenceLSTMOut) CIRCLE_VNODE(CIRCLECONST, CircleConst) CIRCLE_VNODE(CIRCLEINPUT, CircleInput) CIRCLE_VNODE(CIRCLEOUTPUT, CircleOutput) CIRCLE_VNODE(CIRCLEOUTPUTDUMMY, CircleOutputDummy) CIRCLE_VNODE(CIRCLEOUTPUTEXCLUDE, CircleOutputExclude) +CIRCLE_VNODE(CIRCLEVARIABLE, CircleVariable) +// Multi-output virtual nodes +CIRCLE_VNODE(CIRCLEBIDIRECTIONAL_SEQUENCE_LSTM_OUT, CircleBidirectionalSequenceLSTMOut) CIRCLE_VNODE(CIRCLECUSTOMOUT, CircleCustomOut) CIRCLE_VNODE(CIRCLEIFOUT, CircleIfOut) CIRCLE_VNODE(CIRCLENONMAXSUPPRESSIONV4OUT, CircleNonMaxSuppressionV4Out) diff --git a/compiler/luci/lang/include/luci/IR/CircleQuantParam.h b/compiler/luci/lang/include/luci/IR/CircleQuantParam.h index 694437303..8afc80a76 100644 --- a/compiler/luci/lang/include/luci/IR/CircleQuantParam.h +++ b/compiler/luci/lang/include/luci/IR/CircleQuantParam.h @@ -32,6 +32,10 @@ struct CircleQuantParam int32_t quantized_dimension{0}; }; +struct CircleNode; + +void copy_quantparam(const luci::CircleNode *src, luci::CircleNode *dst); + } // namespace luci #endif // __LUCI_IR_CIRCLEQUANTPARAM_H__ diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h index 2862cadb2..dc5aeb267 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h @@ -58,8 +58,12 @@ public: WeightsFormat weights_format(void) const { return _weights_format; } void weights_format(WeightsFormat weights_format) { _weights_format = weights_format; } + bool keep_num_dims(void) const { return _keep_num_dims; } + void keep_num_dims(bool keep_num_dims) { _keep_num_dims = keep_num_dims; } + private: WeightsFormat _weights_format{WeightsFormat::DEFAULT}; + bool _keep_num_dims{false}; }; } // namespace luci diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSVDF.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSVDF.h new file mode 100644 index 000000000..839d11e04 --- /dev/null +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSVDF.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_IR_CIRCLE_SVDF_H__ +#define __LUCI_IR_CIRCLE_SVDF_H__ + +#include "luci/IR/CircleNodeDecl.h" +#include "luci/IR/CircleOpcode.h" + +#include "luci/IR/LuciNodeMixins.h" + +namespace luci +{ + +/** + * @brief SVDF in Circle + */ +class CircleSVDF final : public FixedArityNode<5, CircleNodeImpl<CircleOpcode::SVDF>>, + public CircleNodeMixin<CircleNodeTrait::FusedActFunc> +{ +public: + CircleSVDF() = default; + +public: + loco::Node *input(void) const { return at(0)->node(); } + void input(loco::Node *node) { at(0)->node(node); } + + loco::Node *weight_feature(void) const { return at(1)->node(); } + void weight_feature(loco::Node *node) { at(1)->node(node); } + + loco::Node *weight_time(void) const { return at(2)->node(); } + void weight_time(loco::Node *node) { at(2)->node(node); } + + loco::Node *bias(void) const { return at(3)->node(); } + void bias(loco::Node *node) { at(3)->node(node); } + + loco::Node *input_activation_state(void) const { return at(4)->node(); } + void input_activation_state(loco::Node *node) { at(4)->node(node); } + +public: + bool asymmetric_quantize_inputs() const { return _asymmetric_quantize_inputs; } + void asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) + { + _asymmetric_quantize_inputs = asymmetric_quantize_inputs; + } + + int32_t svdf_rank() const { return _rank; } + void svdf_rank(int32_t svdf_rank) { _rank = svdf_rank; } + +private: + bool _asymmetric_quantize_inputs = false; + int32_t _rank = 0; +}; + +} // namespace luci + +#endif // __LUCI_IR_CIRCLE_SVDF_H__ diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleVariable.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleVariable.h new file mode 100644 index 000000000..8c15b66c9 --- /dev/null +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleVariable.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_IR_CIRCLE_VARIABLE_H__ +#define __LUCI_IR_CIRCLE_VARIABLE_H__ + +#include "luci/IR/CircleNodeDecl.h" +#include "luci/IR/CircleOpcode.h" + +#include "luci/IR/CircleNodeMixins.h" + +namespace luci +{ + +/** + * @brief Virtual CircleVariable in Circle for 'variable' Tensor + */ +class CircleVariable final : public FixedArityNode<0, CircleNodeImpl<CircleOpcode::CIRCLEVARIABLE>> +{ +public: + CircleVariable() = default; +}; + +} // namespace luci + +#endif // __LUCI_IR_CIRCLE_VARIABLE_H__ diff --git a/compiler/luci/lang/src/CircleQuantParam.cpp b/compiler/luci/lang/src/CircleQuantParam.cpp new file mode 100644 index 000000000..89671d3c3 --- /dev/null +++ b/compiler/luci/lang/src/CircleQuantParam.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/IR/CircleQuantParam.h" +#include "luci/IR/CircleNode.h" + +#include <memory> + +namespace luci +{ + +/** + * @brief copy CircleQuantParam of src to dst + */ +void copy_quantparam(const luci::CircleNode *src, luci::CircleNode *dst) +{ + auto q = src->quantparam(); + if (q == nullptr) + dst->quantparam(nullptr); + else + { + auto qparam = std::make_unique<luci::CircleQuantParam>(); + qparam->scale = q->scale; + qparam->zerop = q->zerop; + qparam->min = q->min; + qparam->max = q->max; + qparam->quantized_dimension = q->quantized_dimension; + + dst->quantparam(std::move(qparam)); + } +} + +} // namespace luci diff --git a/compiler/luci/lang/src/CircleQuantParam.test.cpp b/compiler/luci/lang/src/CircleQuantParam.test.cpp new file mode 100644 index 000000000..520ca05cc --- /dev/null +++ b/compiler/luci/lang/src/CircleQuantParam.test.cpp @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// NOTE any node will do for testing +#include "luci/IR/Nodes/CircleAdd.h" + +#include <loco/IR/Graph.h> + +#include <gtest/gtest.h> + +namespace +{ + +luci::CircleAdd *build_simple_add_graph(loco::Graph *g) +{ + auto node = g->nodes()->create<luci::CircleAdd>(); + + node->name("name"); + node->dtype(loco::DataType::FLOAT32); + node->rank(1); + node->dim(0).set(3); + node->shape_status(luci::ShapeStatus::VALID); + node->fusedActivationFunction(luci::FusedActFunc::NONE); + + auto qparam = std::make_unique<luci::CircleQuantParam>(); + qparam->scale = {1.0}; + qparam->zerop = {0}; + qparam->min = {0.0}; + qparam->max = {1.0}; + qparam->quantized_dimension = 0; + node->quantparam(std::move(qparam)); + + return node; +} + +} // namespace + +TEST(CircleNodeCloneTest, copy_quantparam) +{ + auto g = loco::make_graph(); + auto node = build_simple_add_graph(g.get()); + + auto copy = g->nodes()->create<luci::CircleAdd>(); + luci::copy_quantparam(node, copy); + + const auto *qparam_node = node->quantparam(); + const auto *qparam_copy = copy->quantparam(); + ASSERT_EQ(qparam_node->scale, qparam_copy->scale); + ASSERT_EQ(qparam_node->zerop, qparam_copy->zerop); + ASSERT_EQ(qparam_node->quantized_dimension, qparam_copy->quantized_dimension); +} + +TEST(CircleNodeCloneTest, copy_quantparam_NEG) +{ + auto g = loco::make_graph(); + auto node = build_simple_add_graph(g.get()); + + node->quantparam(nullptr); + + auto copy = g->nodes()->create<luci::CircleAdd>(); + luci::copy_quantparam(node, copy); + + const auto *qparam_copy = copy->quantparam(); + ASSERT_EQ(qparam_copy, nullptr); +} diff --git a/compiler/luci/lang/src/Nodes/CircleFullyConnected.test.cpp b/compiler/luci/lang/src/Nodes/CircleFullyConnected.test.cpp index bb0e3c51b..15a780085 100644 --- a/compiler/luci/lang/src/Nodes/CircleFullyConnected.test.cpp +++ b/compiler/luci/lang/src/Nodes/CircleFullyConnected.test.cpp @@ -32,6 +32,7 @@ TEST(CircleFullyConnectedTest, constructor) ASSERT_EQ(nullptr, fc_node.weights()); ASSERT_EQ(nullptr, fc_node.bias()); ASSERT_EQ(luci::FusedActFunc::UNDEFINED, fc_node.fusedActivationFunction()); + ASSERT_EQ(false, fc_node.keep_num_dims()); } TEST(CircleFullyConnectedTest, input_NEG) diff --git a/compiler/luci/lang/src/Nodes/CircleSVDF.test.cpp b/compiler/luci/lang/src/Nodes/CircleSVDF.test.cpp new file mode 100644 index 000000000..833ae0732 --- /dev/null +++ b/compiler/luci/lang/src/Nodes/CircleSVDF.test.cpp @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/IR/Nodes/CircleSVDF.h" + +#include "luci/IR/CircleDialect.h" +#include "luci/IR/CircleNodeVisitor.h" + +#include <gtest/gtest.h> + +TEST(CircleSVDFTest, constructor) +{ + luci::CircleSVDF svdf_node; + + ASSERT_EQ(luci::CircleDialect::get(), svdf_node.dialect()); + ASSERT_EQ(luci::CircleOpcode::SVDF, svdf_node.opcode()); + + ASSERT_EQ(nullptr, svdf_node.input()); + ASSERT_EQ(nullptr, svdf_node.weight_feature()); + ASSERT_EQ(nullptr, svdf_node.weight_time()); + ASSERT_EQ(nullptr, svdf_node.bias()); + ASSERT_EQ(nullptr, svdf_node.input_activation_state()); + + ASSERT_EQ(false, svdf_node.asymmetric_quantize_inputs()); + ASSERT_EQ(0, svdf_node.svdf_rank()); +} + +TEST(CircleSVDFTest, input_NEG) +{ + luci::CircleSVDF svdf_node; + luci::CircleSVDF node; + + svdf_node.input(&node); + svdf_node.weight_feature(&node); + svdf_node.weight_time(&node); + svdf_node.bias(&node); + svdf_node.input_activation_state(&node); + + ASSERT_NE(nullptr, svdf_node.input()); + ASSERT_NE(nullptr, svdf_node.weight_feature()); + ASSERT_NE(nullptr, svdf_node.weight_time()); + ASSERT_NE(nullptr, svdf_node.bias()); + ASSERT_NE(nullptr, svdf_node.input_activation_state()); + + svdf_node.input(nullptr); + svdf_node.weight_feature(nullptr); + svdf_node.weight_time(nullptr); + svdf_node.bias(nullptr); + svdf_node.input_activation_state(nullptr); + + ASSERT_EQ(nullptr, svdf_node.input()); + ASSERT_EQ(nullptr, svdf_node.weight_feature()); + ASSERT_EQ(nullptr, svdf_node.weight_time()); + ASSERT_EQ(nullptr, svdf_node.bias()); + ASSERT_EQ(nullptr, svdf_node.input_activation_state()); +} + +TEST(CircleSVDFTest, arity_NEG) +{ + luci::CircleSVDF svdf_node; + + ASSERT_NO_THROW(svdf_node.arg(4)); + ASSERT_THROW(svdf_node.arg(5), std::out_of_range); +} + +TEST(CircleSVDFTest, visit_mutable_NEG) +{ + struct TestVisitor final : public luci::CircleNodeMutableVisitor<void> + { + }; + + luci::CircleSVDF svdf_node; + + TestVisitor tv; + ASSERT_THROW(svdf_node.accept(&tv), std::exception); +} + +TEST(CircleSVDFTest, visit_NEG) +{ + struct TestVisitor final : public luci::CircleNodeVisitor<void> + { + }; + + luci::CircleSVDF svdf_node; + + TestVisitor tv; + ASSERT_THROW(svdf_node.accept(&tv), std::exception); +} diff --git a/compiler/luci/lang/src/Nodes/CircleVariable.test.cpp b/compiler/luci/lang/src/Nodes/CircleVariable.test.cpp new file mode 100644 index 000000000..e1864f8da --- /dev/null +++ b/compiler/luci/lang/src/Nodes/CircleVariable.test.cpp @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/IR/Nodes/CircleVariable.h" + +#include "luci/IR/CircleDialect.h" +#include "luci/IR/CircleNodeVisitor.h" + +#include <gtest/gtest.h> + +TEST(CircleVariableTest, constructor) +{ + luci::CircleVariable var_node; + + ASSERT_EQ(luci::CircleDialect::get(), var_node.dialect()); + ASSERT_EQ(luci::CircleOpcode::CIRCLEVARIABLE, var_node.opcode()); +} + +TEST(CircleVariableTest, arity_NEG) +{ + luci::CircleVariable var_node; + + ASSERT_THROW(var_node.arg(0), std::out_of_range); +} + +TEST(CircleVariableTest, visit_mutable_NEG) +{ + struct TestVisitor final : public luci::CircleNodeMutableVisitor<void> + { + }; + + luci::CircleVariable var_node; + + TestVisitor tv; + ASSERT_THROW(var_node.accept(&tv), std::exception); +} + +TEST(CircleVariableTest, visit_NEG) +{ + struct TestVisitor final : public luci::CircleNodeVisitor<void> + { + }; + + luci::CircleVariable var_node; + + TestVisitor tv; + ASSERT_THROW(var_node.accept(&tv), std::exception); +} diff --git a/compiler/luci/logex/CMakeLists.txt b/compiler/luci/logex/CMakeLists.txt index aed9fb79b..b8a2111dd 100644 --- a/compiler/luci/logex/CMakeLists.txt +++ b/compiler/luci/logex/CMakeLists.txt @@ -1,5 +1,7 @@ # TODO Find how to test logging-ex utility file(GLOB_RECURSE SOURCES "src/*.cpp") +file(GLOB_RECURSE TESTS "src/*.test.cpp") +list(REMOVE_ITEM SOURCES ${TESTS}) if (NOT LUCI_LIBRARY_TYPE) set(LUCI_LIBRARY_TYPE "SHARED") @@ -13,7 +15,17 @@ target_link_libraries(luci_logex PRIVATE luci_log) target_link_libraries(luci_logex PRIVATE luci_lang) target_link_libraries(luci_logex PRIVATE hermes_std) target_link_libraries(luci_logex PRIVATE nncc_common) -target_link_libraries(luci_logex PRIVATE pepper_str) install(TARGETS luci_logex DESTINATION lib) install(DIRECTORY include/ DESTINATION include FILES_MATCHING PATTERN "*.h") + +if(NOT ENABLE_TEST) + return() +endif(NOT ENABLE_TEST) + +nnas_find_package(GTest REQUIRED) + +GTest_AddTest(luci_logex_test ${TESTS}) +target_include_directories(luci_logex_test PRIVATE src) +target_link_libraries(luci_logex_test luci_logex) +target_link_libraries(luci_logex_test luci_lang) diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp b/compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp new file mode 100644 index 000000000..eff0830b4 --- /dev/null +++ b/compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp @@ -0,0 +1,265 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License") + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleNodeSummaryBuilder.h" +#include "CircleNodeSummaryBuilders.h" + +#include <luci/IR/CircleDialect.h> + +#include <memory> + +namespace +{ + +std::string circle_opname(luci::CircleOpcode opcode) +{ + static const std::string prefix{"circle."}; + + switch (opcode) + { +#define CIRCLE_NODE(OPCODE, CLASS) \ + case luci::CircleOpcode::OPCODE: \ + return prefix + #OPCODE; +#define CIRCLE_VNODE CIRCLE_NODE +#include <luci/IR/CircleNodes.lst> +#undef CIRCLE_VNODE +#undef CIRCLE_NODE + default: + break; + }; + + return prefix + "Invalid"; +} + +} // namespace + +namespace luci +{ + +bool CircleNodeSummaryBuilder::build(const loco::Node *node, const locop::SymbolTable *tbl, + locop::NodeSummary &s) +{ + if (node->dialect() != luci::CircleDialect::get()) + return false; + + auto ptr_to_str = [](const void *ptr) { + std::stringstream ss; + ss << ptr; + return ss.str(); + }; + + auto circle_node = loco::must_cast<const luci::CircleNode *>(node); + if (const auto builder = create_builder(circle_node)) + { + if (!builder->validate(circle_node)) + { + s.state(locop::NodeDesc::State::Invalid); + return false; + } + + auto input_names = builder->get_input_names(circle_node); + assert(node->arity() == input_names.size()); + for (uint32_t i = 0; i < node->arity(); ++i) + s.args().append(input_names.at(i), tbl->lookup(node->arg(i))); + + builder->build_attributes(circle_node, s); + builder->update_status(s); + + s.opname(circle_opname(circle_node->opcode())); + s.comments().append("[" + circle_node->name() + "] = " + ptr_to_str(node)); + + return true; + } + else + { + // When SummaryBuilder is not implemented, return false + return false; + } +} + +bool CircleNodeSummaryBuilder::validate(const luci::CircleNode *) { return true; } + +std::vector<std::string> CircleNodeSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + // Return empty names for default + return std::vector<std::string>(); +} + +void CircleNodeSummaryBuilder::build_attributes(const luci::CircleNode *, locop::NodeSummary &) +{ + // Do nothing for default +} + +void CircleNodeSummaryBuilder::update_status(locop::NodeSummary &s) +{ + s.state(locop::NodeDesc::State::Complete); +} + +std::unique_ptr<CircleNodeSummaryBuilder> +CircleNodeSummaryBuilder::create_builder(const luci::CircleNode *node) +{ + switch (node->opcode()) + { +#define CIRCLE_NODE(OPCODE, CLASS) \ + case luci::CircleOpcode::OPCODE: \ + { \ + return std::make_unique<CLASS>(); \ + } + + CIRCLE_NODE(ABS, CircleAbsSummaryBuilder) + CIRCLE_NODE(ADD, CircleAddSummaryBuilder) + CIRCLE_NODE(ADD_N, CircleAddNSummaryBuilder) + CIRCLE_NODE(ARG_MAX, CircleArgMaxSummaryBuilder) + CIRCLE_NODE(ARG_MIN, CircleArgMinSummaryBuilder) + CIRCLE_NODE(AVERAGE_POOL_2D, CircleAveragePool2DSummaryBuilder) + CIRCLE_NODE(BATCH_MATMUL, CircleBatchMatMulSummaryBuilder) + CIRCLE_NODE(BATCH_TO_SPACE_ND, CircleBatchToSpaceNDSummaryBuilder) + CIRCLE_NODE(BCQ_FULLY_CONNECTED, CircleBCQFullyConnectedSummaryBuilder) + CIRCLE_NODE(BCQ_GATHER, CircleBCQGatherSummaryBuilder) + CIRCLE_NODE(BIDIRECTIONAL_SEQUENCE_LSTM, CircleBidirectionalSequenceLSTMSummaryBuilder) + CIRCLE_NODE(CAST, CircleCastSummaryBuilder) + CIRCLE_NODE(CEIL, CircleCeilSummaryBuilder) + CIRCLE_NODE(CONCATENATION, CircleConcatenationSummaryBuilder) + CIRCLE_NODE(CIRCLECONST, CircleConstSummaryBuilder) + CIRCLE_NODE(CONV_2D, CircleConv2DSummaryBuilder) + CIRCLE_NODE(COS, CircleCosSummaryBuilder) + CIRCLE_NODE(CUSTOM, CircleCustomSummaryBuilder) + CIRCLE_NODE(DEPTH_TO_SPACE, CircleDepthToSpaceSummaryBuilder) + CIRCLE_NODE(DEPTHWISE_CONV_2D, CircleDepthwiseConv2DSummaryBuilder) + CIRCLE_NODE(DEQUANTIZE, CircleDequantizeSummaryBuilder) + CIRCLE_NODE(DIV, CircleDivSummaryBuilder) + CIRCLE_NODE(ELU, CircleEluSummaryBuilder) + CIRCLE_NODE(EQUAL, CircleEqualSummaryBuilder) + CIRCLE_NODE(EXP, CircleExpSummaryBuilder) + CIRCLE_NODE(EXPAND_DIMS, CircleExpandDimsSummaryBuilder) + CIRCLE_NODE(FAKE_QUANT, CircleFakeQuantSummaryBuilder) + CIRCLE_NODE(FILL, CircleFillSummaryBuilder) + CIRCLE_NODE(FLOOR, CircleFloorSummaryBuilder) + CIRCLE_NODE(FLOOR_DIV, CircleFloorDivSummaryBuilder) + CIRCLE_NODE(FLOOR_MOD, CircleFloorModSummaryBuilder) + CIRCLE_NODE(FULLY_CONNECTED, CircleFullyConnectedSummaryBuilder) + CIRCLE_NODE(GATHER, CircleGatherSummaryBuilder) + CIRCLE_NODE(GATHER_ND, CircleGatherNdSummaryBuilder) + CIRCLE_NODE(GREATER, CircleGreaterSummaryBuilder) + CIRCLE_NODE(GREATER_EQUAL, CircleGreaterEqualSummaryBuilder) + CIRCLE_NODE(IF, CircleIfSummaryBuilder) + CIRCLE_NODE(INSTANCE_NORM, CircleInstanceNormSummaryBuilder) + CIRCLE_NODE(L2_NORMALIZATION, CircleL2NormalizeSummaryBuilder) + CIRCLE_NODE(L2_POOL_2D, CircleL2Pool2DSummaryBuilder) + CIRCLE_NODE(LEAKY_RELU, CircleLeakyReluSummaryBuilder) + CIRCLE_NODE(LESS, CircleLessSummaryBuilder) + CIRCLE_NODE(LESS_EQUAL, CircleLessEqualSummaryBuilder) + CIRCLE_NODE(LOCAL_RESPONSE_NORMALIZATION, CircleLocalResponseNormalizationSummaryBuilder) + CIRCLE_NODE(LOG, CircleLogSummaryBuilder) + CIRCLE_NODE(LOGICAL_AND, CircleLogicalAndSummaryBuilder) + CIRCLE_NODE(LOGICAL_NOT, CircleLogicalNotSummaryBuilder) + CIRCLE_NODE(LOGICAL_OR, CircleLogicalOrSummaryBuilder) + CIRCLE_NODE(LOGISTIC, CircleLogisticSummaryBuilder) + CIRCLE_NODE(LOG_SOFTMAX, CircleLogSoftmaxSummaryBuilder) + CIRCLE_NODE(MATRIX_DIAG, CircleMatrixDiagSummaryBuilder) + CIRCLE_NODE(MATRIX_SET_DIAG, CircleMatrixSetDiagSummaryBuilder) + CIRCLE_NODE(MAXIMUM, CircleMaximumSummaryBuilder) + CIRCLE_NODE(MAX_POOL_2D, CircleMaxPool2DSummaryBuilder) + CIRCLE_NODE(MEAN, CircleMeanSummaryBuilder) + CIRCLE_NODE(MINIMUM, CircleMinimumSummaryBuilder) + CIRCLE_NODE(MIRROR_PAD, CircleMirrorPadSummaryBuilder) + CIRCLE_NODE(MUL, CircleMulSummaryBuilder) + CIRCLE_NODE(NEG, CircleNegSummaryBuilder) + CIRCLE_NODE(NON_MAX_SUPPRESSION_V4, CircleNonMaxSuppressionV4SummaryBuilder) + CIRCLE_NODE(NON_MAX_SUPPRESSION_V5, CircleNonMaxSuppressionV5SummaryBuilder) + CIRCLE_NODE(NOT_EQUAL, CircleNotEqualSummaryBuilder) + CIRCLE_NODE(ONE_HOT, CircleOneHotSummaryBuilder) + CIRCLE_NODE(PACK, CirclePackSummaryBuilder) + CIRCLE_NODE(PAD, CirclePadSummaryBuilder) + CIRCLE_NODE(PADV2, CirclePadV2SummaryBuilder) + CIRCLE_NODE(POW, CirclePowSummaryBuilder) + CIRCLE_NODE(PRELU, CirclePReluSummaryBuilder) + CIRCLE_NODE(QUANTIZE, CircleQuantizeSummaryBuilder) + CIRCLE_NODE(RANGE, CircleRangeSummaryBuilder) + CIRCLE_NODE(RANK, CircleRankSummaryBuilder) + CIRCLE_NODE(REDUCE_ANY, CircleReduceAnySummaryBuilder) + CIRCLE_NODE(REDUCE_MAX, CircleReduceMaxSummaryBuilder) + CIRCLE_NODE(REDUCE_MIN, CircleReduceMinSummaryBuilder) + CIRCLE_NODE(REDUCE_PROD, CircleReduceProdSummaryBuilder) + CIRCLE_NODE(RELU, CircleReluSummaryBuilder) + CIRCLE_NODE(RELU6, CircleRelu6SummaryBuilder) + CIRCLE_NODE(RELU_N1_TO_1, CircleReluN1To1SummaryBuilder) + CIRCLE_NODE(RESHAPE, CircleReshapeSummaryBuilder) + CIRCLE_NODE(RESIZE_BILINEAR, CircleResizeBilinearSummaryBuilder) + CIRCLE_NODE(RESIZE_NEAREST_NEIGHBOR, CircleResizeNearestNeighborSummaryBuilder) + CIRCLE_NODE(REVERSE_SEQUENCE, CircleReverseSequenceSummaryBuilder) + CIRCLE_NODE(REVERSE_V2, CircleReverseV2SummaryBuilder) + CIRCLE_NODE(ROUND, CircleRoundSummaryBuilder) + CIRCLE_NODE(RSQRT, CircleRsqrtSummaryBuilder) + CIRCLE_NODE(SCATTER_ND, CircleScatterNdSummaryBuilder) + CIRCLE_NODE(SEGMENT_SUM, CircleSegmentSumSummaryBuilder) + CIRCLE_NODE(SELECT, CircleSelectSummaryBuilder) + CIRCLE_NODE(SELECT_V2, CircleSelectV2SummaryBuilder) + CIRCLE_NODE(SHAPE, CircleShapeSummaryBuilder) + CIRCLE_NODE(SIN, CircleSinSummaryBuilder) + CIRCLE_NODE(SLICE, CircleSliceSummaryBuilder) + CIRCLE_NODE(SOFTMAX, CircleSoftmaxSummaryBuilder) + CIRCLE_NODE(SPACE_TO_BATCH_ND, CircleSpaceToBatchNDSummaryBuilder) + CIRCLE_NODE(SPACE_TO_DEPTH, CircleSpaceToDepthSummaryBuilder) + CIRCLE_NODE(SPARSE_TO_DENSE, CircleSparseToDenseSummaryBuilder) + CIRCLE_NODE(SPLIT, CircleSplitSummaryBuilder) + CIRCLE_NODE(SPLIT_V, CircleSplitVSummaryBuilder) + CIRCLE_NODE(SQRT, CircleSqrtSummaryBuilder) + CIRCLE_NODE(SQUARE, CircleSquareSummaryBuilder) + CIRCLE_NODE(SQUARED_DIFFERENCE, CircleSquaredDifferenceSummaryBuilder) + CIRCLE_NODE(SQUEEZE, CircleSqueezeSummaryBuilder) + CIRCLE_NODE(STRIDED_SLICE, CircleStridedSliceSummaryBuilder) + CIRCLE_NODE(SUB, CircleSubSummaryBuilder) + CIRCLE_NODE(SUM, CircleSumSummaryBuilder) + CIRCLE_NODE(SVDF, CircleSVDFSummaryBuilder) + CIRCLE_NODE(TANH, CircleTanhSummaryBuilder) + CIRCLE_NODE(TILE, CircleTileSummaryBuilder) + CIRCLE_NODE(TOPK_V2, CircleTopKV2SummaryBuilder) + CIRCLE_NODE(TRANSPOSE, CircleTransposeSummaryBuilder) + CIRCLE_NODE(TRANSPOSE_CONV, CircleTransposeConvSummaryBuilder) + CIRCLE_NODE(UNIDIRECTIONAL_SEQUENCE_LSTM, CircleUnidirectionalSequenceLSTMSummaryBuilder) + CIRCLE_NODE(UNIQUE, CircleUniqueSummaryBuilder) + CIRCLE_NODE(UNPACK, CircleUnpackSummaryBuilder) + CIRCLE_NODE(WHERE, CircleWhereSummaryBuilder) + CIRCLE_NODE(WHILE, CircleWhileSummaryBuilder) + CIRCLE_NODE(ZEROS_LIKE, CircleZerosLikeSummaryBuilder) + + CIRCLE_NODE(CIRCLEBIDIRECTIONAL_SEQUENCE_LSTM_OUT, + CircleBidirectionalSequenceLSTMOutSummaryBuilder) + CIRCLE_NODE(CIRCLECUSTOMOUT, CircleCustomOutSummaryBuilder) + CIRCLE_NODE(CIRCLEIFOUT, CircleIfOutSummaryBuilder) + CIRCLE_NODE(CIRCLEINPUT, CircleInputSummaryBuilder) + CIRCLE_NODE(CIRCLENONMAXSUPPRESSIONV4OUT, CircleNonMaxSuppressionV4OutSummaryBuilder) + CIRCLE_NODE(CIRCLENONMAXSUPPRESSIONV5OUT, CircleNonMaxSuppressionV5OutSummaryBuilder) + CIRCLE_NODE(CIRCLEOUTPUT, CircleOutputSummaryBuilder) + CIRCLE_NODE(CIRCLEOUTPUTDUMMY, CircleOutputDummySummaryBuilder) + CIRCLE_NODE(CIRCLEOUTPUTEXCLUDE, CircleOutputExcludeSummaryBuilder) + CIRCLE_NODE(CIRCLESPLITOUT, CircleSplitOutSummaryBuilder) + CIRCLE_NODE(CIRCLESPLITVOUT, CircleSplitVOutSummaryBuilder) + CIRCLE_NODE(CIRCLETOPKV2OUT, CircleTopKV2OutSummaryBuilder) + CIRCLE_NODE(CIRCLEUNIQUEOUT, CircleUniqueOutSummaryBuilder) + CIRCLE_NODE(CIRCLEUNPACKOUT, CircleUnpackOutSummaryBuilder) + CIRCLE_NODE(CIRCLEVARIABLE, CircleVariableSummaryBuilder) + CIRCLE_NODE(CIRCLEWHILEOUT, CircleWhileOutSummaryBuilder) + + default: + return nullptr; + +#undef CIRCLE_NODE + } +} + +} // namespace luci diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilder.h b/compiler/luci/logex/src/CircleNodeSummaryBuilder.h new file mode 100644 index 000000000..e21d77310 --- /dev/null +++ b/compiler/luci/logex/src/CircleNodeSummaryBuilder.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_LOGEX_CIRCLE_NODE_SUMMARY_BUILDER__ +#define __LUCI_LOGEX_CIRCLE_NODE_SUMMARY_BUILDER__ + +#include <luci/IR/CircleNode.h> +#include <locop/NodeSummary.h> +#include <locop/SymbolTable.h> + +#include <memory> +#include <sstream> +#include <vector> + +namespace luci +{ + +class CircleNodeSummaryBuilder +{ +public: + bool build(const loco::Node *node, const locop::SymbolTable *tbl, locop::NodeSummary &s); + +private: + /** + * @brief Template methods for building node summary. + * Default behavior is building a node which has no input. + */ + virtual bool validate(const luci::CircleNode *node); + virtual std::vector<std::string> get_input_names(const luci::CircleNode *node); + virtual void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); + virtual void update_status(locop::NodeSummary &s); + +private: + std::unique_ptr<CircleNodeSummaryBuilder> create_builder(const luci::CircleNode *node); +}; + +} // namespace luci + +#endif // __LUCI_LOGEX_CIRCLE_NODE_SUMMARY_BUILDER__ diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilder.test.cpp b/compiler/luci/logex/src/CircleNodeSummaryBuilder.test.cpp new file mode 100644 index 000000000..89ea213e0 --- /dev/null +++ b/compiler/luci/logex/src/CircleNodeSummaryBuilder.test.cpp @@ -0,0 +1,309 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleNodeSummaryBuilder.h" + +#include <luci/IR/CircleNodes.h> +#include <locop/NodeSummary.h> +#include <locop/SymbolTable.h> + +#include <gtest/gtest.h> + +namespace +{ + +class MockSymbolTable : public locop::SymbolTable +{ + std::string lookup(const loco::Node *) const override + { + return "Do nothing because it is mocking Symbol Table!"; + } +}; + +class CircleNodeSummaryBuilderTest : public ::testing::Test +{ +protected: + bool mock_build(const loco::Node *node) + { + return luci::CircleNodeSummaryBuilder().build(node, &_tbl, _s); + } + +protected: + MockSymbolTable _tbl; + locop::NodeSummary _s; +}; + +} // namespace + +TEST_F(CircleNodeSummaryBuilderTest, Add_validate) +{ + luci::CircleAdd node; + node.fusedActivationFunction(luci::FusedActFunc::RELU); + EXPECT_TRUE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, Add_validate_fused_NEG) +{ + luci::CircleAdd node; + node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + EXPECT_FALSE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, AveragePool2D_validate) +{ + luci::CircleAveragePool2D node; + node.fusedActivationFunction(luci::FusedActFunc::RELU); + node.padding(luci::Padding::SAME); + EXPECT_TRUE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, AveragePool2D_validate_fused_NEG) +{ + luci::CircleAveragePool2D node; + node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + node.padding(luci::Padding::SAME); + EXPECT_FALSE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, AveragePool2D_validate_padding_NEG) +{ + luci::CircleAveragePool2D node; + node.fusedActivationFunction(luci::FusedActFunc::RELU); + node.padding(luci::Padding::UNDEFINED); + EXPECT_FALSE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, BCQFullyConnected_validate) +{ + luci::CircleBCQFullyConnected node; + node.fusedActivationFunction(luci::FusedActFunc::RELU); + EXPECT_TRUE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, BCQFullyConnected_validate_fused_NEG) +{ + luci::CircleBCQFullyConnected node; + node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + EXPECT_FALSE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, Concatenation_validate) +{ + luci::CircleConcatenation node(2); + node.fusedActivationFunction(luci::FusedActFunc::RELU); + EXPECT_TRUE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, Concatenation_validate_fused_NEG) +{ + luci::CircleConcatenation node(2); + node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + EXPECT_FALSE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, Conv2D_validate) +{ + luci::CircleConv2D node; + node.fusedActivationFunction(luci::FusedActFunc::RELU); + node.padding(luci::Padding::SAME); + EXPECT_TRUE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, Conv2D_validate_fused_NEG) +{ + luci::CircleConv2D node; + node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + node.padding(luci::Padding::SAME); + EXPECT_FALSE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, Conv2D_validate_padding_NEG) +{ + luci::CircleConv2D node; + node.fusedActivationFunction(luci::FusedActFunc::RELU); + node.padding(luci::Padding::UNDEFINED); + EXPECT_FALSE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, DepthwiseConv2D_validate) +{ + luci::CircleDepthwiseConv2D node; + node.fusedActivationFunction(luci::FusedActFunc::RELU); + node.padding(luci::Padding::SAME); + EXPECT_TRUE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, DepthwiseConv2D_validate_fused_NEG) +{ + luci::CircleDepthwiseConv2D node; + node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + node.padding(luci::Padding::SAME); + EXPECT_FALSE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, DepthwiseConv2D_validate_padding_NEG) +{ + luci::CircleDepthwiseConv2D node; + node.fusedActivationFunction(luci::FusedActFunc::RELU); + node.padding(luci::Padding::UNDEFINED); + EXPECT_FALSE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, FullyConnected_validate) +{ + luci::CircleFullyConnected node; + node.fusedActivationFunction(luci::FusedActFunc::RELU); + EXPECT_TRUE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, FullyConnected_validate_fused_NEG) +{ + luci::CircleFullyConnected node; + node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + EXPECT_FALSE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, InstanceNorm_validate) +{ + luci::CircleInstanceNorm node; + node.fusedActivationFunction(luci::FusedActFunc::RELU); + EXPECT_TRUE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, InstanceNorm_validate_fused_NEG) +{ + luci::CircleInstanceNorm node; + node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + EXPECT_FALSE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, L2Normalize_validate) +{ + luci::CircleL2Normalize node; + node.fusedActivationFunction(luci::FusedActFunc::RELU); + EXPECT_TRUE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, L2Normalize_validate_fused_NEG) +{ + luci::CircleL2Normalize node; + node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + EXPECT_FALSE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, L2Pool2D_validate) +{ + luci::CircleL2Pool2D node; + node.fusedActivationFunction(luci::FusedActFunc::RELU); + node.padding(luci::Padding::SAME); + EXPECT_TRUE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, L2Pool2D_validate_fused_NEG) +{ + luci::CircleL2Pool2D node; + node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + node.padding(luci::Padding::SAME); + EXPECT_FALSE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, L2Pool2D_validate_padding_NEG) +{ + luci::CircleL2Pool2D node; + node.fusedActivationFunction(luci::FusedActFunc::RELU); + node.padding(luci::Padding::UNDEFINED); + EXPECT_FALSE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, MaxPool2D_validate) +{ + luci::CircleMaxPool2D node; + node.fusedActivationFunction(luci::FusedActFunc::RELU); + node.padding(luci::Padding::SAME); + EXPECT_TRUE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, MaxPool2D_validate_fused_NEG) +{ + luci::CircleMaxPool2D node; + node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + node.padding(luci::Padding::SAME); + EXPECT_FALSE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, MaxPool2D_validate_padding_NEG) +{ + luci::CircleMaxPool2D node; + node.fusedActivationFunction(luci::FusedActFunc::RELU); + node.padding(luci::Padding::UNDEFINED); + EXPECT_FALSE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, MirrorPad_validate) +{ + luci::CircleMirrorPad node; + node.mode(luci::MirrorPadMode::REFLECT); + EXPECT_TRUE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, MirrorPad_validate_mirror_padding_NEG) +{ + luci::CircleMirrorPad node; + node.mode(luci::MirrorPadMode::UNDEFINED); + EXPECT_FALSE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, Mul_validate) +{ + luci::CircleMul node; + node.fusedActivationFunction(luci::FusedActFunc::RELU); + EXPECT_TRUE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, Mul_validate_fused_NEG) +{ + luci::CircleMul node; + node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + EXPECT_FALSE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, SVDF_validate) +{ + luci::CircleSVDF node; + node.fusedActivationFunction(luci::FusedActFunc::RELU); + EXPECT_TRUE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, SVDF_validate_fused_NEG) +{ + luci::CircleSVDF node; + node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + EXPECT_FALSE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, TransposeConv_validate) +{ + luci::CircleTransposeConv node; + node.padding(luci::Padding::SAME); + EXPECT_TRUE(mock_build(&node)); +} + +TEST_F(CircleNodeSummaryBuilderTest, TransposeConv_validate_padding_NEG) +{ + luci::CircleTransposeConv node; + node.padding(luci::Padding::UNDEFINED); + EXPECT_FALSE(mock_build(&node)); +} diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp b/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp new file mode 100644 index 000000000..6df9270e3 --- /dev/null +++ b/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp @@ -0,0 +1,1128 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleNodeSummaryBuilders.h" + +#include <luci/IR/CircleNode.h> +#include <luci/IR/CircleNodes.h> +#include <loco/IR/Node.h> + +#include <string> +#include <vector> + +namespace +{ + +std::string to_str(loco::DataType type) +{ + switch (type) + { + case loco::DataType::U8: + return "UINT8"; + case loco::DataType::U16: + return "UINT16"; + case loco::DataType::U32: + return "UINT32"; + case loco::DataType::U64: + return "UINT64"; + + case loco::DataType::S8: + return "INT8"; + case loco::DataType::S16: + return "INT16"; + case loco::DataType::S32: + return "INT32"; + case loco::DataType::S64: + return "INT64"; + + case loco::DataType::FLOAT16: + return "FLOAT16"; + case loco::DataType::FLOAT32: + return "FLOAT32"; + case loco::DataType::FLOAT64: + return "FLOAT64"; + + case loco::DataType::BOOL: + return "BOOL"; + + default: + return "Error"; + } +} + +std::string to_str(bool value) { return value ? "true" : "false"; } + +std::string to_str(luci::FusedActFunc fused) +{ + switch (fused) + { + case luci::FusedActFunc::NONE: + return "NONE"; + case luci::FusedActFunc::RELU: + return "RELU"; + case luci::FusedActFunc::RELU_N1_TO_1: + return "RELU_N1_TO_1"; + case luci::FusedActFunc::RELU6: + return "RELU6"; + case luci::FusedActFunc::TANH: + return "TANH"; + case luci::FusedActFunc::SIGN_BIT: + return "SIGN_BIT"; + default: + return "Error"; + } +} + +std::string to_str(luci::Padding padding) +{ + switch (padding) + { + case luci::Padding::SAME: + return "SAME"; + case luci::Padding::VALID: + return "VALID"; + default: + return "Error"; + } +} + +std::string to_str(const luci::Stride *stride) +{ + return std::to_string(stride->h()) + "," + std::to_string(stride->w()); +} + +std::string to_str(const luci::Filter *filter) +{ + return std::to_string(filter->h()) + "," + std::to_string(filter->w()); +} + +std::string to_str(luci::MirrorPadMode mode) +{ + switch (mode) + { + case luci::MirrorPadMode::REFLECT: + return "REFLECT"; + case luci::MirrorPadMode::SYMMETRIC: + return "SYMMETRIC"; + default: + return "Error"; + } +} + +} // namespace + +namespace luci +{ + +std::vector<std::string> CircleNodeWithXSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"x"}; +} + +std::vector<std::string> +CircleNodeWithINPUTSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input"}; +} + +std::vector<std::string> CircleNodeWithXYSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"x", "y"}; +} + +std::vector<std::string> +CircleNodeWithFEATURESSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"features"}; +} + +} // namespace luci + +namespace luci +{ + +bool CircleAddSummaryBuilder::validate(const luci::CircleNode *node) +{ + auto add = loco::must_cast<const luci::CircleAdd *>(node); + if (add->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return false; + + return true; +} + +void CircleAddSummaryBuilder::build_attributes(const luci::CircleNode *node, locop::NodeSummary &s) +{ + auto add = loco::must_cast<const luci::CircleAdd *>(node); + s.args().append("fused_activation_function", to_str(add->fusedActivationFunction())); +} + +std::vector<std::string> CircleAddNSummaryBuilder::get_input_names(const luci::CircleNode *node) +{ + return std::vector<std::string>(node->arity(), "inputs"); +} + +std::vector<std::string> CircleArgMaxSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "dimension"}; +} + +void CircleArgMaxSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto argmax = loco::must_cast<const luci::CircleArgMax *>(node); + s.args().append("output_type", to_str(argmax->output_type())); +} + +std::vector<std::string> CircleArgMinSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "dimension"}; +} + +void CircleArgMinSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto argmin = loco::must_cast<const luci::CircleArgMin *>(node); + s.args().append("output_type", to_str(argmin->output_type())); +} + +bool CircleAveragePool2DSummaryBuilder::validate(const luci::CircleNode *node) +{ + auto avgpool = loco::must_cast<const luci::CircleAveragePool2D *>(node); + if (avgpool->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return false; + if (avgpool->padding() == luci::Padding::UNDEFINED) + return false; + + return true; +} + +std::vector<std::string> +CircleAveragePool2DSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"value"}; +} + +void CircleAveragePool2DSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto avgpool = loco::must_cast<const luci::CircleAveragePool2D *>(node); + s.args().append("filter(h,w)", to_str(avgpool->filter())); + s.args().append("stride(h,w)", to_str(avgpool->stride())); + s.args().append("padding", to_str(avgpool->padding())); + s.args().append("fused_activation_function", to_str(avgpool->fusedActivationFunction())); +} + +void CircleBatchMatMulSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto batchmatmul = loco::must_cast<const luci::CircleBatchMatMul *>(node); + s.args().append("adj_x", to_str(batchmatmul->adj_x())); + s.args().append("adj_y", to_str(batchmatmul->adj_y())); +} + +std::vector<std::string> +CircleBatchToSpaceNDSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "block_shape", "crops"}; +} + +bool CircleBCQFullyConnectedSummaryBuilder::validate(const luci::CircleNode *node) +{ + auto bcq_fc = loco::must_cast<const luci::CircleBCQFullyConnected *>(node); + if (bcq_fc->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return false; + + return true; +} + +std::vector<std::string> +CircleBCQFullyConnectedSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "weights_scales", "weights_binary", "bias", "weights_clusters"}; +} + +void CircleBCQFullyConnectedSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto bcq_fc = loco::must_cast<const luci::CircleBCQFullyConnected *>(node); + s.args().append("fused_activation_function", to_str(bcq_fc->fusedActivationFunction())); + s.args().append("weights_hidden_size", std::to_string(bcq_fc->weights_hidden_size())); +} + +std::vector<std::string> CircleBCQGatherSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input_scales", "input_binary", "indices", "input_clusters"}; +} + +void CircleBCQGatherSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto bcq_gather = loco::must_cast<const luci::CircleBCQGather *>(node); + s.args().append("axis", std::to_string(bcq_gather->axis())); + s.args().append("input_hidden_size", std::to_string(bcq_gather->input_hidden_size())); +} + +std::vector<std::string> +CircleBidirectionalSequenceLSTMSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", + "fw_input_to_input_weights", + "fw_input_to_forget_weights", + "fw_input_to_cell_weights", + "fw_input_to_output_weights", + "fw_recurrent_to_input_weights", + "fw_recurrent_to_forget_weights", + "fw_recurrent_to_cell_weights", + "fw_recurrent_to_output_weights", + "fw_cell_to_input_weights", + "fw_cell_to_forget_weights", + "fw_cell_to_output_weights", + "fw_input_gate_bias", + "fw_forget_gate_bias", + "fw_cell_gate_bias", + "fw_output_gate_bias", + "fw_projection_weights", + "fw_projection_bias", + "bw_input_to_input_weights", + "bw_input_to_forget_weights", + "bw_input_to_cell_weights", + "bw_input_to_output_weights", + "bw_recurrent_to_input_weights", + "bw_recurrent_to_forget_weights", + "bw_recurrent_to_cell_weights", + "bw_recurrent_to_output_weights", + "bw_cell_to_input_weights", + "bw_cell_to_forget_weights", + "bw_cell_to_output_weights", + "bw_input_gate_bias", + "bw_forget_gate_bias", + "bw_cell_gate_bias", + "bw_output_gate_bias", + "bw_projection_weights", + "bw_projection_bias", + "fw_activation_state", + "fw_cell_state", + "bw_activation_state", + "bw_cell_state", + "auxillary_input", + "fw_auxillary_input_to_input_weights", + "fw_auxillary_input_to_forget_weights", + "fw_auxillary_input_to_cell_weights", + "fw_auxillary_input_to_output_weights", + "bw_auxillary_input_to_input_weights", + "bw_auxillary_input_to_forget_weights", + "bw_auxillary_input_to_cell_weights", + "bw_auxillary_input_to_output_weights"}; +} + +void CircleBidirectionalSequenceLSTMSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto lstm = loco::must_cast<const luci::CircleBidirectionalSequenceLSTM *>(node); + s.args().append("cell_clip", to_str(lstm->cell_clip())); + s.args().append("proj_clip", to_str(lstm->proj_clip())); + s.args().append("merge_outputs", to_str(lstm->merge_outputs())); + s.args().append("time_major", to_str(lstm->time_major())); + s.args().append("asymmetric_quantize_inputs", to_str(lstm->asymmetric_quantize_inputs())); +} + +std::vector<std::string> CircleCastSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"x"}; +} + +void CircleCastSummaryBuilder::build_attributes(const luci::CircleNode *node, locop::NodeSummary &s) +{ + auto cast = loco::must_cast<const luci::CircleCast *>(node); + s.args().append("in_data_type", to_str(cast->in_data_type())); + s.args().append("out_data_type", to_str(cast->out_data_type())); +} + +bool CircleConcatenationSummaryBuilder::validate(const luci::CircleNode *node) +{ + auto concat = loco::must_cast<const luci::CircleConcatenation *>(node); + if (concat->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return false; + + return true; +} + +std::vector<std::string> +CircleConcatenationSummaryBuilder::get_input_names(const luci::CircleNode *node) +{ + return std::vector<std::string>(node->arity(), "values"); +} + +void CircleConcatenationSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto concat = loco::must_cast<const luci::CircleConcatenation *>(node); + s.args().append("axis", std::to_string(concat->axis())); + s.args().append("fused_activation_function", to_str(concat->fusedActivationFunction())); +} + +void CircleConstSummaryBuilder::update_status(locop::NodeSummary &s) +{ + s.state(locop::NodeDesc::State::PartiallyKnown); +} + +bool CircleConv2DSummaryBuilder::validate(const luci::CircleNode *node) +{ + auto conv2d = loco::must_cast<const luci::CircleConv2D *>(node); + if (conv2d->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return false; + if (conv2d->padding() == luci::Padding::UNDEFINED) + return false; + + return true; +} + +std::vector<std::string> CircleConv2DSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "filter", "bias"}; +} + +void CircleConv2DSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto conv2d = loco::must_cast<const luci::CircleConv2D *>(node); + s.args().append("stride(h,w)", to_str(conv2d->stride())); + s.args().append("dilation(h,w)", to_str(conv2d->dilation())); + s.args().append("padding", to_str(conv2d->padding())); + s.args().append("fused_activation_function", to_str(conv2d->fusedActivationFunction())); +} + +std::vector<std::string> CircleCustomSummaryBuilder::get_input_names(const luci::CircleNode *node) +{ + auto input_names = std::vector<std::string>(); + for (uint32_t i = 0; i < node->arity(); ++i) + input_names.push_back("input" + std::to_string(i)); + return input_names; +} + +void CircleCustomSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto custom = loco::must_cast<const luci::CircleCustom *>(node); + s.args().append("custom_code", custom->custom_code()); +} + +void CircleDepthToSpaceSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto depth_to_space = loco::must_cast<const luci::CircleDepthToSpace *>(node); + s.args().append("block_size", std::to_string(depth_to_space->block_size())); +} + +bool CircleDepthwiseConv2DSummaryBuilder::validate(const luci::CircleNode *node) +{ + auto dw_conv2d = loco::must_cast<const luci::CircleDepthwiseConv2D *>(node); + if (dw_conv2d->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return false; + if (dw_conv2d->padding() == luci::Padding::UNDEFINED) + return false; + + return true; +} + +std::vector<std::string> +CircleDepthwiseConv2DSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "filter", "bias"}; +} + +void CircleDepthwiseConv2DSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto dw_conv2d = loco::must_cast<const luci::CircleDepthwiseConv2D *>(node); + s.args().append("stride(h,w)", to_str(dw_conv2d->stride())); + s.args().append("dilation(h,w)", to_str(dw_conv2d->dilation())); + s.args().append("padding", to_str(dw_conv2d->padding())); + s.args().append("depthMultiplier", std::to_string(dw_conv2d->depthMultiplier())); + s.args().append("fused_activation_function", to_str(dw_conv2d->fusedActivationFunction())); +} + +std::vector<std::string> CircleExpandDimsSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "axis"}; +} + +std::vector<std::string> CircleFakeQuantSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"inputs"}; +} + +void CircleFakeQuantSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto fake_quant = loco::must_cast<const luci::CircleFakeQuant *>(node); + s.args().append("min", std::to_string(fake_quant->min())); + s.args().append("max", std::to_string(fake_quant->max())); + s.args().append("num_bits", std::to_string(fake_quant->num_bits())); + s.args().append("narrow_range", to_str(fake_quant->narrow_range())); +} + +std::vector<std::string> CircleFillSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"dims", "value"}; +} + +bool CircleFullyConnectedSummaryBuilder::validate(const luci::CircleNode *node) +{ + auto fc = loco::must_cast<const luci::CircleFullyConnected *>(node); + if (fc->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return false; + + return true; +} + +std::vector<std::string> +CircleFullyConnectedSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "weights", "bias"}; +} + +void CircleFullyConnectedSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto fc = loco::must_cast<const luci::CircleFullyConnected *>(node); + s.args().append("fused_activation_function", to_str(fc->fusedActivationFunction())); +} + +std::vector<std::string> CircleGatherSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"params", "indices"}; +} + +void CircleGatherSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto gather = loco::must_cast<const luci::CircleGather *>(node); + s.args().append("axis", std::to_string(gather->axis())); +} + +std::vector<std::string> CircleGatherNdSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"params", "indices"}; +} + +std::vector<std::string> CircleIfSummaryBuilder::get_input_names(const luci::CircleNode *node) +{ + auto circle_if = loco::must_cast<const luci::CircleIf *>(node); + + auto input_names = std::vector<std::string>(); + input_names.push_back("cond"); + for (uint32_t i = 0; i < circle_if->input_count(); ++i) + input_names.push_back("input"); + + return input_names; +} + +void CircleIfSummaryBuilder::build_attributes(const luci::CircleNode *node, locop::NodeSummary &s) +{ + auto circle_if = loco::must_cast<const luci::CircleIf *>(node); + + if (circle_if->then_graph() != nullptr) + s.args().append("then_graph", circle_if->then_graph()->name()); + else + s.args().append("then_branch", std::to_string(circle_if->then_branch())); + + if (circle_if->else_graph() != nullptr) + s.args().append("else_graph", circle_if->else_graph()->name()); + else + s.args().append("else_branch", std::to_string(circle_if->else_branch())); +} + +bool CircleInstanceNormSummaryBuilder::validate(const luci::CircleNode *node) +{ + auto instnorm = loco::must_cast<const luci::CircleInstanceNorm *>(node); + if (instnorm->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return false; + + return true; +} + +std::vector<std::string> CircleInstanceNormSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "gamma", "beta"}; +} + +void CircleInstanceNormSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto instnorm = loco::must_cast<const luci::CircleInstanceNorm *>(node); + s.args().append("epsilon", std::to_string(instnorm->epsilon())); + s.args().append("fused_activation_function", to_str(instnorm->fusedActivationFunction())); +} + +bool CircleL2NormalizeSummaryBuilder::validate(const luci::CircleNode *node) +{ + auto l2norm = loco::must_cast<const luci::CircleL2Normalize *>(node); + if (l2norm->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return false; + + return true; +} + +std::vector<std::string> CircleL2NormalizeSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"x"}; +} + +void CircleL2NormalizeSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto l2norm = loco::must_cast<const luci::CircleL2Normalize *>(node); + s.args().append("fused_activation_function", to_str(l2norm->fusedActivationFunction())); +} + +bool CircleL2Pool2DSummaryBuilder::validate(const luci::CircleNode *node) +{ + auto l2pool = loco::must_cast<const luci::CircleL2Pool2D *>(node); + if (l2pool->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return false; + if (l2pool->padding() == luci::Padding::UNDEFINED) + return false; + + return true; +} + +std::vector<std::string> CircleL2Pool2DSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"value"}; +} + +void CircleL2Pool2DSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto l2pool = loco::must_cast<const luci::CircleL2Pool2D *>(node); + s.args().append("filter(h,w)", to_str(l2pool->filter())); + s.args().append("stride(h,w)", to_str(l2pool->stride())); + s.args().append("padding", to_str(l2pool->padding())); + s.args().append("fused_activation_function", to_str(l2pool->fusedActivationFunction())); +} + +void CircleLeakyReluSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto leaky_relu = loco::must_cast<const luci::CircleLeakyRelu *>(node); + s.args().append("alpha", std::to_string(leaky_relu->alpha())); +} + +void CircleLocalResponseNormalizationSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto lrn = loco::must_cast<const luci::CircleLocalResponseNormalization *>(node); + s.args().append("radius", std::to_string(lrn->radius())); + s.args().append("bias", std::to_string(lrn->bias())); + s.args().append("alpha", std::to_string(lrn->alpha())); + s.args().append("beta", std::to_string(lrn->beta())); +} + +std::vector<std::string> CircleLogSoftmaxSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"logits"}; +} + +std::vector<std::string> CircleMatrixDiagSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"diagonal"}; +} + +std::vector<std::string> +CircleMatrixSetDiagSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "diagonal"}; +} + +bool CircleMaxPool2DSummaryBuilder::validate(const luci::CircleNode *node) +{ + auto maxpool = loco::must_cast<const luci::CircleMaxPool2D *>(node); + if (maxpool->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return false; + if (maxpool->padding() == luci::Padding::UNDEFINED) + return false; + + return true; +} + +std::vector<std::string> CircleMaxPool2DSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"value"}; +} + +void CircleMaxPool2DSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto maxpool = loco::must_cast<const luci::CircleMaxPool2D *>(node); + s.args().append("filter(h,w)", to_str(maxpool->filter())); + s.args().append("stride(h,w)", to_str(maxpool->stride())); + s.args().append("padding", to_str(maxpool->padding())); + s.args().append("fused_activation_function", to_str(maxpool->fusedActivationFunction())); +} + +bool CircleMirrorPadSummaryBuilder::validate(const luci::CircleNode *node) +{ + auto mirror_pad = loco::must_cast<const luci::CircleMirrorPad *>(node); + if (mirror_pad->mode() == luci::MirrorPadMode::UNDEFINED) + return false; + + return true; +} + +std::vector<std::string> CircleMirrorPadSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "paddings"}; +} + +void CircleMirrorPadSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto mirror_pad = loco::must_cast<const luci::CircleMirrorPad *>(node); + s.args().append("mode", to_str(mirror_pad->mode())); +} + +bool CircleMulSummaryBuilder::validate(const luci::CircleNode *node) +{ + auto mul = loco::must_cast<const luci::CircleMul *>(node); + if (mul->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return false; + + return true; +} + +void CircleMulSummaryBuilder::build_attributes(const luci::CircleNode *node, locop::NodeSummary &s) +{ + auto mul = loco::must_cast<const luci::CircleMul *>(node); + s.args().append("fused_activation_function", to_str(mul->fusedActivationFunction())); +} + +std::vector<std::string> +CircleNonMaxSuppressionV4SummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"boxes", "scores", "max_output_size", "iou_threshold", "score_threshold"}; +} + +std::vector<std::string> +CircleNonMaxSuppressionV5SummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"boxes", "scores", "max_output_size", + "iou_threshold", "score_threshold", "soft_nms_sigma"}; +} + +std::vector<std::string> CircleOneHotSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"indices", "depth", "on_value", "off_value"}; +} + +void CircleOneHotSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto onehot = loco::must_cast<const luci::CircleOneHot *>(node); + s.args().append("axis", std::to_string(onehot->axis())); +} + +std::vector<std::string> CirclePackSummaryBuilder::get_input_names(const luci::CircleNode *node) +{ + return std::vector<std::string>(node->arity(), "values"); +} + +void CirclePackSummaryBuilder::build_attributes(const luci::CircleNode *node, locop::NodeSummary &s) +{ + auto pack = loco::must_cast<const luci::CirclePack *>(node); + s.args().append("values_count", std::to_string(pack->values_count())); + s.args().append("axis", std::to_string(pack->axis())); +} + +std::vector<std::string> CirclePadSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "paddings"}; +} + +std::vector<std::string> CirclePadV2SummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "paddings", "constant_values"}; +} + +std::vector<std::string> CirclePReluSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "alpha"}; +} + +std::vector<std::string> CircleRangeSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"start", "limit", "delta"}; +} + +std::vector<std::string> CircleReshapeSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"tensor", "shape"}; +} + +void CircleReshapeSummaryBuilder::update_status(locop::NodeSummary &s) +{ + s.state(locop::NodeDesc::State::PartiallyKnown); +} + +std::vector<std::string> +CircleResizeBilinearSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "size"}; +} + +void CircleResizeBilinearSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto resize_bilinear = loco::must_cast<const luci::CircleResizeBilinear *>(node); + s.args().append("align_corners", to_str(resize_bilinear->align_corners())); + s.args().append("half_pixel_centers", to_str(resize_bilinear->half_pixel_centers())); +} + +std::vector<std::string> +CircleResizeNearestNeighborSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "size"}; +} + +void CircleResizeNearestNeighborSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto resize_nn = loco::must_cast<const luci::CircleResizeNearestNeighbor *>(node); + s.args().append("align_corners", to_str(resize_nn->align_corners())); +} + +std::vector<std::string> +CircleReverseSequenceSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "seq_lengths"}; +} + +void CircleReverseSequenceSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto reverse_seq = loco::must_cast<const luci::CircleReverseSequence *>(node); + s.args().append("seq_axis", std::to_string(reverse_seq->seq_axis())); + s.args().append("batch_axis", std::to_string(reverse_seq->batch_axis())); +} + +std::vector<std::string> CircleReverseV2SummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"tensor", "axis"}; +} + +std::vector<std::string> CircleScatterNdSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"indices", "updates", "shape"}; +} + +std::vector<std::string> CircleSegmentSumSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "segment_ids"}; +} + +std::vector<std::string> CircleSelectSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"condition", "t", "e"}; +} + +std::vector<std::string> CircleSelectV2SummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"condition", "t", "e"}; +} + +void CircleShapeSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto shape = loco::must_cast<const luci::CircleShape *>(node); + s.args().append("out_type", to_str(shape->out_type())); +} + +std::vector<std::string> CircleSliceSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "begin", "size"}; +} + +std::vector<std::string> CircleSoftmaxSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"logits"}; +} + +void CircleSoftmaxSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto softmax = loco::must_cast<const luci::CircleSoftmax *>(node); + s.args().append("beta", to_str(softmax->beta())); +} + +std::vector<std::string> +CircleSpaceToBatchNDSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "block_shape", "paddings"}; +} + +void CircleSpaceToDepthSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto space_to_depth = loco::must_cast<const luci::CircleSpaceToDepth *>(node); + s.args().append("block_size", to_str(space_to_depth->block_size())); +} + +std::vector<std::string> +CircleSparseToDenseSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"indices", "output_shape", "values", "default_value"}; +} + +void CircleSparseToDenseSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto sparse_to_dense = loco::must_cast<const luci::CircleSparseToDense *>(node); + s.args().append("validate_indices", to_str(sparse_to_dense->validate_indices())); +} + +std::vector<std::string> CircleSplitSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"split_dim", "input"}; +} + +void CircleSplitSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto split = loco::must_cast<const luci::CircleSplit *>(node); + s.args().append("num_split", std::to_string(split->num_split())); +} + +std::vector<std::string> CircleSplitVSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "size_splits", "split_dim"}; +} + +void CircleSplitVSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto split_v = loco::must_cast<const luci::CircleSplitV *>(node); + s.args().append("num_split", std::to_string(split_v->num_split())); +} + +void CircleSqueezeSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto squeeze = loco::must_cast<const luci::CircleSqueeze *>(node); + + std::string squeeze_dims = "("; + for (size_t i = 0; i < squeeze->squeeze_dims().size(); ++i) + { + if (i != 0) + squeeze_dims += ", "; + squeeze_dims += std::to_string(squeeze->squeeze_dims().at(i)); + } + squeeze_dims += ")"; + + s.args().append("squeeze_dims", squeeze_dims); +} + +std::vector<std::string> CircleStridedSliceSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "begin", "end", "strides"}; +} + +void CircleStridedSliceSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto strided_slice = loco::must_cast<const luci::CircleStridedSlice *>(node); + s.args().append("begin_mask", std::to_string(strided_slice->begin_mask())); + s.args().append("end_mask", std::to_string(strided_slice->end_mask())); + s.args().append("ellipsis_mask", std::to_string(strided_slice->ellipsis_mask())); + s.args().append("new_axis_mask", std::to_string(strided_slice->new_axis_mask())); + s.args().append("shrink_axis_mask", std::to_string(strided_slice->shrink_axis_mask())); +} + +bool CircleSVDFSummaryBuilder::validate(const luci::CircleNode *node) +{ + auto svdf = loco::must_cast<const luci::CircleSVDF *>(node); + if (svdf->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return false; + + return true; +} + +std::vector<std::string> CircleSVDFSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "weight_feature", "weight_time", "bias", "State"}; +} + +void CircleSVDFSummaryBuilder::build_attributes(const luci::CircleNode *node, locop::NodeSummary &s) +{ + auto svdf = loco::must_cast<const luci::CircleSVDF *>(node); + s.args().append("rank", to_str(svdf->svdf_rank())); + s.args().append("asymmetric_quantize_inputs", to_str(svdf->asymmetric_quantize_inputs())); + s.args().append("fused_activation_function", to_str(svdf->fusedActivationFunction())); +} + +std::vector<std::string> CircleTileSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "multiples"}; +} + +std::vector<std::string> CircleTopKV2SummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", "k"}; +} + +std::vector<std::string> CircleTransposeSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"a", "perm"}; +} + +bool CircleTransposeConvSummaryBuilder::validate(const luci::CircleNode *node) +{ + auto transpose_conv = loco::must_cast<const luci::CircleTransposeConv *>(node); + if (transpose_conv->padding() == luci::Padding::UNDEFINED) + return false; + + return true; +} + +std::vector<std::string> +CircleTransposeConvSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"inputSizes", "filter", "outBackProp", "bias"}; +} + +void CircleTransposeConvSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto transpose_conv = loco::must_cast<const luci::CircleTransposeConv *>(node); + s.args().append("stride(h,w)", to_str(transpose_conv->stride())); + s.args().append("padding", to_str(transpose_conv->padding())); +} + +std::vector<std::string> +CircleUnidirectionalSequenceLSTMSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"input", + "input_to_input_weights", + "input_to_forget_weights", + "input_to_cell_weights", + "input_to_output_weights", + "recurrent_to_input_weights", + "recurrent_to_forget_weights", + "recurrent_to_cell_weights", + "recurrent_to_output_weights", + "cell_to_input_weights", + "cell_to_forget_weights", + "cell_to_output_weights", + "input_gate_bias", + "forget_gate_bias", + "cell_gate_bias", + "output_gate_bias", + "projection_weights", + "projection_bias", + "activation_state", + "cell_state", + "input_layer_norm_coefficients", + "forget_layer_norm_coefficients", + "cell_layer_norm_coefficients", + "output_layer_norm_coefficients"}; +} + +void CircleUnidirectionalSequenceLSTMSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto lstm = loco::must_cast<const luci::CircleUnidirectionalSequenceLSTM *>(node); + s.args().append("cell_clip", to_str(lstm->cell_clip())); + s.args().append("proj_clip", to_str(lstm->proj_clip())); + s.args().append("time_major", to_str(lstm->time_major())); + s.args().append("asymmetric_quantize_inputs", to_str(lstm->asymmetric_quantize_inputs())); +} + +void CircleUniqueSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto unique = loco::must_cast<const luci::CircleUnique *>(node); + s.args().append("idx_out_type", to_str(unique->idx_out_type())); +} + +std::vector<std::string> CircleUnpackSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"value"}; +} + +void CircleUnpackSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto unpack = loco::must_cast<const luci::CircleUnpack *>(node); + s.args().append("num", std::to_string(unpack->num())); + s.args().append("axis", std::to_string(unpack->axis())); +} +std::vector<std::string> CircleWhereSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"condition"}; +} + +std::vector<std::string> CircleWhileSummaryBuilder::get_input_names(const luci::CircleNode *node) +{ + auto circle_while = loco::must_cast<const luci::CircleWhile *>(node); + + auto input_names = std::vector<std::string>(); + for (uint32_t i = 0; i < circle_while->input_count(); ++i) + input_names.push_back("input"); + + return input_names; +} + +void CircleWhileSummaryBuilder::build_attributes(const luci::CircleNode *node, + locop::NodeSummary &s) +{ + auto circle_while = loco::must_cast<const luci::CircleWhile *>(node); + + if (circle_while->cond_graph() != nullptr) + s.args().append("then_graph", circle_while->cond_graph()->name()); + else + s.args().append("then_branch", std::to_string(circle_while->cond_branch())); + + if (circle_while->body_graph() != nullptr) + s.args().append("else_graph", circle_while->body_graph()->name()); + else + s.args().append("else_branch", std::to_string(circle_while->body_branch())); +} + +std::vector<std::string> CircleOutputSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"from"}; +} + +std::vector<std::string> CircleTopKV2OutSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"topkv2"}; +} + +std::vector<std::string> CircleUniqueOutSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"unique"}; +} + +std::vector<std::string> CircleUnpackOutSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"unpack"}; +} + +std::vector<std::string> CircleWhileOutSummaryBuilder::get_input_names(const luci::CircleNode *) +{ + return {"while"}; +} + +} // namespace luci diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilders.h b/compiler/luci/logex/src/CircleNodeSummaryBuilders.h new file mode 100644 index 000000000..6cd24b7f1 --- /dev/null +++ b/compiler/luci/logex/src/CircleNodeSummaryBuilders.h @@ -0,0 +1,821 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_LOGEX_CIRCLE_NODE_SUMMARY_BUILDERS__ +#define __LUCI_LOGEX_CIRCLE_NODE_SUMMARY_BUILDERS__ + +#include "CircleNodeSummaryBuilder.h" + +#include <luci/IR/CircleNode.h> + +#include <string> +#include <vector> + +namespace luci +{ + +class CircleNodeWithXSummaryBuilder : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleNodeWithINPUTSummaryBuilder : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleNodeWithXYSummaryBuilder : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleNodeWithFEATURESSummaryBuilder : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +template <class REDUCER_NODE> +class CircleNodeWithReducerSummaryBuilder : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *) + { + return {"input", "reduction_indices"}; + } + + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s) + { + auto mean = loco::must_cast<const REDUCER_NODE *>(node); + s.args().append("keep_dims", mean->keep_dims() ? "true" : "false"); + } +}; + +} // namespace luci + +namespace luci +{ + +class CircleAbsSummaryBuilder final : public CircleNodeWithXSummaryBuilder +{ +}; + +class CircleAddSummaryBuilder final : public CircleNodeWithXYSummaryBuilder +{ +private: + bool validate(const luci::CircleNode *node); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleAddNSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *node); +}; + +class CircleArgMaxSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleArgMinSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleAveragePool2DSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + bool validate(const luci::CircleNode *node); + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleBatchMatMulSummaryBuilder final : public CircleNodeWithXYSummaryBuilder +{ +private: + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleBatchToSpaceNDSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleBCQFullyConnectedSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + bool validate(const luci::CircleNode *node); + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleBCQGatherSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleBidirectionalSequenceLSTMSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleCastSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleCeilSummaryBuilder final : public CircleNodeWithXSummaryBuilder +{ +}; + +class CircleConcatenationSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + bool validate(const luci::CircleNode *node); + std::vector<std::string> get_input_names(const luci::CircleNode *node); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleConstSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + void update_status(locop::NodeSummary &s); +}; + +class CircleConv2DSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + bool validate(const luci::CircleNode *node); + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleCosSummaryBuilder final : public CircleNodeWithXSummaryBuilder +{ +}; + +class CircleCustomSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *node); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleDepthToSpaceSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder +{ +private: + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleDepthwiseConv2DSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + bool validate(const luci::CircleNode *node); + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleDequantizeSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder +{ +}; + +class CircleDivSummaryBuilder final : public CircleNodeWithXYSummaryBuilder +{ +}; + +class CircleEluSummaryBuilder final : public CircleNodeWithFEATURESSummaryBuilder +{ +}; + +class CircleEqualSummaryBuilder final : public CircleNodeWithXYSummaryBuilder +{ +}; + +class CircleExpSummaryBuilder final : public CircleNodeWithXSummaryBuilder +{ +}; + +class CircleExpandDimsSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleFakeQuantSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleFillSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleFloorSummaryBuilder final : public CircleNodeWithXSummaryBuilder +{ +}; + +class CircleFloorDivSummaryBuilder final : public CircleNodeWithXYSummaryBuilder +{ +}; + +class CircleFloorModSummaryBuilder final : public CircleNodeWithXYSummaryBuilder +{ +}; + +class CircleFullyConnectedSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + bool validate(const luci::CircleNode *node); + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleGatherSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleGatherNdSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleGreaterSummaryBuilder final : public CircleNodeWithXYSummaryBuilder +{ +}; + +class CircleGreaterEqualSummaryBuilder final : public CircleNodeWithXYSummaryBuilder +{ +}; + +class CircleIfSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *node); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleInstanceNormSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + bool validate(const luci::CircleNode *node); + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleL2NormalizeSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + bool validate(const luci::CircleNode *node); + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleL2Pool2DSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + bool validate(const luci::CircleNode *node); + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleLeakyReluSummaryBuilder final : public CircleNodeWithFEATURESSummaryBuilder +{ +private: + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleLessSummaryBuilder final : public CircleNodeWithXYSummaryBuilder +{ +}; + +class CircleLessEqualSummaryBuilder final : public CircleNodeWithXYSummaryBuilder +{ +}; + +class CircleLocalResponseNormalizationSummaryBuilder final + : public CircleNodeWithINPUTSummaryBuilder +{ +private: + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleLogSummaryBuilder final : public CircleNodeWithXSummaryBuilder +{ +}; + +class CircleLogicalAndSummaryBuilder final : public CircleNodeWithXYSummaryBuilder +{ +}; + +class CircleLogicalNotSummaryBuilder final : public CircleNodeWithXSummaryBuilder +{ +}; + +class CircleLogicalOrSummaryBuilder final : public CircleNodeWithXYSummaryBuilder +{ +}; + +class CircleLogisticSummaryBuilder final : public CircleNodeWithXSummaryBuilder +{ +}; + +class CircleLogSoftmaxSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleMatrixDiagSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleMatrixSetDiagSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleMaximumSummaryBuilder final : public CircleNodeWithXYSummaryBuilder +{ +}; + +class CircleMaxPool2DSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + bool validate(const luci::CircleNode *node); + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleMeanSummaryBuilder final : public CircleNodeWithReducerSummaryBuilder<luci::CircleMean> +{ +}; + +class CircleMinimumSummaryBuilder final : public CircleNodeWithXYSummaryBuilder +{ +}; + +class CircleMirrorPadSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + bool validate(const luci::CircleNode *node); + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleMulSummaryBuilder final : public CircleNodeWithXYSummaryBuilder +{ +private: + bool validate(const luci::CircleNode *node); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleNegSummaryBuilder final : public CircleNodeWithXSummaryBuilder +{ +}; + +class CircleNonMaxSuppressionV4SummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleNonMaxSuppressionV5SummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleNotEqualSummaryBuilder final : public CircleNodeWithXYSummaryBuilder +{ +}; + +class CircleOneHotSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CirclePackSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *node); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CirclePadSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CirclePadV2SummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CirclePowSummaryBuilder final : public CircleNodeWithXYSummaryBuilder +{ +}; + +class CirclePReluSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleQuantizeSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder +{ +}; + +class CircleRangeSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleRankSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder +{ +}; + +class CircleReduceAnySummaryBuilder final + : public CircleNodeWithReducerSummaryBuilder<luci::CircleReduceAny> +{ +}; + +class CircleReduceMaxSummaryBuilder final + : public CircleNodeWithReducerSummaryBuilder<luci::CircleReduceMax> +{ +}; + +class CircleReduceMinSummaryBuilder final + : public CircleNodeWithReducerSummaryBuilder<luci::CircleReduceMin> +{ +}; + +class CircleReduceProdSummaryBuilder final + : public CircleNodeWithReducerSummaryBuilder<luci::CircleReduceProd> +{ +}; + +class CircleReluSummaryBuilder final : public CircleNodeWithFEATURESSummaryBuilder +{ +}; + +class CircleRelu6SummaryBuilder final : public CircleNodeWithFEATURESSummaryBuilder +{ +}; + +class CircleReluN1To1SummaryBuilder final : public CircleNodeWithFEATURESSummaryBuilder +{ +}; + +class CircleReshapeSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); + void update_status(locop::NodeSummary &s); +}; + +class CircleResizeBilinearSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleResizeNearestNeighborSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleReverseSequenceSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleReverseV2SummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleRoundSummaryBuilder final : public CircleNodeWithXSummaryBuilder +{ +}; + +class CircleRsqrtSummaryBuilder final : public CircleNodeWithXSummaryBuilder +{ +}; + +class CircleScatterNdSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleSegmentSumSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleSelectSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleSelectV2SummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleShapeSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder +{ +private: + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleSinSummaryBuilder final : public CircleNodeWithXSummaryBuilder +{ +}; + +class CircleSliceSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleSoftmaxSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleSpaceToBatchNDSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleSpaceToDepthSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder +{ +private: + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleSparseToDenseSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleSplitSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleSplitVSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleSqrtSummaryBuilder final : public CircleNodeWithXSummaryBuilder +{ +}; + +class CircleSquareSummaryBuilder final : public CircleNodeWithXSummaryBuilder +{ +}; + +class CircleSquaredDifferenceSummaryBuilder final : public CircleNodeWithXYSummaryBuilder +{ +}; + +class CircleSqueezeSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder +{ +private: + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleStridedSliceSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleSubSummaryBuilder final : public CircleNodeWithXYSummaryBuilder +{ +}; + +class CircleSumSummaryBuilder final : public CircleNodeWithReducerSummaryBuilder<luci::CircleSum> +{ +}; + +class CircleSVDFSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + bool validate(const luci::CircleNode *node); + std::vector<std::string> get_input_names(const luci::CircleNode *); + + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleTanhSummaryBuilder final : public CircleNodeWithXSummaryBuilder +{ +}; + +class CircleTileSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleTopKV2SummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleTransposeSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleTransposeConvSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + bool validate(const luci::CircleNode *node); + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleUnidirectionalSequenceLSTMSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleUniqueSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder +{ +private: + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleUnpackSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleWhereSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleWhileSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *node); + void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s); +}; + +class CircleZerosLikeSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder +{ +}; + +class CircleBidirectionalSequenceLSTMOutSummaryBuilder final + : public CircleNodeWithINPUTSummaryBuilder +{ +}; + +class CircleCustomOutSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder +{ +}; + +class CircleIfOutSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder +{ +}; + +class CircleInputSummaryBuilder final : public CircleNodeSummaryBuilder +{ +}; + +class CircleNonMaxSuppressionV4OutSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder +{ +}; + +class CircleNonMaxSuppressionV5OutSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder +{ +}; + +class CircleOutputSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleOutputDummySummaryBuilder final : public CircleNodeSummaryBuilder +{ +}; + +class CircleOutputExcludeSummaryBuilder final : public CircleNodeSummaryBuilder +{ +}; + +class CircleSplitOutSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder +{ +}; + +class CircleSplitVOutSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder +{ +}; + +class CircleTopKV2OutSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleUniqueOutSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleUnpackOutSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +class CircleVariableSummaryBuilder final : public CircleNodeSummaryBuilder +{ +}; + +class CircleWhileOutSummaryBuilder final : public CircleNodeSummaryBuilder +{ +private: + std::vector<std::string> get_input_names(const luci::CircleNode *); +}; + +} // namespace luci + +#endif // __LUCI_LOGEX_CIRCLE_NODE_SUMMARY_BUILDERS__ diff --git a/compiler/luci/logex/src/FormattedGraph.cpp b/compiler/luci/logex/src/FormattedGraph.cpp index 0588ed79e..d3b2170b0 100644 --- a/compiler/luci/logex/src/FormattedGraph.cpp +++ b/compiler/luci/logex/src/FormattedGraph.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "CircleNodeSummaryBuilder.h" #include "luci/FormattedGraph.h" #include <luci/IR/CircleDialect.h> @@ -25,2179 +26,6 @@ #include <sstream> #include <vector> -using namespace luci; -/** - * @brief dump std::vector<int64_t> values to stream - */ -std::ostream &operator<<(std::ostream &os, const std::vector<int64_t> &vi64) -{ - for (auto vi : vi64) - { - os << vi << " "; - } - return os; -} - -// For TF lite -namespace -{ - -const char *to_str(loco::DataType type) -{ - switch (type) - { - case loco::DataType::U8: - return "UINT8"; - case loco::DataType::U16: - return "UINT16"; - case loco::DataType::U32: - return "UINT32"; - case loco::DataType::U64: - return "UINT64"; - - case loco::DataType::S8: - return "INT8"; - case loco::DataType::S16: - return "INT16"; - case loco::DataType::S32: - return "INT32"; - case loco::DataType::S64: - return "INT64"; - - case loco::DataType::FLOAT16: - return "FLOAT16"; - case loco::DataType::FLOAT32: - return "FLOAT32"; - case loco::DataType::FLOAT64: - return "FLOAT64"; - - case loco::DataType::BOOL: - return "BOOL"; - - default: - return "Error"; - } -} - -const char *to_str(bool value) { return value ? "true" : "false"; } - -const char *to_str(luci::FusedActFunc fused) -{ - switch (fused) - { - case luci::FusedActFunc::NONE: - return "NONE"; - case luci::FusedActFunc::RELU: - return "RELU"; - case luci::FusedActFunc::RELU_N1_TO_1: - return "RELU_N1_TO_1"; - case luci::FusedActFunc::RELU6: - return "RELU6"; - case luci::FusedActFunc::TANH: - return "TANH"; - case luci::FusedActFunc::SIGN_BIT: - return "SIGN_BIT"; - default: - return "Error"; - } -} - -const char *to_str(luci::Padding padding) -{ - switch (padding) - { - case luci::Padding::SAME: - return "SAME"; - case luci::Padding::VALID: - return "VALID"; - default: - return "Error"; - } -} - -const char *to_str(luci::MirrorPadMode mode) -{ - switch (mode) - { - case luci::MirrorPadMode::REFLECT: - return "REFLECT"; - case luci::MirrorPadMode::SYMMETRIC: - return "SYMMETRIC"; - default: - return "Error"; - } -} - -std::string to_str(const luci::Stride *stride) -{ - return pepper::str(stride->h(), ",", stride->w()); -} - -std::string to_str(const luci::Filter *filter) -{ - return pepper::str(filter->h(), ",", filter->w()); -} - -std::string circle_opname(uint32_t opnum) -{ - static const std::string prefix{"circle."}; - - switch (static_cast<luci::CircleOpcode>(opnum)) - { -#define CIRCLE_NODE(OPCODE, CLASS) \ - case luci::CircleOpcode::OPCODE: \ - return prefix + #OPCODE; -#define CIRCLE_VNODE CIRCLE_NODE -#include <luci/IR/CircleNodes.lst> -#undef CIRCLE_VNODE -#undef CIRCLE_NODE - default: - break; - }; - - return prefix + "Invalid"; -} - -// CircleNodeSummaryBuilder with default implementation -class CircleNodeSummaryBuilderBase : public locop::NodeSummaryBuilder -{ -public: - CircleNodeSummaryBuilderBase(const locop::SymbolTable *tbl) : _tbl{tbl} - { - // DO NOTHING - } - -public: - bool build(const loco::Node *, locop::NodeSummary &s) const final; - -protected: -#define CIRCLE_NODE(OPCODE, CLASS) \ - virtual bool summary(const CLASS *, locop::NodeSummary &) const { return false; } -#define CIRCLE_VNODE CIRCLE_NODE -#include <luci/IR/CircleNodes.lst> -#undef CIRCLE_VNODE -#undef CIRCLE_NODE - -protected: - const locop::SymbolTable *tbl(void) const { return _tbl; } - -private: - const locop::SymbolTable *_tbl; -}; - -template <class CIRCLENODE> -bool use_x(const locop::SymbolTable *tbl, const CIRCLENODE *node, locop::NodeSummary &s) -{ - s.args().append("x", tbl->lookup(node->x())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -template <class CIRCLENODE> -bool use_input(const locop::SymbolTable *tbl, const CIRCLENODE *node, locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -template <class CIRCLENODE> -bool use_features(const locop::SymbolTable *tbl, const CIRCLENODE *node, locop::NodeSummary &s) -{ - s.args().append("features", tbl->lookup(node->features())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -template <class CIRCLENODE> -bool use_xy(const locop::SymbolTable *tbl, const CIRCLENODE *node, locop::NodeSummary &s) -{ - s.args().append("x", tbl->lookup(node->x())); - s.args().append("y", tbl->lookup(node->y())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -template <class CIRCLENODE> -bool use_xy_act(const locop::SymbolTable *tbl, const CIRCLENODE *node, locop::NodeSummary &s) -{ - assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED); - - s.args().append("x", tbl->lookup(node->x())); - s.args().append("y", tbl->lookup(node->y())); - s.args().append("fused_activation_function", to_str(node->fusedActivationFunction())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -template <class CIRCLENODE> -bool use_reducer(const locop::SymbolTable *tbl, const CIRCLENODE *node, locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - s.args().append("reduction_indices", tbl->lookup(node->reduction_indices())); - s.args().append("keep_dims", node->keep_dims() ? "true" : "false"); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -template <class CIRCLENODE> -bool use_ido(const locop::SymbolTable *tbl, const CIRCLENODE *node, locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - s.args().append("dimension", tbl->lookup(node->dimension())); - s.args().append("output_type", to_str(node->output_type())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleAddN *node, - locop::NodeSummary &s) -{ - for (uint32_t i = 0; i < node->arity(); ++i) - s.args().append("inputs", tbl->lookup(node->inputs(i))); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleAveragePool2D *node, - locop::NodeSummary &s) -{ - assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED); - - s.args().append("value", tbl->lookup(node->value())); - s.args().append("filter(h,w)", to_str(node->filter())); - s.args().append("stride(h,w)", to_str(node->stride())); - s.args().append("padding", to_str(node->padding())); - s.args().append("fused", to_str(node->fusedActivationFunction())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleBatchMatMul *node, - locop::NodeSummary &s) -{ - s.args().append("x", tbl->lookup(node->x())); - s.args().append("y", tbl->lookup(node->y())); - s.args().append("adj_x", to_str(node->adj_x())); - s.args().append("adj_y", to_str(node->adj_y())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleBatchToSpaceND *node, - locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - s.args().append("block_shape", tbl->lookup(node->block_shape())); - s.args().append("crops", tbl->lookup(node->crops())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleBidirectionalSequenceLSTM *node, - locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - - s.args().append("fw_input_to_input_weights", tbl->lookup(node->fw_input_to_input_weights())); - s.args().append("fw_input_to_forget_weights", tbl->lookup(node->fw_input_to_forget_weights())); - s.args().append("fw_input_to_cell_weights", tbl->lookup(node->fw_input_to_cell_weights())); - s.args().append("fw_input_to_output_weights", tbl->lookup(node->fw_input_to_output_weights())); - - s.args().append("fw_recurrent_to_input_weights", - tbl->lookup(node->fw_recurrent_to_input_weights())); - s.args().append("fw_recurrent_to_forget_weights", - tbl->lookup(node->fw_recurrent_to_forget_weights())); - s.args().append("fw_recurrent_to_cell_weights", - tbl->lookup(node->fw_recurrent_to_cell_weights())); - s.args().append("fw_recurrent_to_output_weights", - tbl->lookup(node->fw_recurrent_to_output_weights())); - - s.args().append("fw_cell_to_input_weights", tbl->lookup(node->fw_cell_to_input_weights())); - s.args().append("fw_cell_to_forget_weights", tbl->lookup(node->fw_cell_to_forget_weights())); - s.args().append("fw_cell_to_output_weights", tbl->lookup(node->fw_cell_to_output_weights())); - - s.args().append("fw_input_gate_bias", tbl->lookup(node->fw_input_gate_bias())); - s.args().append("fw_forget_gate_bias", tbl->lookup(node->fw_forget_gate_bias())); - s.args().append("fw_cell_gate_bias", tbl->lookup(node->fw_cell_gate_bias())); - s.args().append("fw_output_gate_bias", tbl->lookup(node->fw_output_gate_bias())); - - s.args().append("fw_projection_weights", tbl->lookup(node->fw_projection_weights())); - s.args().append("fw_projection_bias", tbl->lookup(node->fw_projection_bias())); - - s.args().append("bw_input_to_input_weights", tbl->lookup(node->bw_input_to_input_weights())); - s.args().append("bw_input_to_forget_weights", tbl->lookup(node->bw_input_to_forget_weights())); - s.args().append("bw_input_to_cell_weights", tbl->lookup(node->bw_input_to_cell_weights())); - s.args().append("bw_input_to_output_weights", tbl->lookup(node->bw_input_to_output_weights())); - - s.args().append("bw_recurrent_to_input_weights", - tbl->lookup(node->bw_recurrent_to_input_weights())); - s.args().append("bw_recurrent_to_forget_weights", - tbl->lookup(node->bw_recurrent_to_forget_weights())); - s.args().append("bw_recurrent_to_cell_weights", - tbl->lookup(node->bw_recurrent_to_cell_weights())); - s.args().append("bw_recurrent_to_output_weights", - tbl->lookup(node->bw_recurrent_to_output_weights())); - - s.args().append("bw_cell_to_input_weights", tbl->lookup(node->bw_cell_to_input_weights())); - s.args().append("bw_cell_to_forget_weights", tbl->lookup(node->bw_cell_to_forget_weights())); - s.args().append("bw_cell_to_output_weights", tbl->lookup(node->bw_cell_to_output_weights())); - - s.args().append("bw_input_gate_bias", tbl->lookup(node->bw_input_gate_bias())); - s.args().append("bw_forget_gate_bias", tbl->lookup(node->bw_forget_gate_bias())); - s.args().append("bw_cell_gate_bias", tbl->lookup(node->bw_cell_gate_bias())); - s.args().append("bw_output_gate_bias", tbl->lookup(node->bw_output_gate_bias())); - - s.args().append("bw_projection_weights", tbl->lookup(node->bw_projection_weights())); - s.args().append("bw_projection_bias", tbl->lookup(node->bw_projection_bias())); - - s.args().append("fw_activation_state", tbl->lookup(node->fw_activation_state())); - s.args().append("fw_cell_state", tbl->lookup(node->fw_cell_state())); - s.args().append("bw_activation_state", tbl->lookup(node->bw_activation_state())); - s.args().append("bw_cell_state", tbl->lookup(node->bw_cell_state())); - - s.args().append("auxillary_input", tbl->lookup(node->auxillary_input())); - s.args().append("fw_auxillary_input_to_input_weights", - tbl->lookup(node->fw_auxillary_input_to_input_weights())); - s.args().append("fw_auxillary_input_to_forget_weights", - tbl->lookup(node->fw_auxillary_input_to_forget_weights())); - s.args().append("fw_auxillary_input_to_cell_weights", - tbl->lookup(node->fw_auxillary_input_to_cell_weights())); - s.args().append("fw_auxillary_input_to_output_weights", - tbl->lookup(node->fw_auxillary_input_to_output_weights())); - s.args().append("bw_auxillary_input_to_input_weights", - tbl->lookup(node->bw_auxillary_input_to_input_weights())); - s.args().append("bw_auxillary_input_to_forget_weights", - tbl->lookup(node->bw_auxillary_input_to_forget_weights())); - s.args().append("bw_auxillary_input_to_cell_weights", - tbl->lookup(node->bw_auxillary_input_to_cell_weights())); - s.args().append("bw_auxillary_input_to_output_weights", - tbl->lookup(node->bw_auxillary_input_to_output_weights())); - - s.args().append("cell_clip", to_str(node->cell_clip())); - s.args().append("proj_clip", to_str(node->proj_clip())); - s.args().append("merge_outputs", to_str(node->merge_outputs())); - s.args().append("time_major", to_str(node->time_major())); - s.args().append("asymmetric_quantize_inputs", to_str(node->asymmetric_quantize_inputs())); - - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleCast *node, - locop::NodeSummary &s) -{ - s.args().append("x", tbl->lookup(node->x())); - s.args().append("in_data_type", to_str(node->in_data_type())); - s.args().append("out_data_type", to_str(node->out_data_type())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleConcatenation *node, - locop::NodeSummary &s) -{ - assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED); - - for (uint32_t i = 0; i < node->numValues(); ++i) - s.args().append("values", tbl->lookup(node->values(i))); - s.args().append("axis", pepper::str(node->axis())); - s.args().append("fused", to_str(node->fusedActivationFunction())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleConv2D *node, - locop::NodeSummary &s) -{ - assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED); - assert(node->padding() != luci::Padding::UNDEFINED); - - s.args().append("input", tbl->lookup(node->input())); - s.args().append("filter", tbl->lookup(node->filter())); - s.args().append("bias", tbl->lookup(node->bias())); - s.args().append("stride(h,w)", to_str(node->stride())); - s.args().append("dilation(h,w)", to_str(node->dilation())); - s.args().append("padding", to_str(node->padding())); - s.args().append("fused", to_str(node->fusedActivationFunction())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleCustom *node, - locop::NodeSummary &s) -{ - for (uint32_t i = 0; i < node->numInputs(); i++) - { - s.args().append("input" + std::to_string(i), tbl->lookup(node->inputs(i))); - } - s.args().append("custom_code", node->custom_code()); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleDepthToSpace *node, - locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - s.args().append("block_size", std::to_string(node->block_size())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleDepthwiseConv2D *node, - locop::NodeSummary &s) -{ - assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED); - assert(node->padding() != luci::Padding::UNDEFINED); - - s.args().append("input", tbl->lookup(node->input())); - s.args().append("filter", tbl->lookup(node->filter())); - s.args().append("bias", tbl->lookup(node->bias())); - s.args().append("stride(h,w)", to_str(node->stride())); - s.args().append("dilation(h,w)", to_str(node->dilation())); - s.args().append("padding", to_str(node->padding())); - s.args().append("depthMultiplier", std::to_string(node->depthMultiplier())); - s.args().append("fused", to_str(node->fusedActivationFunction())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleExpandDims *node, - locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - s.args().append("axis", tbl->lookup(node->axis())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleFakeQuant *node, - locop::NodeSummary &s) -{ - s.args().append("inputs", tbl->lookup(node->inputs())); - s.args().append("min", pepper::str(node->min())); - s.args().append("max", pepper::str(node->max())); - s.args().append("num_bits", pepper::str(node->num_bits())); - s.args().append("narrow_range", node->narrow_range() ? "true" : "false"); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleFill *node, - locop::NodeSummary &s) -{ - s.args().append("dims", tbl->lookup(node->dims())); - s.args().append("value", tbl->lookup(node->value())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleFullyConnected *node, - locop::NodeSummary &s) -{ - assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED); - - s.args().append("input", tbl->lookup(node->input())); - s.args().append("weights", tbl->lookup(node->weights())); - s.args().append("bias", tbl->lookup(node->bias())); - s.args().append("fused", to_str(node->fusedActivationFunction())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleGather *node, - locop::NodeSummary &s) -{ - s.args().append("params", tbl->lookup(node->params())); - s.args().append("indices", tbl->lookup(node->indices())); - s.args().append("axis", pepper::str(node->axis())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleGatherNd *node, - locop::NodeSummary &s) -{ - s.args().append("params", tbl->lookup(node->params())); - s.args().append("indices", tbl->lookup(node->indices())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleIf *node, locop::NodeSummary &s) -{ - s.args().append("cond", tbl->lookup(node->cond())); - for (uint32_t i = 0; i < node->input_count(); ++i) - s.args().append("input", tbl->lookup(node->input(i))); - - if (node->then_graph() != nullptr) - s.args().append("then_graph", node->then_graph()->name()); - else - s.args().append("then_branch", pepper::str(node->then_branch())); - - if (node->else_graph() != nullptr) - s.args().append("else_graph", node->else_graph()->name()); - else - s.args().append("else_branch", pepper::str(node->else_branch())); - - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleL2Normalize *node, - locop::NodeSummary &s) -{ - s.args().append("x", tbl->lookup(node->x())); - s.args().append("fused_activation_function", to_str(node->fusedActivationFunction())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleL2Pool2D *node, - locop::NodeSummary &s) -{ - assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED); - - s.args().append("value", tbl->lookup(node->value())); - s.args().append("filter(h,w)", to_str(node->filter())); - s.args().append("stride(h,w)", to_str(node->stride())); - s.args().append("padding", to_str(node->padding())); - s.args().append("fused", to_str(node->fusedActivationFunction())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleLeakyRelu *node, - locop::NodeSummary &s) -{ - s.args().append("features", tbl->lookup(node->features())); - s.args().append("alpha", std::to_string(node->alpha())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleLocalResponseNormalization *node, - locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - s.args().append("radius", pepper::str(node->radius())); - s.args().append("bias", pepper::str(node->bias())); - s.args().append("alpha", pepper::str(node->alpha())); - s.args().append("beta", pepper::str(node->beta())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleLogSoftmax *node, - locop::NodeSummary &s) -{ - s.args().append("logits", tbl->lookup(node->logits())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleMatrixDiag *node, - locop::NodeSummary &s) -{ - s.args().append("diagonal", tbl->lookup(node->diagonal())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleMatrixSetDiag *node, - locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - s.args().append("diagonal", tbl->lookup(node->diagonal())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleMaxPool2D *node, - locop::NodeSummary &s) -{ - assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED); - - s.args().append("value", tbl->lookup(node->value())); - s.args().append("filter(h,w)", to_str(node->filter())); - s.args().append("stride(h,w)", to_str(node->stride())); - s.args().append("padding", to_str(node->padding())); - s.args().append("fused", to_str(node->fusedActivationFunction())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleMirrorPad *node, - locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - s.args().append("paddings", tbl->lookup(node->paddings())); - s.args().append("mode", to_str(node->mode())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleNonMaxSuppressionV4 *node, - locop::NodeSummary &s) -{ - s.args().append("boxes", tbl->lookup(node->boxes())); - s.args().append("scores", tbl->lookup(node->scores())); - s.args().append("max_output_size", tbl->lookup(node->max_output_size())); - s.args().append("iou_threshold", tbl->lookup(node->iou_threshold())); - s.args().append("score_threshold", tbl->lookup(node->score_threshold())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleNonMaxSuppressionV5 *node, - locop::NodeSummary &s) -{ - s.args().append("boxes", tbl->lookup(node->boxes())); - s.args().append("scores", tbl->lookup(node->scores())); - s.args().append("max_output_size", tbl->lookup(node->max_output_size())); - s.args().append("iou_threshold", tbl->lookup(node->iou_threshold())); - s.args().append("score_threshold", tbl->lookup(node->score_threshold())); - s.args().append("soft_nms_sigma", tbl->lookup(node->soft_nms_sigma())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleOneHot *node, - locop::NodeSummary &s) -{ - s.args().append("indices", tbl->lookup(node->indices())); - s.args().append("depth", tbl->lookup(node->depth())); - s.args().append("on_value", tbl->lookup(node->on_value())); - s.args().append("off_value", tbl->lookup(node->off_value())); - s.args().append("axis", pepper::str(node->axis())); - - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CirclePack *node, - locop::NodeSummary &s) -{ - for (uint32_t i = 0; i < node->values_count(); ++i) - s.args().append("values", tbl->lookup(node->values(i))); - s.args().append("values_count", pepper::str(node->values_count())); - s.args().append("axis", pepper::str(node->axis())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CirclePad *node, locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - s.args().append("paddings", tbl->lookup(node->paddings())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CirclePadV2 *node, - locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - s.args().append("paddings", tbl->lookup(node->paddings())); - s.args().append("constant_values", tbl->lookup(node->constant_values())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CirclePRelu *node, - locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - s.args().append("alpha", tbl->lookup(node->alpha())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleRange *node, - locop::NodeSummary &s) -{ - s.args().append("start", tbl->lookup(node->start())); - s.args().append("limit", tbl->lookup(node->limit())); - s.args().append("delta", tbl->lookup(node->delta())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleReshape *node, - locop::NodeSummary &s) -{ - s.args().append("tensor", tbl->lookup(node->tensor())); - s.args().append("shape", tbl->lookup(node->shape())); - // TODO Show newShape info - s.state(locop::NodeSummary::State::PartiallyKnown); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleResizeBilinear *node, - locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - s.args().append("size", tbl->lookup(node->size())); - s.args().append("align_corners", node->align_corners() ? "true" : "false"); - s.args().append("half_pixel_centers", node->half_pixel_centers() ? "true" : "false"); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleResizeNearestNeighbor *node, - locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - s.args().append("size", tbl->lookup(node->size())); - s.args().append("align_corners", node->align_corners() ? "true" : "false"); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleReverseSequence *node, - locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - s.args().append("seq_lengths", tbl->lookup(node->seq_lengths())); - s.args().append("seq_axis", std::to_string(node->seq_axis())); - s.args().append("batch_axis", std::to_string(node->batch_axis())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleReverseV2 *node, - locop::NodeSummary &s) -{ - s.args().append("tensor", tbl->lookup(node->tensor())); - s.args().append("axis", tbl->lookup(node->axis())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleScatterNd *node, - locop::NodeSummary &s) -{ - s.args().append("indices", tbl->lookup(node->indices())); - s.args().append("updates", tbl->lookup(node->updates())); - s.args().append("shape", tbl->lookup(node->shape())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSegmentSum *node, - locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - s.args().append("segment_ids", tbl->lookup(node->segment_ids())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSelect *node, - locop::NodeSummary &s) -{ - s.args().append("condition", tbl->lookup(node->condition())); - s.args().append("t", tbl->lookup(node->t())); - s.args().append("e", tbl->lookup(node->e())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSelectV2 *node, - locop::NodeSummary &s) -{ - s.args().append("condition", tbl->lookup(node->condition())); - s.args().append("t", tbl->lookup(node->t())); - s.args().append("e", tbl->lookup(node->e())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleShape *node, - locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - s.args().append("out_type", to_str(node->out_type())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSlice *node, - locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - s.args().append("begin", tbl->lookup(node->begin())); - s.args().append("size", tbl->lookup(node->size())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSoftmax *node, - locop::NodeSummary &s) -{ - s.args().append("logits", tbl->lookup(node->logits())); - s.args().append("beta", pepper::str(node->beta())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSpaceToBatchND *node, - locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - s.args().append("block_shape", tbl->lookup(node->block_shape())); - s.args().append("paddings", tbl->lookup(node->paddings())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSpaceToDepth *node, - locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - s.args().append("block_size", pepper::str(node->block_size())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSparseToDense *node, - locop::NodeSummary &s) -{ - s.args().append("indices", tbl->lookup(node->indices())); - s.args().append("output_shape", tbl->lookup(node->output_shape())); - s.args().append("values", tbl->lookup(node->values())); - s.args().append("default_value", tbl->lookup(node->default_value())); - s.args().append("Validate_indices", pepper::str(node->validate_indices())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSplit *node, - locop::NodeSummary &s) -{ - s.args().append("split_dim", tbl->lookup(node->split_dim())); - s.args().append("input", tbl->lookup(node->input())); - s.args().append("num_split", pepper::str(node->num_split())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSplitV *node, - locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - s.args().append("size_splits", tbl->lookup(node->size_splits())); - s.args().append("split_dim", tbl->lookup(node->split_dim())); - s.args().append("num_split", pepper::str(node->num_split())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSqueeze *node, - locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - - std::stringstream ss{"("}; - for (size_t i = 0; i < node->squeeze_dims().size(); ++i) - { - if (i != 0) - ss << ", "; - ss << node->squeeze_dims()[i]; - } - ss << ")"; - s.args().append("squeeze_dims", ss.str()); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleStridedSlice *node, - locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - s.args().append("begin", tbl->lookup(node->begin())); - s.args().append("end", tbl->lookup(node->end())); - s.args().append("strides", tbl->lookup(node->strides())); - s.args().append("begin_mask", pepper::str(node->begin_mask())); - s.args().append("end_mask", pepper::str(node->end_mask())); - s.args().append("ellipsis_mask", pepper::str(node->ellipsis_mask())); - s.args().append("new_axis_mask", pepper::str(node->new_axis_mask())); - s.args().append("shrink_axis_mask", pepper::str(node->shrink_axis_mask())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleTile *node, - locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - s.args().append("multiples", tbl->lookup(node->multiples())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleTopKV2 *node, - locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - s.args().append("k", tbl->lookup(node->k())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleTranspose *node, - locop::NodeSummary &s) -{ - s.args().append("a", tbl->lookup(node->a())); - s.args().append("perm", tbl->lookup(node->perm())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleTransposeConv *node, - locop::NodeSummary &s) -{ - assert(node->padding() != luci::Padding::UNDEFINED); - - s.args().append("inputSizes", tbl->lookup(node->inputSizes())); - s.args().append("filter", tbl->lookup(node->filter())); - s.args().append("outBackprop", tbl->lookup(node->outBackprop())); - s.args().append("bias", tbl->lookup(node->bias())); - s.args().append("stride(h,w)", to_str(node->stride())); - s.args().append("padding", to_str(node->padding())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleUnidirectionalSequenceLSTM *node, - locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - - s.args().append("input_to_input_weights", tbl->lookup(node->input_to_input_weights())); - s.args().append("input_to_forget_weights", tbl->lookup(node->input_to_forget_weights())); - s.args().append("input_to_cell_weights", tbl->lookup(node->input_to_cell_weights())); - s.args().append("input_to_output_weights", tbl->lookup(node->input_to_output_weights())); - - s.args().append("recurrent_to_input_weights", tbl->lookup(node->recurrent_to_input_weights())); - s.args().append("recurrent_to_forget_weights", tbl->lookup(node->recurrent_to_forget_weights())); - s.args().append("recurrent_to_cell_weights", tbl->lookup(node->recurrent_to_cell_weights())); - s.args().append("recurrent_to_output_weights", tbl->lookup(node->recurrent_to_output_weights())); - - s.args().append("cell_to_input_weights", tbl->lookup(node->cell_to_input_weights())); - s.args().append("cell_to_forget_weights", tbl->lookup(node->cell_to_forget_weights())); - s.args().append("cell_to_output_weights", tbl->lookup(node->cell_to_output_weights())); - - s.args().append("input_gate_bias", tbl->lookup(node->input_gate_bias())); - s.args().append("forget_gate_bias", tbl->lookup(node->forget_gate_bias())); - s.args().append("cell_gate_bias", tbl->lookup(node->cell_gate_bias())); - s.args().append("output_gate_bias", tbl->lookup(node->output_gate_bias())); - - s.args().append("projection_weights", tbl->lookup(node->projection_weights())); - s.args().append("projection_bias", tbl->lookup(node->projection_bias())); - - s.args().append("activation_state", tbl->lookup(node->activation_state())); - s.args().append("cell_state", tbl->lookup(node->cell_state())); - - s.args().append("input_layer_norm_coefficients", - tbl->lookup(node->input_layer_norm_coefficients())); - s.args().append("forget_layer_norm_coefficients", - tbl->lookup(node->forget_layer_norm_coefficients())); - s.args().append("cell_layer_norm_coefficients", - tbl->lookup(node->cell_layer_norm_coefficients())); - s.args().append("output_layer_norm_coefficients", - tbl->lookup(node->output_layer_norm_coefficients())); - - s.args().append("cell_clip", to_str(node->cell_clip())); - s.args().append("proj_clip", to_str(node->proj_clip())); - s.args().append("time_major", to_str(node->time_major())); - s.args().append("asymmetric_quantize_inputs", to_str(node->asymmetric_quantize_inputs())); - - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleUnique *node, - locop::NodeSummary &s) -{ - s.args().append("input", tbl->lookup(node->input())); - s.args().append("idx_out_type", to_str(node->idx_out_type())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleUnpack *node, - locop::NodeSummary &s) -{ - s.args().append("value", tbl->lookup(node->value())); - s.args().append("num", pepper::str(node->num())); - s.args().append("axis", pepper::str(node->axis())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleWhere *node, - locop::NodeSummary &s) -{ - s.args().append("condition", tbl->lookup(node->condition())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleWhile *node, - locop::NodeSummary &s) -{ - for (uint32_t i = 0; i < node->input_count(); ++i) - s.args().append("input", tbl->lookup(node->input(i))); - - if (node->cond_graph() != nullptr) - s.args().append("cond_graph", node->cond_graph()->name()); - else - s.args().append("cond_branch", pepper::str(node->cond_branch())); - - if (node->body_graph() != nullptr) - s.args().append("body_graph", node->body_graph()->name()); - else - s.args().append("body_branch", pepper::str(node->body_branch())); - - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleTopKV2Out *node, - locop::NodeSummary &s) -{ - s.args().append("topkv2", tbl->lookup(node->input())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleUniqueOut *node, - locop::NodeSummary &s) -{ - s.args().append("unique", tbl->lookup(node->input())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleUnpackOut *node, - locop::NodeSummary &s) -{ - s.args().append("unpack", tbl->lookup(node->input())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleWhileOut *node, - locop::NodeSummary &s) -{ - s.args().append("while", tbl->lookup(node->input())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleOutput *node, - locop::NodeSummary &s) -{ - s.args().append("from", tbl->lookup(node->from())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *, const luci::CircleOutputDummy *, - locop::NodeSummary &s) -{ - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *, const luci::CircleOutputExclude *, - locop::NodeSummary &s) -{ - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleBCQFullyConnected *node, - locop::NodeSummary &s) -{ - assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED); - s.args().append("input", tbl->lookup(node->input())); - s.args().append("weights_scales", tbl->lookup(node->weights_scales())); - s.args().append("weights_binary", tbl->lookup(node->weights_binary())); - s.args().append("bias", tbl->lookup(node->bias())); - s.args().append("weights_clusters", tbl->lookup(node->weights_clusters())); - s.args().append("fused", to_str(node->fusedActivationFunction())); - s.args().append("weights_hidden_size", pepper::str(node->weights_hidden_size())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleBCQGather *node, - locop::NodeSummary &s) -{ - s.args().append("input_scales", tbl->lookup(node->input_scales())); - s.args().append("input_binary", tbl->lookup(node->input_binary())); - s.args().append("indices", tbl->lookup(node->indices())); - s.args().append("input_clusters", tbl->lookup(node->input_clusters())); - s.args().append("axis", pepper::str(node->axis())); - s.args().append("input_hidden_size", pepper::str(node->input_hidden_size())); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool summary_node(const locop::SymbolTable *tbl, const luci::CircleInstanceNorm *node, - locop::NodeSummary &s) -{ - auto fused = node->fusedActivationFunction(); - assert(fused != luci::FusedActFunc::UNDEFINED); - - s.args().append("input", tbl->lookup(node->input())); - s.args().append("gamma", tbl->lookup(node->gamma())); - s.args().append("beta", tbl->lookup(node->beta())); - s.args().append("epsilon", pepper::str(node->epsilon())); - s.args().append("fused_activation_function", to_str(fused)); - s.state(locop::NodeSummary::State::Complete); - return true; -} - -// SummaryBuilderLet type -enum class SB -{ - ABC, - DEF, - GHIJ, - KLMN, - OPQR, - STUV, - WXYZ, - CIRC, // circle only - VIRT, // virtual -}; - -template <SB sb> class SummaryBuilderLet; - -#define IMPLEMENT(CLASS) bool summary(const CLASS *, locop::NodeSummary &) const final; - -template <> class SummaryBuilderLet<SB::ABC> final : public CircleNodeSummaryBuilderBase -{ -public: - SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl) - { - // DO NOTHING - } - -private: - IMPLEMENT(luci::CircleAbs) - IMPLEMENT(luci::CircleAdd) - IMPLEMENT(luci::CircleAddN) - IMPLEMENT(luci::CircleArgMax) - IMPLEMENT(luci::CircleArgMin) - IMPLEMENT(luci::CircleAveragePool2D) - IMPLEMENT(luci::CircleBatchMatMul) - IMPLEMENT(luci::CircleBatchToSpaceND) - IMPLEMENT(luci::CircleBidirectionalSequenceLSTM) - IMPLEMENT(luci::CircleCast) - IMPLEMENT(luci::CircleCeil) - IMPLEMENT(luci::CircleConcatenation) - IMPLEMENT(luci::CircleConst) - IMPLEMENT(luci::CircleConv2D) - IMPLEMENT(luci::CircleCos) - IMPLEMENT(luci::CircleCustom) -}; - -template <> class SummaryBuilderLet<SB::DEF> final : public CircleNodeSummaryBuilderBase -{ -public: - SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl) - { - // DO NOTHING - } - -private: - IMPLEMENT(luci::CircleDepthToSpace) - IMPLEMENT(luci::CircleDepthwiseConv2D) - IMPLEMENT(luci::CircleDequantize) - IMPLEMENT(luci::CircleDiv) - IMPLEMENT(luci::CircleElu) - IMPLEMENT(luci::CircleEqual) - IMPLEMENT(luci::CircleExp) - IMPLEMENT(luci::CircleExpandDims) - IMPLEMENT(luci::CircleFakeQuant) - IMPLEMENT(luci::CircleFill) - IMPLEMENT(luci::CircleFloor) - IMPLEMENT(luci::CircleFloorDiv) - IMPLEMENT(luci::CircleFloorMod) - IMPLEMENT(luci::CircleFullyConnected) -}; - -template <> class SummaryBuilderLet<SB::GHIJ> final : public CircleNodeSummaryBuilderBase -{ -public: - SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl) - { - // DO NOTHING - } - -private: - IMPLEMENT(luci::CircleGather) - IMPLEMENT(luci::CircleGatherNd) - IMPLEMENT(luci::CircleGreater) - IMPLEMENT(luci::CircleGreaterEqual) - IMPLEMENT(luci::CircleIf) -}; - -template <> class SummaryBuilderLet<SB::KLMN> final : public CircleNodeSummaryBuilderBase -{ -public: - SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl) - { - // DO NOTHING - } - -private: - IMPLEMENT(luci::CircleL2Normalize) - IMPLEMENT(luci::CircleL2Pool2D) - IMPLEMENT(luci::CircleLeakyRelu) - IMPLEMENT(luci::CircleLess) - IMPLEMENT(luci::CircleLessEqual) - IMPLEMENT(luci::CircleLocalResponseNormalization) - IMPLEMENT(luci::CircleLog) - IMPLEMENT(luci::CircleLogicalAnd) - IMPLEMENT(luci::CircleLogicalNot) - IMPLEMENT(luci::CircleLogicalOr) - IMPLEMENT(luci::CircleLogistic) - IMPLEMENT(luci::CircleLogSoftmax) - IMPLEMENT(luci::CircleMatrixDiag) - IMPLEMENT(luci::CircleMatrixSetDiag) - IMPLEMENT(luci::CircleMaximum) - IMPLEMENT(luci::CircleMaxPool2D) - IMPLEMENT(luci::CircleMean) - IMPLEMENT(luci::CircleMinimum) - IMPLEMENT(luci::CircleMirrorPad) - IMPLEMENT(luci::CircleMul) - IMPLEMENT(luci::CircleNeg) - IMPLEMENT(luci::CircleNonMaxSuppressionV4) - IMPLEMENT(luci::CircleNonMaxSuppressionV5) - IMPLEMENT(luci::CircleNotEqual) -}; - -template <> class SummaryBuilderLet<SB::OPQR> final : public CircleNodeSummaryBuilderBase -{ -public: - SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl) - { - // DO NOTHING - } - -private: - IMPLEMENT(luci::CircleOneHot) - IMPLEMENT(luci::CirclePack) - IMPLEMENT(luci::CirclePad) - IMPLEMENT(luci::CirclePadV2) - IMPLEMENT(luci::CirclePow) - IMPLEMENT(luci::CirclePRelu) - IMPLEMENT(luci::CircleQuantize) - IMPLEMENT(luci::CircleRange) - IMPLEMENT(luci::CircleRank) - IMPLEMENT(luci::CircleReduceAny) - IMPLEMENT(luci::CircleReduceMax) - IMPLEMENT(luci::CircleReduceMin) - IMPLEMENT(luci::CircleReduceProd) - IMPLEMENT(luci::CircleRelu) - IMPLEMENT(luci::CircleRelu6) - IMPLEMENT(luci::CircleReluN1To1) - IMPLEMENT(luci::CircleReshape) - IMPLEMENT(luci::CircleResizeBilinear) - IMPLEMENT(luci::CircleResizeNearestNeighbor) - IMPLEMENT(luci::CircleReverseSequence) - IMPLEMENT(luci::CircleReverseV2) - IMPLEMENT(luci::CircleRound) - IMPLEMENT(luci::CircleRsqrt) -}; - -template <> class SummaryBuilderLet<SB::STUV> final : public CircleNodeSummaryBuilderBase -{ -public: - SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl) - { - // DO NOTHING - } - -private: - IMPLEMENT(luci::CircleScatterNd) - IMPLEMENT(luci::CircleSegmentSum) - IMPLEMENT(luci::CircleSelect) - IMPLEMENT(luci::CircleSelectV2) - IMPLEMENT(luci::CircleShape) - IMPLEMENT(luci::CircleSin) - IMPLEMENT(luci::CircleSlice) - IMPLEMENT(luci::CircleSoftmax) - IMPLEMENT(luci::CircleSpaceToBatchND) - IMPLEMENT(luci::CircleSpaceToDepth) - IMPLEMENT(luci::CircleSparseToDense) - IMPLEMENT(luci::CircleSplit) - IMPLEMENT(luci::CircleSplitV) - IMPLEMENT(luci::CircleSqrt) - IMPLEMENT(luci::CircleSquare) - IMPLEMENT(luci::CircleSquaredDifference) - IMPLEMENT(luci::CircleSqueeze) - IMPLEMENT(luci::CircleStridedSlice) - IMPLEMENT(luci::CircleSub) - IMPLEMENT(luci::CircleSum) - IMPLEMENT(luci::CircleTanh) - IMPLEMENT(luci::CircleTile) - IMPLEMENT(luci::CircleTopKV2) - IMPLEMENT(luci::CircleTranspose) - IMPLEMENT(luci::CircleTransposeConv) - IMPLEMENT(luci::CircleUnidirectionalSequenceLSTM) - IMPLEMENT(luci::CircleUnique) - IMPLEMENT(luci::CircleUnpack) -}; - -template <> class SummaryBuilderLet<SB::WXYZ> final : public CircleNodeSummaryBuilderBase -{ -public: - SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl) - { - // DO NOTHING - } - -private: - IMPLEMENT(luci::CircleWhere) - IMPLEMENT(luci::CircleWhile) - IMPLEMENT(luci::CircleZerosLike) -}; - -template <> class SummaryBuilderLet<SB::CIRC> final : public CircleNodeSummaryBuilderBase -{ -public: - SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl) - { - // DO NOTHING - } - -private: - IMPLEMENT(luci::CircleBCQFullyConnected) - IMPLEMENT(luci::CircleBCQGather) - IMPLEMENT(luci::CircleInstanceNorm) -}; - -template <> class SummaryBuilderLet<SB::VIRT> final : public CircleNodeSummaryBuilderBase -{ -public: - SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl) - { - // DO NOTHING - } - -private: - IMPLEMENT(luci::CircleInput) - IMPLEMENT(luci::CircleOutput) - IMPLEMENT(luci::CircleCustomOut) - IMPLEMENT(luci::CircleIfOut) - IMPLEMENT(luci::CircleNonMaxSuppressionV4Out) - IMPLEMENT(luci::CircleNonMaxSuppressionV5Out) - IMPLEMENT(luci::CircleOutputDummy) - IMPLEMENT(luci::CircleOutputExclude) - IMPLEMENT(luci::CircleSplitOut) - IMPLEMENT(luci::CircleSplitVOut) - IMPLEMENT(luci::CircleTopKV2Out) - IMPLEMENT(luci::CircleUniqueOut) - IMPLEMENT(luci::CircleUnpackOut) - IMPLEMENT(luci::CircleWhileOut) -}; - -#undef IMPLEMENT - -bool CircleNodeSummaryBuilderBase::build(const loco::Node *node, locop::NodeSummary &s) const -{ - if (node->dialect() != luci::CircleDialect::get()) - return false; - - auto ptr_to_str = [](const void *ptr) { - std::stringstream ss; - ss << ptr; - return ss.str(); - }; - - auto add_comment = [&]() { - auto cnode = loco::must_cast<const luci::CircleNode *>(node); - s.opname(circle_opname(node->opnum())); - s.comments().append("[" + cnode->name() + "] = " + ptr_to_str(node)); - }; - -#define CIRCLE_NODE(OPCODE, CLASS) \ - if (dynamic_cast<const CLASS *>(node)) \ - { \ - if (summary(dynamic_cast<const CLASS *>(node), s)) \ - { \ - add_comment(); \ - return true; \ - } \ - } -#define CIRCLE_VNODE CIRCLE_NODE -#include <luci/IR/CircleNodes.lst> -#undef CIRCLE_VNODE -#undef CIRCLE_NODE - - return false; -} - -bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleAbs *node, locop::NodeSummary &s) const -{ - return use_x(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleAdd *node, locop::NodeSummary &s) const -{ - return use_xy_act(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleAddN *node, locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleArgMax *node, - locop::NodeSummary &s) const -{ - return use_ido(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleArgMin *node, - locop::NodeSummary &s) const -{ - return use_ido(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleAveragePool2D *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleBatchMatMul *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleBatchToSpaceND *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleBidirectionalSequenceLSTM *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleCast *node, locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleCeil *node, locop::NodeSummary &s) const -{ - return use_x(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleConcatenation *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleConst *, locop::NodeSummary &s) const -{ - s.state(locop::NodeSummary::State::PartiallyKnown); - return true; -} - -bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleConv2D *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleCos *node, locop::NodeSummary &s) const -{ - return use_x(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleCustom *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleDepthToSpace *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleDepthwiseConv2D *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleDequantize *node, - locop::NodeSummary &s) const -{ - return use_input(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleDiv *node, locop::NodeSummary &s) const -{ - return use_xy(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleElu *node, locop::NodeSummary &s) const -{ - return use_features(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleEqual *node, locop::NodeSummary &s) const -{ - return use_xy(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleExp *node, locop::NodeSummary &s) const -{ - return use_x(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleExpandDims *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleFakeQuant *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleFill *node, locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleFloor *node, locop::NodeSummary &s) const -{ - return use_x(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleFloorDiv *node, - locop::NodeSummary &s) const -{ - return use_xy(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleFloorMod *node, - locop::NodeSummary &s) const -{ - return use_xy(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleFullyConnected *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::GHIJ>::summary(const luci::CircleGather *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::GHIJ>::summary(const luci::CircleGatherNd *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::GHIJ>::summary(const luci::CircleGreater *node, - locop::NodeSummary &s) const -{ - return use_xy(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::GHIJ>::summary(const luci::CircleGreaterEqual *node, - locop::NodeSummary &s) const -{ - return use_xy(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::GHIJ>::summary(const luci::CircleIf *node, locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleL2Normalize *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleL2Pool2D *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLess *node, locop::NodeSummary &s) const -{ - return use_xy(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLessEqual *node, - locop::NodeSummary &s) const -{ - return use_xy(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLeakyRelu *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLocalResponseNormalization *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLog *node, locop::NodeSummary &s) const -{ - return use_x(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLogicalAnd *node, - locop::NodeSummary &s) const -{ - return use_xy(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLogicalNot *node, - locop::NodeSummary &s) const -{ - return use_x(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLogicalOr *node, - locop::NodeSummary &s) const -{ - return use_xy(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLogistic *node, - locop::NodeSummary &s) const -{ - return use_x(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLogSoftmax *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMatrixDiag *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMatrixSetDiag *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMaximum *node, - locop::NodeSummary &s) const -{ - return use_xy(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMaxPool2D *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMean *node, locop::NodeSummary &s) const -{ - return use_reducer(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMinimum *node, - locop::NodeSummary &s) const -{ - return use_xy(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMirrorPad *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMul *node, locop::NodeSummary &s) const -{ - return use_xy_act(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleNeg *node, locop::NodeSummary &s) const -{ - return use_x(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleNonMaxSuppressionV4 *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleNonMaxSuppressionV5 *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleNotEqual *node, - locop::NodeSummary &s) const -{ - return use_xy(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleOneHot *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CirclePack *node, locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CirclePad *node, locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CirclePadV2 *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CirclePow *node, locop::NodeSummary &s) const -{ - return use_xy(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CirclePRelu *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleQuantize *node, - locop::NodeSummary &s) const -{ - return use_input(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleRange *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleRank *node, locop::NodeSummary &s) const -{ - return use_input(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReduceAny *node, - locop::NodeSummary &s) const -{ - return use_reducer(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReduceMax *node, - locop::NodeSummary &s) const -{ - return use_reducer(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReduceMin *node, - locop::NodeSummary &s) const -{ - return use_reducer(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReduceProd *node, - locop::NodeSummary &s) const -{ - return use_reducer(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleRelu *node, locop::NodeSummary &s) const -{ - return use_features(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleRelu6 *node, - locop::NodeSummary &s) const -{ - return use_features(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReluN1To1 *node, - locop::NodeSummary &s) const -{ - return use_features(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReshape *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleResizeBilinear *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleResizeNearestNeighbor *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReverseSequence *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReverseV2 *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleRound *node, - locop::NodeSummary &s) const -{ - return use_x(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleRsqrt *node, - locop::NodeSummary &s) const -{ - return use_x(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleScatterNd *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSegmentSum *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSelect *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSelectV2 *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleShape *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSin *node, locop::NodeSummary &s) const -{ - return use_x(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSlice *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSoftmax *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSpaceToBatchND *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSpaceToDepth *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSparseToDense *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSplit *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSplitV *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSqrt *node, locop::NodeSummary &s) const -{ - return use_x(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSquare *node, - locop::NodeSummary &s) const -{ - return use_x(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSquaredDifference *node, - locop::NodeSummary &s) const -{ - return use_xy(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSqueeze *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleStridedSlice *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSub *node, locop::NodeSummary &s) const -{ - return use_xy(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSum *node, locop::NodeSummary &s) const -{ - return use_reducer(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleTanh *node, locop::NodeSummary &s) const -{ - return use_x(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleTile *node, locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleTopKV2 *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleTranspose *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleTransposeConv *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleUnidirectionalSequenceLSTM *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleUnique *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleUnpack *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::WXYZ>::summary(const luci::CircleWhere *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::WXYZ>::summary(const luci::CircleWhile *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::WXYZ>::summary(const luci::CircleZerosLike *node, - locop::NodeSummary &s) const -{ - return use_input(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::CIRC>::summary(const luci::CircleBCQFullyConnected *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::CIRC>::summary(const luci::CircleBCQGather *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::CIRC>::summary(const luci::CircleInstanceNorm *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleInput *, locop::NodeSummary &s) const -{ - s.state(locop::NodeSummary::State::Complete); - return true; -} - -bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleOutput *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleCustomOut *node, - locop::NodeSummary &s) const -{ - return use_input(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleIfOut *node, - locop::NodeSummary &s) const -{ - return use_input(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleNonMaxSuppressionV4Out *node, - locop::NodeSummary &s) const -{ - return use_input(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleNonMaxSuppressionV5Out *node, - locop::NodeSummary &s) const -{ - return use_input(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleOutputDummy *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleOutputExclude *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleSplitOut *node, - locop::NodeSummary &s) const -{ - return use_input(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleSplitVOut *node, - locop::NodeSummary &s) const -{ - return use_input(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleTopKV2Out *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleUniqueOut *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleUnpackOut *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleWhileOut *node, - locop::NodeSummary &s) const -{ - return summary_node(tbl(), node, s); -} - -} // namespace - namespace luci { @@ -2208,22 +36,10 @@ bool NodeSummaryBuilder::build(const loco::Node *node, locop::NodeSummary &s) co return true; } -#define BUILD_GRP(GRP) \ - do \ - { \ - if (SummaryBuilderLet<SB::GRP>(_tbl).build(node, s)) \ - return true; \ - } while (false) - - BUILD_GRP(ABC); - BUILD_GRP(DEF); - BUILD_GRP(GHIJ); - BUILD_GRP(KLMN); - BUILD_GRP(OPQR); - BUILD_GRP(STUV); - BUILD_GRP(WXYZ); - BUILD_GRP(CIRC); - BUILD_GRP(VIRT); + if (CircleNodeSummaryBuilder().build(node, _tbl, s)) + { + return true; + } return false; } diff --git a/compiler/luci/partition/CMakeLists.txt b/compiler/luci/partition/CMakeLists.txt index ec8e0b0d6..f28207df2 100644 --- a/compiler/luci/partition/CMakeLists.txt +++ b/compiler/luci/partition/CMakeLists.txt @@ -13,7 +13,7 @@ target_link_libraries(luci_partition PUBLIC luci_lang) target_link_libraries(luci_partition PRIVATE luci_service) target_link_libraries(luci_partition PRIVATE luci_log) target_link_libraries(luci_partition PRIVATE luci_logex) -target_link_libraries(luci_partition PRIVATE mio_circle) +target_link_libraries(luci_partition PRIVATE mio_circle04) target_link_libraries(luci_partition PRIVATE nncc_common) target_link_libraries(luci_partition PRIVATE pepper_csv2vec) target_link_libraries(luci_partition PRIVATE oops) diff --git a/compiler/luci/partition/src/ConnectNode.h b/compiler/luci/partition/src/ConnectNode.h index ebbff7a6a..e60567c69 100644 --- a/compiler/luci/partition/src/ConnectNode.h +++ b/compiler/luci/partition/src/ConnectNode.h @@ -161,6 +161,7 @@ public: void visit(const luci::CircleSquaredDifference *) final; void visit(const luci::CircleSqueeze *) final; void visit(const luci::CircleStridedSlice *) final; + void visit(const luci::CircleSVDF *) final; void visit(const luci::CircleSub *) final; void visit(const luci::CircleSum *) final; void visit(const luci::CircleTanh *) final; @@ -197,6 +198,7 @@ public: void visit(const luci::CircleTopKV2Out *) final; void visit(const luci::CircleUniqueOut *) final; void visit(const luci::CircleUnpackOut *) final; + void visit(const luci::CircleVariable *) final; void visit(const luci::CircleWhileOut *) final; public: diff --git a/compiler/luci/partition/src/Nodes/CircleSVDF.cpp b/compiler/luci/partition/src/Nodes/CircleSVDF.cpp new file mode 100644 index 000000000..f661a794c --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSVDF.cpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleSVDF *node) +{ + auto *cloned = loco::must_cast<luci::CircleSVDF *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *weight_feature = loco::must_cast<luci::CircleNode *>(node->weight_feature()); + luci::CircleNode *weight_time = loco::must_cast<luci::CircleNode *>(node->weight_time()); + luci::CircleNode *bias = loco::must_cast<luci::CircleNode *>(node->bias()); + luci::CircleNode *input_activation_state = + loco::must_cast<luci::CircleNode *>(node->input_activation_state()); + + cloned->input(cn->find_clone(input)); + cloned->weight_feature(cn->find_clone(weight_feature)); + cloned->weight_time(cn->find_clone(weight_time)); + cloned->bias(cn->find_clone(bias)); + cloned->input_activation_state(cn->find_clone(input_activation_state)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleSVDF *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleSVDF.test.cpp b/compiler/luci/partition/src/Nodes/CircleSVDF.test.cpp new file mode 100644 index 000000000..5fae5206e --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSVDF.test.cpp @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleSVDF> +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g) + { + NodeGraphletT<luci::CircleSVDF>::init(g); + + _node->fusedActivationFunction(luci::FusedActFunc::RELU); + } +}; + +class TestNodeGraph : public TestIsOGraph<5>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<5>::init({shape, shape, shape, shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->weight_feature(input(1)); + node()->weight_time(input(2)); + node()->bias(input(3)); + node()->input_activation_state(input(4)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_SVDF) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSVDF *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSVDF *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(5, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); + ASSERT_EQ(cth.inputs(3), clone->arg(3)); + ASSERT_EQ(cth.inputs(4), clone->arg(4)); +} + +TEST(ConnectNodeTest, connect_SVDF_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSVDF *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSVDF *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleVariable.cpp b/compiler/luci/partition/src/Nodes/CircleVariable.cpp new file mode 100644 index 000000000..f7f6f21fd --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleVariable.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleVariable *) +{ + // Nothing to do +} + +} // namespace luci diff --git a/compiler/luci/partition/src/PartitionIRDump.cpp b/compiler/luci/partition/src/PartitionIRDump.cpp index 4f2c26800..0fabfc416 100644 --- a/compiler/luci/partition/src/PartitionIRDump.cpp +++ b/compiler/luci/partition/src/PartitionIRDump.cpp @@ -32,18 +32,18 @@ void dump(std::ostream &os, const PNode *pnode) void dump(std::ostream &os, const PGroup *pgroup) { os << "--- PGroup: " << pgroup->group << std::endl; - os << "Input(s): "; + os << "Input(s): [ "; for (auto &node_in : pgroup->inputs) os << node_in->name() << " "; - os << std::endl; + os << "]" << std::endl; for (auto &pnode : pgroup->pnodes) { dump(os, pnode.get()); } - os << "Output(s): "; + os << "Output(s): [ "; for (auto &node_out : pgroup->outputs) os << node_out->name() << " "; - os << std::endl; + os << "]" << std::endl; } void dump(std::ostream &os, const PGroups *pgroups) @@ -57,7 +57,8 @@ void dump(std::ostream &os, const PGroups *pgroups) { auto node = it->first; auto group = it->second; - os << " Node: " << node << "(" << node->name() << "): " << group << std::endl; + os << " Node: " << node << "(" << luci::opcode_name(node) << "," << node->name() + << "): " << group << std::endl; } } diff --git a/compiler/luci/partition/src/PartitionMerge.cpp b/compiler/luci/partition/src/PartitionMerge.cpp index c517bf93f..4c3971bd8 100644 --- a/compiler/luci/partition/src/PartitionMerge.cpp +++ b/compiler/luci/partition/src/PartitionMerge.cpp @@ -58,9 +58,6 @@ bool is_input_same(const luci::PGroup *pgroup, const luci::PGroups *pgroups) // we need to clone this CircleConst for each graph of the group. if (dynamic_cast<const luci::CircleConst *>(input) != nullptr) continue; - // Skip also for OutputExclude - if (dynamic_cast<const luci::CircleOutputExclude *>(input) != nullptr) - continue; auto input_group = pgroups->group_of(input); // NOTE: all the nodes should be registered and return should be valid group. @@ -87,7 +84,7 @@ bool is_input_same(const luci::PGroup *pgroup, const luci::PGroups *pgroups) input_pgroup = pgroup_input; else { - if (input_pgroup != pgroup_input) + if (input_pgroup->group != pgroup_input->group) return false; } } @@ -96,6 +93,48 @@ bool is_input_same(const luci::PGroup *pgroup, const luci::PGroups *pgroups) } /** + * @brief return true if there is only one output and is fed to same group of nodes + * @note pgroups is used to find group of pgroup + * ex) + * /-- pgroup_user_1 (grp_1) + * --- pgroup + * \-- pgroup_user_2 (grp_2) + * + * return false if grp_1 != grp_2 + */ +bool is_output_same(const luci::PGroup *pgroup, const luci::PGroups *pgroups) +{ + assert(pgroups != nullptr); + assert(pgroup != nullptr); + + std::string group; + for (auto &output : pgroup->outputs) + { + // get output_group + auto output_group = pgroups->group_of(output); + assert(not output_group.empty()); + if (output_group.empty()) + output_group = pgroups->default_group; + + // find all PGroup that uses output + for (auto &pgroup_user : pgroups->pgroups) + { + for (auto &user_inputs : pgroup_user->inputs) + { + if (output == user_inputs) + { + // OK, these are connected, check group is same + if (pgroup_user->group != output_group) + return false; + } + } + } + } + + return true; +} + +/** * @brief merge pgroup into pgroup_i * @note output of pgroup_i should be input of pgroup */ @@ -191,6 +230,9 @@ std::unique_ptr<luci::PGroups> merge_pgroups(const luci::PGroups *s_pgroups) // skip if there are multiple inputs but inputs differ in group if (!is_input_same(pgroup.get(), d_pgroups.get())) continue; + // skip if pgroup has different group for other users of pgroup_i + if (!is_output_same(pgroup_i.get(), d_pgroups.get())) + continue; // TODO add more condition may be needed merge_into(pgroup.get(), pgroup_i.get()); diff --git a/compiler/luci/partition/src/PartitionPGroups.cpp b/compiler/luci/partition/src/PartitionPGroups.cpp index 0080873e6..eaeacf9c4 100644 --- a/compiler/luci/partition/src/PartitionPGroups.cpp +++ b/compiler/luci/partition/src/PartitionPGroups.cpp @@ -46,6 +46,9 @@ public: bool visit(const luci::CircleUniqueOut *) final { return true; } bool visit(const luci::CircleUnpackOut *) final { return true; } bool visit(const luci::CircleWhileOut *) final { return true; } + // For inputs not used + bool visit(const luci::CircleOutputExclude *) final { return true; } + bool visit(const luci::CircleVariable *) final { return true; } // TODO add all virtual nodes // default is false @@ -69,59 +72,80 @@ bool check_allocate_partition(const luci::CircleNode *node) return true; } -class FindGroupToFollow final : public luci::CircleNodeVisitor<const std::string &> +} // namespace + +namespace { -public: - FindGroupToFollow(const luci::PartitionTable &partition, luci::PGroups *pgroups) - : _partition(partition), _pgroups(pgroups) - { - // NOTHING TODO - } -private: - const std::string &groupof(const luci::CircleNode *input) const +std::string group_from_partition(const luci::CircleNode *node, + const luci::PartitionTable &partition) +{ + LOGGER(l); + + auto group = partition.default_group; + + std::string opcodename; // opcodename or opname + + switch (partition.comply) { - auto group = _pgroups->node2group[input]; - assert(not group.empty()); - if (group.empty()) - return _partition.default_group; - return _pgroups->node2group[input]; + case luci::PartitionTable::COMPLY::OPCODE: + { + opcodename = luci::opcode_name(node); + assert(!opcodename.empty()); + + auto it = partition.byopcodes.find(opcodename); + if (it != partition.byopcodes.end()) + group = it->second; + break; + } + case luci::PartitionTable::COMPLY::OPNAME: + { + opcodename = node->name(); + assert(!opcodename.empty()); + + auto it = partition.byopnames.find(opcodename); + if (it != partition.byopnames.end()) + group = it->second; + break; + } + + default: + throw std::runtime_error("Unsupported partition.comply"); } + INFO(l) << "Op: " << node->name() << ": " << opcodename << ", " << node << ", " << group + << std::endl; + + return group; +} + +class IsVirtualInputNode final : public luci::CircleNodeVisitor<bool> +{ public: -#define IMPLEMENT(CLASS) \ - const std::string &visit(const luci::CLASS *node) final \ - { \ - auto input = loco::must_cast<luci::CircleNode *>(node->input()); \ - return groupof(input); \ - } + // TODO check CircleOutputDummy + bool visit(const luci::CircleOutputExclude *) final { return true; } + bool visit(const luci::CircleVariable *) final { return true; } - IMPLEMENT(CircleCustomOut); - IMPLEMENT(CircleIfOut); - IMPLEMENT(CircleNonMaxSuppressionV4Out); - IMPLEMENT(CircleNonMaxSuppressionV5Out); - IMPLEMENT(CircleSplitOut); - IMPLEMENT(CircleSplitVOut); - IMPLEMENT(CircleTopKV2Out); - IMPLEMENT(CircleUniqueOut); - IMPLEMENT(CircleUnpackOut); - IMPLEMENT(CircleWhileOut); - -#undef IMPLEMENT - - // return empty for nothing to do - const std::string &visit(const luci::CircleNode *) final { return _empty_str; } - -private: - const luci::PartitionTable &_partition; - luci::PGroups *_pgroups = nullptr; - std::string _empty_str; + // default is false + bool visit(const luci::CircleNode *) final { return false; } }; -} // namespace - -namespace +class IsMultiOutputNode final : public luci::CircleNodeVisitor<bool> { +public: + bool visit(const luci::CircleCustom *) final { return true; } + bool visit(const luci::CircleIf *) final { return true; } + bool visit(const luci::CircleNonMaxSuppressionV4 *) final { return true; } + bool visit(const luci::CircleNonMaxSuppressionV5 *) final { return true; } + bool visit(const luci::CircleSplit *) final { return true; } + bool visit(const luci::CircleSplitV *) final { return true; } + bool visit(const luci::CircleTopKV2 *) final { return true; } + bool visit(const luci::CircleUnique *) final { return true; } + bool visit(const luci::CircleUnpack *) final { return true; } + bool visit(const luci::CircleWhile *) final { return true; } + // default is false + bool visit(const luci::CircleNode *) final { return false; } +}; void append(luci::CircleNode *node, luci::PGroups *pgroups, const std::string &group, uint32_t idx) { @@ -136,17 +160,56 @@ void append(luci::CircleNode *node, luci::PGroups *pgroups, const std::string &g pgroup->pnodes.push_back(std::move(pnode)); + IsVirtualInputNode queryvi; // Set input of PGroup for (uint32_t in = 0; in < node->arity(); ++in) { auto input = loco::must_cast<luci::CircleNode *>(node->arg(in)); - // this input maybe CircleInput in source graph - // --> not confident this is safe - pgroup->inputs.push_back(input); + if (input->accept(&queryvi)) + { + auto pnode = std::make_unique<luci::PNode>(); + pnode->node = input; + pnode->group = group; + pnode->pgroup = pgroup.get(); + + pgroup->pnodes.push_back(std::move(pnode)); + + pgroups->node2group[input] = group; + } + else + { + // this input maybe CircleInput in source graph + // --> not confident this is safe + pgroup->inputs.push_back(input); + } + } + + IsMultiOutputNode query; + if (node->accept(&query)) + { + // Include CircleXXXOut virtual nodes in this group + auto succs = loco::succs(node); + for (auto &succ_node : succs) + { + auto nodeout = loco::must_cast<luci::CircleNode *>(succ_node); + + auto pnode = std::make_unique<luci::PNode>(); + pnode->node = nodeout; + pnode->group = group; + pnode->pgroup = pgroup.get(); + + pgroup->pnodes.push_back(std::move(pnode)); + + pgroups->node2group[nodeout] = group; + + pgroup->outputs.push_back(nodeout); + } + } + else + { + // Set output of PGroup: node itself + pgroup->outputs.push_back(node); } - // Set output of PGroup: node itself or multiple virtual outputs - // TODO support multiple virtual outputs - pgroup->outputs.push_back(node); pgroups->node2group[node] = group; pgroups->id2pgroup[pgroup->id] = pgroup.get(); @@ -182,70 +245,9 @@ std::unique_ptr<luci::PGroups> produce_pgroups(const luci::Module *source, // check if node is normal node that we are interested if (check_allocate_partition(node)) { - auto group = partition.default_group; - - std::string opcodename; // opcodename or opname - - switch (partition.comply) - { - case luci::PartitionTable::COMPLY::OPCODE: - { - opcodename = luci::opcode_name(node); - assert(!opcodename.empty()); - - auto it = partition.byopcodes.find(opcodename); - if (it != partition.byopcodes.end()) - group = it->second; - break; - } - case luci::PartitionTable::COMPLY::OPNAME: - { - opcodename = node->name(); - assert(!opcodename.empty()); - - auto it = partition.byopnames.find(opcodename); - if (it != partition.byopnames.end()) - group = it->second; - break; - } - - default: - throw std::runtime_error("Unsupported partition.comply"); - } - - INFO(l) << "Op: " << node->name() << ": " << opcodename << ", " << node << ", " << group - << std::endl; + auto group = group_from_partition(node, partition); append(node, pgroups.get(), group, idx); -#if 0 - auto pgroup = std::make_unique<luci::PGroup>(); - pgroup->group = group; - pgroup->id = idx + 1; - - auto pnode = std::make_unique<luci::PNode>(); - pnode->node = node; - pnode->group = group; - pnode->pgroup = pgroup.get(); - - pgroup->pnodes.push_back(std::move(pnode)); - - // Set input of PGroup - for (uint32_t in = 0; in < node->arity(); ++in) - { - auto input = loco::must_cast<luci::CircleNode *>(node->arg(in)); - // this input maybe CircleInput in source graph - // --> not confident this is safe - pgroup->inputs.push_back(input); - } - // Set output of PGroup: node itself or multiple virtual outputs - // TODO support multiple virtual outputs - pgroup->outputs.push_back(node); - - pgroups->node2group[node] = group; - pgroups->id2pgroup[pgroup->id] = pgroup.get(); - - pgroups->pgroups.push_back(std::move(pgroup)); -#endif } else { @@ -255,22 +257,6 @@ std::unique_ptr<luci::PGroups> produce_pgroups(const luci::Module *source, } } - // handle for virtual nodes like multiple outputs - // these nodes should follow group of the input - for (uint32_t idx = 0; idx < nodes->size(); ++idx) - { - auto node = loco::must_cast<luci::CircleNode *>(nodes->at(idx)); - - // for virtual nodes like CircleUnpackOut should follow it's input (owner) - // or just set to default - FindGroupToFollow query(partition, pgroups.get()); - const auto &group = node->accept(&query); - if (not group.empty()) - { - append(node, pgroups.get(), group, idx); - } - } - return std::move(pgroups); } diff --git a/compiler/luci/pass/CMakeLists.txt b/compiler/luci/pass/CMakeLists.txt index b8b406a38..5237c6d3f 100644 --- a/compiler/luci/pass/CMakeLists.txt +++ b/compiler/luci/pass/CMakeLists.txt @@ -1,4 +1,4 @@ -nnas_find_package(FlatBuffers EXACT 1.12 QUIET) +nnas_find_package(FlatBuffers EXACT 2.0 QUIET) if(NOT FlatBuffers_FOUND) message(STATUS "FlatBuffers NOT FOUND") return() @@ -23,11 +23,11 @@ target_link_libraries(luci_pass PRIVATE luci_log) target_link_libraries(luci_pass PRIVATE luci_service) target_link_libraries(luci_pass PRIVATE luci_logex) target_link_libraries(luci_pass PRIVATE luci_profile) -target_link_libraries(luci_pass PRIVATE mio_tflite260_inc) +target_link_libraries(luci_pass PRIVATE mio_tflite280_inc) target_link_libraries(luci_pass PRIVATE nncc_common) target_link_libraries(luci_pass PRIVATE pepper_csv2vec) target_link_libraries(luci_pass PRIVATE oops) -target_link_libraries(luci_pass PRIVATE flatbuffers-1.12) +target_link_libraries(luci_pass PRIVATE flatbuffers-2.0) install(TARGETS luci_pass DESTINATION lib) install(DIRECTORY include/ DESTINATION include FILES_MATCHING PATTERN "*.h") @@ -43,5 +43,5 @@ target_include_directories(luci_pass_test PRIVATE src) target_link_libraries(luci_pass_test luci_pass) target_link_libraries(luci_pass_test luci_lang) target_link_libraries(luci_pass_test luci_testhelper) -target_link_libraries(luci_pass_test flatbuffers-1.12) +target_link_libraries(luci_pass_test flatbuffers-2.0) #target_link_libraries(luci_pass_test oops) diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h index 658563ecf..c803898f6 100644 --- a/compiler/luci/pass/include/luci/CircleOptimizer.h +++ b/compiler/luci/pass/include/luci/CircleOptimizer.h @@ -47,15 +47,12 @@ public: ResolveCustomOpBatchMatMul, ResolveCustomOpMatMul, ResolveCustomOpMaxPoolWithArgmax, - QuantizeDequantizeWeights, - QuantizeWithMinMax, - Requantize, FoldAddV2, FoldCast, FoldDepthwiseConv2D, FoldDequantize, + FoldGather, FoldSparseToDense, - ForceQuantParam, ForwardReshapeToUnaryOp, SparsifyTensorPass, FusePreActivationBatchNorm, @@ -79,6 +76,7 @@ public: TransformMinReluToRelu6Pass, SubstituteStridedSliceToReshape, SubstituteTransposeToReshape, + RemoveRedundantQuantize, RemoveRedundantReshape, RemoveFakeQuant, RemoveQuantDequantSeq, @@ -86,16 +84,6 @@ public: enum AlgorithmParameters { - // quantize - Quantize_input_model_dtype, - Quantize_output_model_dtype, - Quantize_granularity, // layer-wise or channel-wise - Quantize_tensor_names, - Quantize_scales, - Quantize_zero_points, - Quantize_input_type, - Quantize_output_type, - // sparsify Sparsify_tensor_name, Sparsify_traversal_order, @@ -114,8 +102,6 @@ public: virtual bool query(Algorithm) = 0; virtual void param(AlgorithmParameters, const std::string &) = 0; virtual const std::string param(AlgorithmParameters) const = 0; - virtual void params(AlgorithmParameters, std::vector<std::string> &) = 0; - virtual std::vector<std::string> params(AlgorithmParameters) const = 0; }; public: @@ -127,8 +113,6 @@ public: void optimize(loco::Graph *) const; - void quantize(loco::Graph *) const; - void sparsify(loco::Graph *) const; private: diff --git a/compiler/luci/pass/include/luci/CircleQuantizer.h b/compiler/luci/pass/include/luci/CircleQuantizer.h new file mode 100644 index 000000000..4e7074d98 --- /dev/null +++ b/compiler/luci/pass/include/luci/CircleQuantizer.h @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_CIRCLE_QUANTIZER_H__ +#define __LUCI_CIRCLE_QUANTIZER_H__ + +#include <loco.h> + +#include <string> +#include <vector> + +namespace luci +{ + +class CircleQuantizer final +{ +public: + struct Options + { + struct LayerParam + { + std::string name; + std::string dtype; + std::string granularity; + }; + + enum Algorithm + { + QuantizeDequantizeWeights, + QuantizeWithMinMax, + Requantize, + CopyQuantParam, + ForceQuantParam, + ConvertToFakeQuantizedModel, + }; + + enum AlgorithmParameters + { + // quantize + Quantize_input_model_dtype, + Quantize_output_model_dtype, + Quantize_granularity, // layer-wise or channel-wise + Quantize_tensor_names, + Quantize_scales, + Quantize_zero_points, + Quantize_layer_params, + + // copy_quantparam + Quantize_src_tensor_names, + Quantize_dst_tensor_names, + + Quantize_input_type, + Quantize_output_type, + Quantize_TF_style_maxpool, + }; + + virtual ~Options() = default; + + virtual void enable(Algorithm) = 0; + virtual bool query(Algorithm) = 0; + virtual void param(AlgorithmParameters, const std::string &) = 0; + virtual const std::string param(AlgorithmParameters) const = 0; + virtual void params(AlgorithmParameters, std::vector<std::string> &) = 0; + virtual std::vector<std::string> params(AlgorithmParameters) const = 0; + + // Quantization parameters for multiple layers + virtual void layer_params(AlgorithmParameters, std::vector<std::shared_ptr<LayerParam>> &) = 0; + virtual std::vector<std::shared_ptr<LayerParam>> layer_params(AlgorithmParameters) const = 0; + }; + +public: + // TODO maybe caller can provide Options as ctor parameters + Options *options(void); + +public: + void quantize(loco::Graph *) const; + +private: + std::unique_ptr<Options> _options; +}; + +} // namespace luci + +#endif // __LUCI_CIRCLE_QUANTIZER_H__ diff --git a/compiler/luci/pass/include/luci/Pass/ConvertToFakeQuantizedModelPass.h b/compiler/luci/pass/include/luci/Pass/ConvertToFakeQuantizedModelPass.h new file mode 100644 index 000000000..91dd2300e --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/ConvertToFakeQuantizedModelPass.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_CONVERT_TO_FAKE_QUANTIZED_MODEL_PASS_H__ +#define __LUCI_CONVERT_TO_FAKE_QUANTIZED_MODEL_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to convert a quantized model to a fake-quantized fp32 model. + */ +struct ConvertToFakeQuantizedModelPass final : public logo::Pass +{ + ConvertToFakeQuantizedModelPass() {} + + const char *name(void) const final { return "luci::ConvertToFakeQuantizedModelPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_CONVERT_TO_FAKE_QUANTIZED_MODEL_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/CopyQuantParamPass.h b/compiler/luci/pass/include/luci/Pass/CopyQuantParamPass.h new file mode 100644 index 000000000..18c9cd56a --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/CopyQuantParamPass.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_COPY_QUANT_PARAM_PASS_H__ +#define __LUCI_COPY_QUANT_PARAM_PASS_H__ + +#include <loco.h> + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Pass to copy quantparam (scale, zerop) of a tensor to another tensor + */ +class CopyQuantParamPass : public logo::Pass +{ +public: + using TensorVector = std::vector<std::string>; + +public: + CopyQuantParamPass(TensorVector &src_tensors, TensorVector &dst_tensors) + : _src_tensors{src_tensors}, _dst_tensors{dst_tensors} + { + // DO NOTHING + } + virtual const char *name(void) const { return "luci::CopyQuantParamPass"; } + +public: + bool run(loco::Graph *graph); + +private: + TensorVector _src_tensors; + TensorVector _dst_tensors; +}; + +} // namespace luci + +#endif //__LUCI_COPY_QUANT_PARAM_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/FoldGatherPass.h b/compiler/luci/pass/include/luci/Pass/FoldGatherPass.h new file mode 100644 index 000000000..de08c8845 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FoldGatherPass.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FOLD_GATHER_PASS_H__ +#define __LUCI_FOLD_GATHER_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to fold Gather to a constant tensor + * + */ +struct FoldGatherPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FoldGatherPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FOLD_GATHER_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/PropagateQParamBackwardPass.h b/compiler/luci/pass/include/luci/Pass/PropagateQParamBackwardPass.h new file mode 100644 index 000000000..0c489fc30 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/PropagateQParamBackwardPass.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_PROPAGATE_QPARAM_BACKWARD_PASS_H__ +#define __LUCI_PROPAGATE_QPARAM_BACKWARD_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to propagate quantization parameters of an operator's output to input + */ +struct PropagateQParamBackwardPass final : public logo::Pass +{ + PropagateQParamBackwardPass(loco::DataType output) : _output_model_dtype(output) {} + + const char *name(void) const final { return "luci::PropagateQParamBackwardPass"; } + + bool run(loco::Graph *g) final; + +private: + loco::DataType _output_model_dtype; +}; + +} // namespace luci + +#endif // __LUCI_PROPAGATE_QPARAM_BACKWARD_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/PropagateQuantParamPass.h b/compiler/luci/pass/include/luci/Pass/PropagateQParamForwardPass.h index 7e0c44b8c..952bd9614 100644 --- a/compiler/luci/pass/include/luci/Pass/PropagateQuantParamPass.h +++ b/compiler/luci/pass/include/luci/Pass/PropagateQParamForwardPass.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef __LUCI_PROPAGATE_QUANT_PARAM_PASS_H__ -#define __LUCI_PROPAGATE_QUANT_PARAM_PASS_H__ +#ifndef __LUCI_PROPAGATE_QPARAM_FORWARD_PASS_H__ +#define __LUCI_PROPAGATE_QPARAM_FORWARD_PASS_H__ #include <logo/Pass.h> @@ -23,15 +23,22 @@ namespace luci { /** - * @brief Class to propagate quantization parameters of an operator's output to input + * @brief Class to propagate quantization parameters of an operator's input to output */ -struct PropagateQuantParamPass final : public logo::Pass +struct PropagateQParamForwardPass final : public logo::Pass { - const char *name(void) const final { return "luci::PropagateQuantParamPass"; } + PropagateQParamForwardPass(bool TF_style_maxpool) : _TF_style_maxpool(TF_style_maxpool) {} + + PropagateQParamForwardPass() {} + + const char *name(void) const final { return "luci::PropagateQParamForwardPass"; } bool run(loco::Graph *g) final; + +private: + bool _TF_style_maxpool = false; }; } // namespace luci -#endif // __LUCI_PROPAGATE_QUANT_PARAM_PASS_H__ +#endif // __LUCI_PROPAGATE_QPARAM_FORWARD_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/QuantizationParameters.h b/compiler/luci/pass/include/luci/Pass/QuantizationParameters.h index 5c9cd427f..30c8db058 100644 --- a/compiler/luci/pass/include/luci/Pass/QuantizationParameters.h +++ b/compiler/luci/pass/include/luci/Pass/QuantizationParameters.h @@ -17,6 +17,10 @@ #ifndef __LUCI_QUANTIZATION_PARAMETERS_H__ #define __LUCI_QUANTIZATION_PARAMETERS_H__ +#include <loco.h> + +#include <string> + namespace luci { @@ -26,6 +30,13 @@ enum QuantizationGranularity ChannelWise = 1, }; +struct LayerInfo +{ + std::string name; + loco::DataType dtype; + QuantizationGranularity granularity; +}; + } // namespace luci #endif // __LUCI_QUANTIZATION_PARAMETERS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/QuantizeDequantizeWeightsPass.h b/compiler/luci/pass/include/luci/Pass/QuantizeDequantizeWeightsPass.h index 68765ec5b..1825ee1aa 100644 --- a/compiler/luci/pass/include/luci/Pass/QuantizeDequantizeWeightsPass.h +++ b/compiler/luci/pass/include/luci/Pass/QuantizeDequantizeWeightsPass.h @@ -32,12 +32,30 @@ namespace luci class QuantizeDequantizeWeightsPass : public logo::Pass { public: + struct Context + { + loco::DataType input_model_dtype = loco::DataType::Unknown; + loco::DataType output_model_dtype = loco::DataType::Unknown; + QuantizationGranularity granularity = QuantizationGranularity::ChannelWise; + std::vector<LayerInfo> layers_info; + }; + +public: + QuantizeDequantizeWeightsPass(std::unique_ptr<Context> &&ctx) : _ctx{std::move(ctx)} + { + // DO NOTHING + } + +public: QuantizeDequantizeWeightsPass(loco::DataType input_model_dtype, loco::DataType output_model_dtype, QuantizationGranularity granularity) - : _input_model_dtype{input_model_dtype}, _output_model_dtype{output_model_dtype}, _granularity{ - granularity} { - // DO NOTHING + _ctx = std::make_unique<Context>(); + { + _ctx->input_model_dtype = input_model_dtype; + _ctx->output_model_dtype = output_model_dtype; + _ctx->granularity = granularity; + } } virtual const char *name(void) const { return "luci::QuantizeDequantizeWeightsPass"; } @@ -45,9 +63,7 @@ public: bool run(loco::Graph *graph); private: - loco::DataType _input_model_dtype; - loco::DataType _output_model_dtype; - QuantizationGranularity _granularity; + std::unique_ptr<Context> _ctx; }; } // namespace luci diff --git a/compiler/luci/pass/include/luci/Pass/QuantizePreCheckerPass.h b/compiler/luci/pass/include/luci/Pass/QuantizePreCheckerPass.h new file mode 100644 index 000000000..c852f88e0 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/QuantizePreCheckerPass.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_QUANTIZE_PRE_CHECKER_PASS_H__ +#define __LUCI_QUANTIZE_PRE_CHECKER_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Pass to verify the input model has the form acceptable by quantizer + */ +class QuantizePreCheckerPass : public logo::Pass +{ +public: + const char *name(void) const final { return "luci::QuantizePreCheckerPass"; } + +public: + bool run(loco::Graph *graph) final; +}; + +} // namespace luci + +#endif //__LUCI_QUANTIZE_PRE_CHECKER_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h b/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h index 648abad70..ea6db85d1 100644 --- a/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h +++ b/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h @@ -23,6 +23,8 @@ #include <luci/Pass/QuantizationParameters.h> +#include <vector> + namespace luci { @@ -31,26 +33,41 @@ namespace luci */ class QuantizeWithMinMaxPass : public logo::Pass { +public: + struct Context + { + loco::DataType input_model_dtype = loco::DataType::Unknown; + loco::DataType output_model_dtype = loco::DataType::Unknown; + QuantizationGranularity granularity = QuantizationGranularity::ChannelWise; + loco::DataType input_type = loco::DataType::Unknown; + loco::DataType output_type = loco::DataType::Unknown; + bool TF_style_maxpool = false; + std::vector<LayerInfo> layers_info; + }; + // For backward-compatibility // TODO Remove this constructor public: QuantizeWithMinMaxPass(loco::DataType input_model_dtype, loco::DataType output_model_dtype, QuantizationGranularity granularity) - : _input_model_dtype{input_model_dtype}, _output_model_dtype{output_model_dtype}, - _granularity{granularity}, _input_type{output_model_dtype}, _output_type{output_model_dtype} { - // DO NOTHING + _ctx = std::make_unique<Context>(); + { + _ctx->input_model_dtype = input_model_dtype; + _ctx->output_model_dtype = output_model_dtype; + _ctx->granularity = granularity; + _ctx->input_type = output_model_dtype; + _ctx->output_type = output_model_dtype; + _ctx->TF_style_maxpool = false; + } } public: - QuantizeWithMinMaxPass(loco::DataType input_model_dtype, loco::DataType output_model_dtype, - QuantizationGranularity granularity, loco::DataType input_type, - loco::DataType output_type) - : _input_model_dtype{input_model_dtype}, _output_model_dtype{output_model_dtype}, - _granularity{granularity}, _input_type{input_type}, _output_type{output_type} + QuantizeWithMinMaxPass(std::unique_ptr<Context> &&ctx) : _ctx{std::move(ctx)} { // DO NOTHING } + virtual const char *name(void) const { return "luci::QuantizeWithMinMaxPass"; } public: @@ -61,11 +78,7 @@ private: void set_output_type(loco::Graph *graph) const; private: - loco::DataType _input_model_dtype; - loco::DataType _output_model_dtype; - QuantizationGranularity _granularity; - loco::DataType _input_type; - loco::DataType _output_type; + std::unique_ptr<Context> _ctx; }; } // namespace luci diff --git a/compiler/luci/pass/include/luci/Pass/RemoveRedundantQuantizePass.h b/compiler/luci/pass/include/luci/Pass/RemoveRedundantQuantizePass.h new file mode 100644 index 000000000..3e76bcdc3 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/RemoveRedundantQuantizePass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_REMOVE_REDUNDANT_QUANTIZE_PASS_H__ +#define __LUCI_REMOVE_REDUNDANT_QUANTIZE_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to remove redundant quantize operations + */ +struct RemoveRedundantQuantizePass final : public logo::Pass +{ + const char *name(void) const final { return "luci::RemoveRedundantQuantizePass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_REMOVE_REDUNDANT_QUANTIZE_PASS_H__ diff --git a/compiler/luci/pass/src/BatchNormPatternFinder.cpp b/compiler/luci/pass/src/BatchNormPatternFinder.cpp index c1a06bfda..e3f126b15 100644 --- a/compiler/luci/pass/src/BatchNormPatternFinder.cpp +++ b/compiler/luci/pass/src/BatchNormPatternFinder.cpp @@ -44,10 +44,26 @@ bool is_batchnorm_add(const luci::CircleAdd *add, luci::CircleMul *&mul, luci::C return false; } - if (constant->rank() != 1) + uint32_t channel_dim = 0; + + if (constant->rank() == 1) + { + channel_dim = constant->dim(0).value(); + } + else if (constant->rank() == 4) + { + for (uint32_t i = 0; i < 3; i++) + { + if (constant->dim(i).value() != 1) + return false; + } + channel_dim = constant->dim(3).value(); + } + else + { return false; + } - auto channel_dim = constant->dim(0); // Assumption: Layout is channel-last if (!(channel_dim == add->dim(add->rank() - 1))) return false; @@ -90,10 +106,26 @@ bool is_batchnorm_mul(const luci::CircleMul *mul, luci::CircleNode *&pred_node, return false; } - if (constant->rank() != 1) + uint32_t channel_dim = 0; + + if (constant->rank() == 1) + { + channel_dim = constant->dim(0).value(); + } + else if (constant->rank() == 4) + { + for (uint32_t i = 0; i < 3; i++) + { + if (constant->dim(i).value() != 1) + return false; + } + channel_dim = constant->dim(3).value(); + } + else + { return false; + } - auto channel_dim = constant->dim(0); // Assumption: Layout is channel-last if (!(channel_dim == mul->dim(mul->rank() - 1))) return false; diff --git a/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp b/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp index 08e7fac1c..cc8c5615f 100644 --- a/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp +++ b/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp @@ -50,7 +50,7 @@ public: auto channel_size = *last_it; _add->shape(shape); - _add_beta->shape({channel_size}); + set_beta_shape(channel_size); _add_beta->size<loco::DataType::FLOAT32>(channel_size); for (uint32_t i = 0; i < channel_size; i++) _add_beta->at<loco::DataType::FLOAT32>(i) = i; @@ -63,10 +63,23 @@ public: luci::CircleAdd *add() { return _add; } protected: + virtual void set_beta_shape(uint32_t channel) = 0; + +protected: luci::CircleAdd *_add = nullptr; luci::CircleConst *_add_beta = nullptr; }; +class AddRank1BetaGraphlet : public AddBetaGraphlet +{ + void set_beta_shape(uint32_t channel) final { _add_beta->shape({channel}); } +}; + +class AddRank4BetaGraphlet : public AddBetaGraphlet +{ + void set_beta_shape(uint32_t channel) final { _add_beta->shape({1, 1, 1, channel}); } +}; + /** * @brief Graphlet with Mul and Const as gamma from BatchNorm */ @@ -90,7 +103,7 @@ public: auto channel_size = *last_it; _mul->shape(shape); - _mul_gamma->shape({channel_size}); + set_gamma_shape(channel_size); _mul_gamma->size<loco::DataType::FLOAT32>(channel_size); for (uint32_t i = 0; i < channel_size; i++) _mul_gamma->at<loco::DataType::FLOAT32>(i) = i; @@ -103,14 +116,27 @@ public: luci::CircleMul *mul(void) { return _mul; } protected: + virtual void set_gamma_shape(uint32_t channel) = 0; + +protected: luci::CircleMul *_mul = nullptr; luci::CircleConst *_mul_gamma = nullptr; }; +class MulRank1GammaGraphlet : public MulGammaGraphlet +{ + void set_gamma_shape(uint32_t channel) final { _mul_gamma->shape({channel}); } +}; + +class MulRank4GammaGraphlet : public MulGammaGraphlet +{ + void set_gamma_shape(uint32_t channel) final { _mul_gamma->shape({1, 1, 1, channel}); } +}; + /** * @brief Graph of Mul-Add pattern from BatchNorm */ -class MulAddGraph : public TestIOGraph, public AddBetaGraphlet, public MulGammaGraphlet +class MulAddGraph : public TestIOGraph, public AddRank1BetaGraphlet, public MulRank1GammaGraphlet { public: MulAddGraph() = default; @@ -118,8 +144,30 @@ public: void init(const ShapeU32 shape_in, const ShapeU32 shape_out) { TestIOGraph::init(shape_in, shape_out); - MulGammaGraphlet::init(g(), shape_in, luci::FusedActFunc::NONE); - AddBetaGraphlet::init(g(), shape_out, luci::FusedActFunc::RELU); + MulRank1GammaGraphlet::init(g(), shape_in, luci::FusedActFunc::NONE); + AddRank1BetaGraphlet::init(g(), shape_out, luci::FusedActFunc::RELU); + + // connect network + _mul->x(input()); + _mul->y(_mul_gamma); + _add->x(_mul); + _add->y(_add_beta); + output()->from(_add); + } +}; + +class MulAddRank4Graph : public TestIOGraph, + public AddRank4BetaGraphlet, + public MulRank4GammaGraphlet +{ +public: + MulAddRank4Graph() = default; + + void init(const ShapeU32 shape_in, const ShapeU32 shape_out) + { + TestIOGraph::init(shape_in, shape_out); + MulRank4GammaGraphlet::init(g(), shape_in, luci::FusedActFunc::NONE); + AddRank4BetaGraphlet::init(g(), shape_out, luci::FusedActFunc::RELU); // connect network _mul->x(input()); @@ -133,7 +181,7 @@ public: /** * @brief Graph of Add with Const */ -class AddGraph : public TestIOGraph, public AddBetaGraphlet +class AddGraph : public TestIOGraph, public AddRank1BetaGraphlet { public: AddGraph() = default; @@ -141,7 +189,24 @@ public: void init(const ShapeU32 shape_in, const ShapeU32 shape_out) { TestIOGraph::init(shape_in, shape_out); - AddBetaGraphlet::init(g(), shape_in, luci::FusedActFunc::RELU); + AddRank1BetaGraphlet::init(g(), shape_in, luci::FusedActFunc::RELU); + + // connect network + _add->x(input()); + _add->y(_add_beta); + output()->from(_add); + } +}; + +class AddRank4Graph : public TestIOGraph, public AddRank4BetaGraphlet +{ +public: + AddRank4Graph() = default; + + void init(const ShapeU32 shape_in, const ShapeU32 shape_out) + { + TestIOGraph::init(shape_in, shape_out); + AddRank4BetaGraphlet::init(g(), shape_in, luci::FusedActFunc::RELU); // connect network _add->x(input()); @@ -160,6 +225,7 @@ public: protected: luci::test::MulAddGraph _mag; + luci::test::MulAddRank4Graph _mag_r4; }; class BatchNormPatternFinderAddTest : public ::testing::Test @@ -169,6 +235,7 @@ public: protected: luci::test::AddGraph _ag; + luci::test::AddRank4Graph _ag_r4; }; TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_add) @@ -192,6 +259,19 @@ TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_add2) ASSERT_TRUE(res); } +TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_add_rank4) +{ + _mag_r4.init({1, 16, 16, 4}, {1, 16, 16, 4}); + + luci::CircleMul *mul = nullptr; + luci::CircleConst *beta = nullptr; + + auto res = luci::is_batchnorm_add(_mag_r4.add(), mul, beta); + ASSERT_TRUE(res); + ASSERT_NE(nullptr, mul); + ASSERT_NE(nullptr, beta); +} + TEST_F(BatchNormPatternFinderAddTest, is_batchnorm_add_NEG) { _ag.init({1, 16, 16, 4}, {1, 16, 16, 4}); @@ -215,3 +295,16 @@ TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_mul) ASSERT_NE(nullptr, pred); ASSERT_NE(nullptr, gamma); } + +TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_mul_rank4) +{ + _mag_r4.init({1, 16, 16, 4}, {1, 16, 16, 4}); + + luci::CircleNode *pred = nullptr; + luci::CircleConst *gamma = nullptr; + + auto res = luci::is_batchnorm_mul(_mag_r4.mul(), pred, gamma); + ASSERT_TRUE(res); + ASSERT_NE(nullptr, pred); + ASSERT_NE(nullptr, gamma); +} diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp index 75f04b3b5..6dbb22d7c 100644 --- a/compiler/luci/pass/src/CircleOptimizer.cpp +++ b/compiler/luci/pass/src/CircleOptimizer.cpp @@ -22,9 +22,9 @@ #include "luci/Pass/FoldCastPass.h" #include "luci/Pass/FoldDepthwiseConv2DPass.h" #include "luci/Pass/FoldDequantizePass.h" +#include "luci/Pass/FoldGatherPass.h" #include "luci/Pass/FoldSparseToDensePass.h" #include "luci/Pass/ForwardReshapeToUnaryOpPass.h" -#include "luci/Pass/ForceQuantParamPass.h" #include "luci/Pass/FuseActivationFunctionPass.h" #include "luci/Pass/FuseAddWithFullyConnectedPass.h" #include "luci/Pass/FuseAddWithTConvPass.h" @@ -37,11 +37,11 @@ #include "luci/Pass/FusePreActivationBatchNormPass.h" #include "luci/Pass/FuseTransposeWithMeanPass.h" #include "luci/Pass/MakeBatchNormGammaPositivePass.h" -#include "luci/Pass/PropagateQuantParamPass.h" #include "luci/Pass/RemoveFakeQuantPass.h" #include "luci/Pass/RemoveQuantDequantSeqPass.h" #include "luci/Pass/RemoveRedundantReshapePass.h" #include "luci/Pass/RemoveRedundantTransposePass.h" +#include "luci/Pass/RemoveRedundantQuantizePass.h" #include "luci/Pass/RemoveUnnecessaryReshapePass.h" #include "luci/Pass/RemoveUnnecessarySlicePass.h" #include "luci/Pass/RemoveUnnecessaryStridedSlicePass.h" @@ -52,9 +52,6 @@ #include "luci/Pass/ResolveCustomOpBatchMatMulPass.h" #include "luci/Pass/ResolveCustomOpMatMulPass.h" #include "luci/Pass/ResolveCustomOpMaxPoolWithArgmaxPass.h" -#include "luci/Pass/RequantizePass.h" -#include "luci/Pass/QuantizeWithMinMaxPass.h" -#include "luci/Pass/QuantizeDequantizeWeightsPass.h" #include "luci/Pass/SparsifyTensorPass.h" #include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h" #include "luci/Pass/SubstitutePackToReshapePass.h" @@ -75,9 +72,6 @@ #include "ModulePhase.h" #include "ProgressReporter.h" -#include "helpers/Strings.h" - -#include "QuantizedModelVerifier.h" #include <luci/IR/CircleNodes.h> #include <logo/Phase.h> @@ -91,37 +85,17 @@ namespace using namespace luci; -template <typename T> T lexical_cast(const std::string &str) -{ - std::istringstream ss; - ss.str(str); - T data; - ss >> data; - return data; -} - -template <typename T> std::vector<T> lexical_cast(std::vector<std::string> &sv) -{ - std::vector<T> result; - std::transform(sv.begin(), sv.end(), std::back_inserter(result), - [](std::string str) -> T { return lexical_cast<T>(str); }); - return result; -} - class OptimizeOptionsImpl final : public luci::CircleOptimizer::Options { public: void enable(Algorithm) final; void param(AlgorithmParameters, const std::string &) final; const std::string param(AlgorithmParameters) const final; - void params(AlgorithmParameters, std::vector<std::string> &) final; - std::vector<std::string> params(AlgorithmParameters) const final; bool query(Algorithm) final; private: std::vector<Algorithm> _algorithms; std::map<AlgorithmParameters, const std::string> _algorithm_params; - std::map<AlgorithmParameters, std::vector<std::string>> _multiple_params; }; void OptimizeOptionsImpl::enable(Algorithm algo) { _algorithms.push_back(algo); } @@ -144,24 +118,6 @@ const std::string OptimizeOptionsImpl::param(AlgorithmParameters param) const } } -void OptimizeOptionsImpl::params(AlgorithmParameters param, std::vector<std::string> &vec) -{ - _multiple_params[param] = vec; -} - -std::vector<std::string> OptimizeOptionsImpl::params(AlgorithmParameters param) const -{ - auto param_vec = _multiple_params.find(param); - if (param_vec != _multiple_params.end()) - { - return param_vec->second; - } - else - { - return std::vector<std::string>(); - } -} - bool OptimizeOptionsImpl::query(Algorithm algo) { std::vector<Algorithm>::iterator it = std::find(_algorithms.begin(), _algorithms.end(), algo); @@ -312,6 +268,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique<luci::FoldDequantizePass>()); } + if (_options->query(Options::Algorithm::FoldGather)) + { + phase.emplace_back(std::make_unique<luci::FoldGatherPass>()); + } if (_options->query(Options::Algorithm::FoldSparseToDense)) { phase.emplace_back(std::make_unique<luci::FoldSparseToDensePass>()); @@ -368,6 +328,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique<luci::RemoveRedundantTransposePass>()); } + if (_options->query(Options::Algorithm::RemoveRedundantQuantize)) + { + phase.emplace_back(std::make_unique<luci::RemoveRedundantQuantizePass>()); + } if (_options->query(Options::Algorithm::ReplaceMulAddWithDepthwiseConv)) { phase.emplace_back(std::make_unique<luci::ReplaceMulAddWithDepthwiseConvPass>()); @@ -417,174 +381,6 @@ void CircleOptimizer::optimize(loco::Graph *g) const phase_runner.run(phase); } -void CircleOptimizer::quantize(loco::Graph *g) const -{ - // Fake quantization of weights - if (_options->query(Options::Algorithm::QuantizeDequantizeWeights)) - { - static const std::vector<std::string> fakeq_supported_input_model_dtype{"float32"}; - static const std::vector<std::string> fakeq_supported_output_model_dtype{"uint8", "int16"}; - static const std::vector<std::string> fakeq_supported_granularity{"layer", "channel"}; - - auto input_model_dtype = - _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype); - auto output_model_dtype = - _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype); - auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity); - - if (!in_array(to_lower_case(input_model_dtype), fakeq_supported_input_model_dtype)) - throw std::runtime_error("Unsupported input type. List of supported input type: " + - to_string(fakeq_supported_input_model_dtype)); - - if (!in_array(to_lower_case(output_model_dtype), fakeq_supported_output_model_dtype)) - throw std::runtime_error("Unsupported output type. List of supported output type: " + - to_string(fakeq_supported_output_model_dtype)); - - if (!in_array(to_lower_case(granularity), fakeq_supported_granularity)) - throw std::runtime_error("Unsupported granularity. List of supported granularity: " + - to_string(fakeq_supported_granularity)); - - if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise && - str_to_dtype(output_model_dtype) != loco::DataType::U8) - throw std::runtime_error("Layer-wise quantization only supports uint8 dtype."); - - // Clear existing quantparams before doing fake quantization - for (auto node : loco::active_nodes(loco::output_nodes(g))) - { - auto circle_node = loco::must_cast<luci::CircleNode *>(node); - if (circle_node->quantparam() != nullptr) - circle_node->quantparam(nullptr); - } - - luci::QuantizeDequantizeWeightsPass fake_quantizer(str_to_dtype(input_model_dtype), - str_to_dtype(output_model_dtype), - str_to_granularity(granularity)); - fake_quantizer.run(g); - } - - // Actual quantization of weights, bias, and activation - if (_options->query(Options::Algorithm::QuantizeWithMinMax)) - { - static const std::vector<std::string> qwmm_supported_input_model_dtype{"float32"}; - static const std::vector<std::string> qwmm_supported_output_model_dtype{"uint8", "int16"}; - static const std::vector<std::string> qwmm_supported_granularity{"layer", "channel"}; - static const std::vector<std::string> qwmm_supported_input_type{"uint8", "int16"}; - static const std::vector<std::string> qwmm_supported_output_type{"uint8", "int16"}; - - auto input_model_dtype = - _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype); - auto output_model_dtype = - _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype); - auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity); - auto input_type = _options->param(Options::AlgorithmParameters::Quantize_input_type); - if (input_type.empty()) - input_type = output_model_dtype; - auto output_type = _options->param(Options::AlgorithmParameters::Quantize_output_type); - if (output_type.empty()) - output_type = output_model_dtype; - - if (!in_array(to_lower_case(input_model_dtype), qwmm_supported_input_model_dtype)) - throw std::runtime_error("Unsupported input type. List of supported input types: " + - to_string(qwmm_supported_input_model_dtype)); - - if (!in_array(to_lower_case(output_model_dtype), qwmm_supported_output_model_dtype)) - throw std::runtime_error("Unsupported output type. List of supported output types: " + - to_string(qwmm_supported_output_model_dtype)); - - if (!in_array(to_lower_case(granularity), qwmm_supported_granularity)) - throw std::runtime_error("Unsupported granularity. List of supported granularity: " + - to_string(qwmm_supported_granularity)); - - if (!in_array(to_lower_case(input_type), qwmm_supported_input_type)) - throw std::runtime_error("Unsupported input type. List of supported input types: " + - to_string(qwmm_supported_input_type)); - - if (!in_array(to_lower_case(output_type), qwmm_supported_output_type)) - throw std::runtime_error("Unsupported output type. List of supported output types: " + - to_string(qwmm_supported_output_type)); - - if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise && - str_to_dtype(output_model_dtype) != loco::DataType::U8) - throw std::runtime_error("Layer-wise quantization only supports uint8 dtype."); - - luci::QuantizeWithMinMaxPass quantizer( - str_to_dtype(input_model_dtype), str_to_dtype(output_model_dtype), - str_to_granularity(granularity), str_to_dtype(input_type), str_to_dtype(output_type)); - quantizer.run(g); - - // Post-quantization optimizations - logo::Phase phase; - - phase.emplace_back(std::make_unique<luci::PropagateQuantParamPass>()); - - phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>()); - phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>()); - phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>()); - - ProgressReporter prog(g, logo::PhaseStrategy::Saturate); - logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g}; - phase_runner.attach(&prog); - phase_runner.run(phase); - - // Verify the type/granularity of the quantized model - luci::QuantizedModelVerifier verifier(str_to_dtype(output_model_dtype), - str_to_granularity(granularity)); - verifier.verify(g); - } - - // Requantize - if (_options->query(Options::Algorithm::Requantize)) - { - static const std::vector<std::string> rq_supported_input_model_dtype{"int8"}; - static const std::vector<std::string> rq_supported_output_model_dtype{"uint8"}; - - auto input_model_dtype = - _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype); - auto output_model_dtype = - _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype); - - if (!in_array(to_lower_case(input_model_dtype), rq_supported_input_model_dtype)) - throw std::runtime_error("Unsupported input type. List of supported input types: " + - to_string(rq_supported_input_model_dtype)); - - if (!in_array(to_lower_case(output_model_dtype), rq_supported_output_model_dtype)) - throw std::runtime_error("Unsupported output type. List of supported output types: " + - to_string(rq_supported_output_model_dtype)); - - luci::RequantizePass requantizer(str_to_dtype(input_model_dtype), - str_to_dtype(output_model_dtype)); - requantizer.run(g); - } - - // Force to write quantparam to specified tensors - // NOTE Only per-tensor (not per-channel) qparam can be written - if (_options->query(Options::Algorithm::ForceQuantParam)) - { - ForceQuantParamPass::TensorVector tensors = - _options->params(Options::AlgorithmParameters::Quantize_tensor_names); - auto str_scales = _options->params(Options::AlgorithmParameters::Quantize_scales); - auto str_zero_points = _options->params(Options::AlgorithmParameters::Quantize_zero_points); - - // Cast scales/zero_points to proper types - ForceQuantParamPass::ScaleVector scales = lexical_cast<float>(str_scales); - ForceQuantParamPass::ZPVector zero_points = lexical_cast<int64_t>(str_zero_points); - - ForceQuantParamPass fq(tensors, scales, zero_points); - fq.run(g); - } - - logo::Phase phase; - - // Do Shape/Type inference - phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>()); - phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>()); - - ProgressReporter prog(g, logo::PhaseStrategy::Saturate); - logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g}; - phase_runner.attach(&prog); - phase_runner.run(phase); -} - void CircleOptimizer::sparsify(loco::Graph *g) const { if (_options->query(Options::Algorithm::SparsifyTensorPass)) diff --git a/compiler/luci/pass/src/CircleOptimizer.test.cpp b/compiler/luci/pass/src/CircleOptimizer.test.cpp index a1b5c7f80..041fc7d75 100644 --- a/compiler/luci/pass/src/CircleOptimizer.test.cpp +++ b/compiler/luci/pass/src/CircleOptimizer.test.cpp @@ -71,171 +71,3 @@ TEST(CircleOptimizerTest, sparsify_simple) SUCCEED(); } - -TEST(CircleOptimizerTest, quantize_quantdequant_simple) -{ - loco::Graph g; - luci::CircleOptimizer o; - - auto options = o.options(); - - options->enable(Algorithms::QuantizeDequantizeWeights); - options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32"); - options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8"); - options->param(AlgorithmParameters::Quantize_granularity, "layer"); - - o.quantize(&g); - - SUCCEED(); -} - -TEST(CircleOptimizerTest, quantize_quantdequant_input_NEG) -{ - loco::Graph g; - luci::CircleOptimizer o; - - auto options = o.options(); - - options->enable(Algorithms::QuantizeDequantizeWeights); - options->param(AlgorithmParameters::Quantize_input_model_dtype, "invalid"); - options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8"); - options->param(AlgorithmParameters::Quantize_granularity, "layer"); - - EXPECT_THROW(o.quantize(&g), std::runtime_error); -} - -TEST(CircleOptimizerTest, quantize_quantdequant_output_NEG) -{ - loco::Graph g; - luci::CircleOptimizer o; - - auto options = o.options(); - - options->enable(Algorithms::QuantizeDequantizeWeights); - options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32"); - options->param(AlgorithmParameters::Quantize_output_model_dtype, "invalid"); - options->param(AlgorithmParameters::Quantize_granularity, "layer"); - - EXPECT_THROW(o.quantize(&g), std::runtime_error); -} - -TEST(CircleOptimizerTest, quantize_quantdequant_gran_NEG) -{ - loco::Graph g; - luci::CircleOptimizer o; - - auto options = o.options(); - - options->enable(Algorithms::QuantizeDequantizeWeights); - options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32"); - options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8"); - options->param(AlgorithmParameters::Quantize_granularity, "invalid"); - - EXPECT_THROW(o.quantize(&g), std::runtime_error); -} - -TEST(CircleOptimizerTest, quantize_minmax_simple) -{ - loco::Graph g; - luci::CircleOptimizer o; - - auto options = o.options(); - - options->enable(Algorithms::QuantizeWithMinMax); - options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32"); - options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8"); - options->param(AlgorithmParameters::Quantize_granularity, "layer"); - - o.quantize(&g); - - SUCCEED(); -} - -TEST(CircleOptimizerTest, quantize_minmax_input_NEG) -{ - loco::Graph g; - luci::CircleOptimizer o; - - auto options = o.options(); - - options->enable(Algorithms::QuantizeWithMinMax); - options->param(AlgorithmParameters::Quantize_input_model_dtype, "invalid"); - options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8"); - options->param(AlgorithmParameters::Quantize_granularity, "layer"); - - EXPECT_THROW(o.quantize(&g), std::runtime_error); -} - -TEST(CircleOptimizerTest, quantize_minmax_output_NEG) -{ - loco::Graph g; - luci::CircleOptimizer o; - - auto options = o.options(); - - options->enable(Algorithms::QuantizeWithMinMax); - options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32"); - options->param(AlgorithmParameters::Quantize_output_model_dtype, "invalid"); - options->param(AlgorithmParameters::Quantize_granularity, "layer"); - - EXPECT_THROW(o.quantize(&g), std::runtime_error); -} - -TEST(CircleOptimizerTest, quantize_minmax_gran_NEG) -{ - loco::Graph g; - luci::CircleOptimizer o; - - auto options = o.options(); - - options->enable(Algorithms::QuantizeWithMinMax); - options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32"); - options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8"); - options->param(AlgorithmParameters::Quantize_granularity, "invalid"); - - EXPECT_THROW(o.quantize(&g), std::runtime_error); -} - -TEST(CircleOptimizerTest, quantize_requant_simple) -{ - loco::Graph g; - luci::CircleOptimizer o; - - auto options = o.options(); - - options->enable(Algorithms::Requantize); - options->param(AlgorithmParameters::Quantize_input_model_dtype, "int8"); - options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8"); - - o.quantize(&g); - - SUCCEED(); -} - -TEST(CircleOptimizerTest, quantize_requant_input_NEG) -{ - loco::Graph g; - luci::CircleOptimizer o; - - auto options = o.options(); - - options->enable(Algorithms::Requantize); - options->param(AlgorithmParameters::Quantize_input_model_dtype, "invalid"); - options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8"); - - EXPECT_THROW(o.quantize(&g), std::runtime_error); -} - -TEST(CircleOptimizerTest, quantize_requant_output_NEG) -{ - loco::Graph g; - luci::CircleOptimizer o; - - auto options = o.options(); - - options->enable(Algorithms::Requantize); - options->param(AlgorithmParameters::Quantize_input_model_dtype, "int8"); - options->param(AlgorithmParameters::Quantize_output_model_dtype, "invalid"); - - EXPECT_THROW(o.quantize(&g), std::runtime_error); -} diff --git a/compiler/luci/pass/src/CircleQuantizer.cpp b/compiler/luci/pass/src/CircleQuantizer.cpp new file mode 100644 index 000000000..ce38a90b9 --- /dev/null +++ b/compiler/luci/pass/src/CircleQuantizer.cpp @@ -0,0 +1,458 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/CircleQuantizer.h" + +#include "luci/Pass/CopyQuantParamPass.h" +#include "luci/Pass/ForceQuantParamPass.h" +#include "luci/Pass/PropagateQParamForwardPass.h" +#include "luci/Pass/RequantizePass.h" +#include "luci/Pass/ConvertToFakeQuantizedModelPass.h" +#include "luci/Pass/FoldDequantizePass.h" +#include "luci/Pass/QuantizePreCheckerPass.h" +#include "luci/Pass/QuantizeWithMinMaxPass.h" +#include "luci/Pass/QuantizeDequantizeWeightsPass.h" + +#include "luci/Pass/CircleShapeInferencePass.h" +#include "luci/Pass/CircleTypeInferencePass.h" + +// logo passes +#include <logo/RemoveDeadNodeWithQueryPass.h> + +#include "ProgressReporter.h" +#include "helpers/Strings.h" + +#include "QuantizedModelVerifier.h" + +#include <luci/IR/CircleNode.h> +#include <logo/Phase.h> + +#include <memory> + +namespace +{ + +using namespace luci; +using LayerParam = luci::CircleQuantizer::Options::LayerParam; + +template <typename T> T lexical_cast(const std::string &str) +{ + std::istringstream ss; + ss.str(str); + T data; + ss >> data; + return data; +} + +template <typename T> std::vector<T> lexical_cast(std::vector<std::string> &sv) +{ + std::vector<T> result; + std::transform(sv.begin(), sv.end(), std::back_inserter(result), + [](std::string str) -> T { return lexical_cast<T>(str); }); + return result; +} + +class QuantizeOptionsImpl final : public luci::CircleQuantizer::Options +{ +public: + void enable(Algorithm) final; + void param(AlgorithmParameters, const std::string &) final; + const std::string param(AlgorithmParameters) const final; + void params(AlgorithmParameters, std::vector<std::string> &) final; + std::vector<std::string> params(AlgorithmParameters) const final; + void layer_params(AlgorithmParameters, std::vector<std::shared_ptr<LayerParam>> &) final; + std::vector<std::shared_ptr<LayerParam>> layer_params(AlgorithmParameters) const final; + bool query(Algorithm) final; + +private: + std::vector<Algorithm> _algorithms; + std::map<AlgorithmParameters, const std::string> _algorithm_params; + std::map<AlgorithmParameters, std::vector<std::string>> _multiple_params; + std::map<AlgorithmParameters, std::vector<std::shared_ptr<LayerParam>>> _layer_params; +}; + +void QuantizeOptionsImpl::enable(Algorithm algo) { _algorithms.push_back(algo); } + +void QuantizeOptionsImpl::param(AlgorithmParameters param, const std::string &str) +{ + _algorithm_params.insert(std::pair<AlgorithmParameters, const std::string>(param, str)); +} + +const std::string QuantizeOptionsImpl::param(AlgorithmParameters param) const +{ + auto param_str = _algorithm_params.find(param); + if (param_str != _algorithm_params.end()) + { + return param_str->second; + } + else + { + return std::string(); + } +} + +void QuantizeOptionsImpl::params(AlgorithmParameters param, std::vector<std::string> &vec) +{ + _multiple_params[param] = vec; +} + +std::vector<std::string> QuantizeOptionsImpl::params(AlgorithmParameters param) const +{ + auto param_vec = _multiple_params.find(param); + if (param_vec != _multiple_params.end()) + { + return param_vec->second; + } + else + { + return std::vector<std::string>(); + } +} + +void QuantizeOptionsImpl::layer_params(AlgorithmParameters param, + std::vector<std::shared_ptr<LayerParam>> &vec) +{ + _layer_params[param] = vec; +} + +std::vector<std::shared_ptr<LayerParam>> +QuantizeOptionsImpl::layer_params(AlgorithmParameters param) const +{ + auto param_vec = _layer_params.find(param); + if (param_vec != _layer_params.end()) + { + return param_vec->second; + } + else + { + return std::vector<std::shared_ptr<LayerParam>>(); + } +} + +bool QuantizeOptionsImpl::query(Algorithm algo) +{ + std::vector<Algorithm>::iterator it = std::find(_algorithms.begin(), _algorithms.end(), algo); + if (it == _algorithms.end()) + return false; + + return true; +} + +} // namespace + +namespace luci +{ + +CircleQuantizer::Options *CircleQuantizer::options(void) +{ + if (_options == nullptr) + { + _options = std::make_unique<QuantizeOptionsImpl>(); + } + + return _options.get(); +} + +void CircleQuantizer::quantize(loco::Graph *g) const +{ + // Fake quantization of weights + if (_options->query(Options::Algorithm::QuantizeDequantizeWeights)) + { + static const std::vector<std::string> fakeq_supported_input_model_dtype{"float32"}; + static const std::vector<std::string> fakeq_supported_output_model_dtype{"uint8", "int16"}; + static const std::vector<std::string> fakeq_supported_granularity{"layer", "channel"}; + + auto input_model_dtype = + _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype); + auto output_model_dtype = + _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype); + auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity); + auto layer_params = _options->layer_params(Options::AlgorithmParameters::Quantize_layer_params); + + if (!in_array(to_lower_case(input_model_dtype), fakeq_supported_input_model_dtype)) + throw std::runtime_error("Unsupported input type. List of supported input type: " + + to_string(fakeq_supported_input_model_dtype)); + + if (!in_array(to_lower_case(output_model_dtype), fakeq_supported_output_model_dtype)) + throw std::runtime_error("Unsupported output type. List of supported output type: " + + to_string(fakeq_supported_output_model_dtype)); + + if (!in_array(to_lower_case(granularity), fakeq_supported_granularity)) + throw std::runtime_error("Unsupported granularity. List of supported granularity: " + + to_string(fakeq_supported_granularity)); + + if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise && + str_to_dtype(output_model_dtype) != loco::DataType::U8) + throw std::runtime_error("Layer-wise quantization only supports uint8 dtype."); + + // Check dtype/granularity of layer params + for (auto layer_param : layer_params) + { + auto name = layer_param->name; + if (!in_array(to_lower_case(layer_param->dtype), fakeq_supported_output_model_dtype)) + { + throw std::runtime_error("Unsupported dtype in " + name + ". List of supported dtype: " + + to_string(fakeq_supported_output_model_dtype)); + } + if (!in_array(to_lower_case(layer_param->granularity), fakeq_supported_granularity)) + { + throw std::runtime_error( + "Unsupported granularity in " + name + + ". List of supported granularity: " + to_string(fakeq_supported_granularity)); + } + } + + // Clear existing quantparams before doing fake quantization + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + if (circle_node->quantparam() != nullptr) + circle_node->quantparam(nullptr); + } + + auto ctx = std::make_unique<luci::QuantizeDequantizeWeightsPass::Context>(); + { + ctx->input_model_dtype = str_to_dtype(input_model_dtype); + ctx->output_model_dtype = str_to_dtype(output_model_dtype); + ctx->granularity = str_to_granularity(granularity); + + for (auto layer_param : layer_params) + { + LayerInfo info; + { + info.name = layer_param->name; + info.dtype = str_to_dtype(layer_param->dtype); + info.granularity = str_to_granularity(layer_param->granularity); + } + ctx->layers_info.emplace_back(info); + } + } + + luci::QuantizeDequantizeWeightsPass fake_quantizer(std::move(ctx)); + + fake_quantizer.run(g); + } + + // Actual quantization of weights, bias, and activation + if (_options->query(Options::Algorithm::QuantizeWithMinMax)) + { + static const std::vector<std::string> qwmm_supported_input_model_dtype{"float32"}; + static const std::vector<std::string> qwmm_supported_output_model_dtype{"uint8", "int16"}; + static const std::vector<std::string> qwmm_supported_granularity{"layer", "channel"}; + static const std::vector<std::string> qwmm_supported_input_type{"uint8", "int16"}; + static const std::vector<std::string> qwmm_supported_output_type{"uint8", "int16"}; + + auto input_model_dtype = + _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype); + auto output_model_dtype = + _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype); + auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity); + auto input_type = _options->param(Options::AlgorithmParameters::Quantize_input_type); + if (input_type.empty()) + input_type = output_model_dtype; + auto output_type = _options->param(Options::AlgorithmParameters::Quantize_output_type); + if (output_type.empty()) + output_type = output_model_dtype; + + bool TF_style_maxpool = + _options->param(Options::AlgorithmParameters::Quantize_TF_style_maxpool) == "True"; + + auto layer_params = _options->layer_params(Options::AlgorithmParameters::Quantize_layer_params); + + if (!in_array(to_lower_case(input_model_dtype), qwmm_supported_input_model_dtype)) + throw std::runtime_error("Unsupported input type. List of supported input types: " + + to_string(qwmm_supported_input_model_dtype)); + + if (!in_array(to_lower_case(output_model_dtype), qwmm_supported_output_model_dtype)) + throw std::runtime_error("Unsupported output type. List of supported output types: " + + to_string(qwmm_supported_output_model_dtype)); + + if (!in_array(to_lower_case(granularity), qwmm_supported_granularity)) + throw std::runtime_error("Unsupported granularity. List of supported granularity: " + + to_string(qwmm_supported_granularity)); + + if (!in_array(to_lower_case(input_type), qwmm_supported_input_type)) + throw std::runtime_error("Unsupported input type. List of supported input types: " + + to_string(qwmm_supported_input_type)); + + if (!in_array(to_lower_case(output_type), qwmm_supported_output_type)) + throw std::runtime_error("Unsupported output type. List of supported output types: " + + to_string(qwmm_supported_output_type)); + + if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise && + str_to_dtype(output_model_dtype) != loco::DataType::U8) + throw std::runtime_error("Layer-wise quantization only supports uint8 dtype."); + + // Check dtype/granularity of layer params + for (auto layer_param : layer_params) + { + auto name = layer_param->name; + if (!in_array(to_lower_case(layer_param->dtype), qwmm_supported_output_model_dtype)) + { + throw std::runtime_error("Unsupported dtype in " + name + ". List of supported dtype: " + + to_string(qwmm_supported_output_model_dtype)); + } + if (!in_array(to_lower_case(layer_param->granularity), qwmm_supported_granularity)) + { + throw std::runtime_error( + "Unsupported granularity in " + name + + ". List of supported granularity: " + to_string(qwmm_supported_granularity)); + } + } + + // Input model checker for quantization + luci::QuantizePreCheckerPass input_model_checker{}; + input_model_checker.run(g); + + auto ctx = std::make_unique<luci::QuantizeWithMinMaxPass::Context>(); + { + ctx->input_model_dtype = str_to_dtype(input_model_dtype); + ctx->output_model_dtype = str_to_dtype(output_model_dtype); + ctx->granularity = str_to_granularity(granularity); + ctx->input_type = str_to_dtype(input_type); + ctx->output_type = str_to_dtype(output_type); + ctx->TF_style_maxpool = TF_style_maxpool; + + for (auto layer_param : layer_params) + { + LayerInfo info; + { + info.name = layer_param->name; + info.dtype = str_to_dtype(layer_param->dtype); + info.granularity = str_to_granularity(layer_param->granularity); + } + ctx->layers_info.emplace_back(info); + } + } + + luci::QuantizeWithMinMaxPass quantizer(std::move(ctx)); + + quantizer.run(g); + + auto verify_ctx = std::make_unique<luci::QuantizedModelVerifier::Context>(); + { + verify_ctx->output_model_dtype = str_to_dtype(output_model_dtype); + verify_ctx->granularity = str_to_granularity(granularity); + verify_ctx->input_type = str_to_dtype(input_type); + verify_ctx->output_type = str_to_dtype(output_type); + verify_ctx->TF_style_maxpool = TF_style_maxpool; + + for (auto layer_param : layer_params) + { + LayerInfo info; + { + info.name = layer_param->name; + info.dtype = str_to_dtype(layer_param->dtype); + info.granularity = str_to_granularity(layer_param->granularity); + } + verify_ctx->layers_info.emplace_back(info); + } + } + + // Verify the type/granularity of the quantized model + luci::QuantizedModelVerifier verifier(std::move(verify_ctx)); + + verifier.verify(g); + } + + // Requantize + if (_options->query(Options::Algorithm::Requantize)) + { + static const std::vector<std::string> rq_supported_input_model_dtype{"int8"}; + static const std::vector<std::string> rq_supported_output_model_dtype{"uint8"}; + + auto input_model_dtype = + _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype); + auto output_model_dtype = + _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype); + + if (!in_array(to_lower_case(input_model_dtype), rq_supported_input_model_dtype)) + throw std::runtime_error("Unsupported input type. List of supported input types: " + + to_string(rq_supported_input_model_dtype)); + + if (!in_array(to_lower_case(output_model_dtype), rq_supported_output_model_dtype)) + throw std::runtime_error("Unsupported output type. List of supported output types: " + + to_string(rq_supported_output_model_dtype)); + + luci::RequantizePass requantizer(str_to_dtype(input_model_dtype), + str_to_dtype(output_model_dtype)); + requantizer.run(g); + } + + // Force to write quantparam to specified tensors + // NOTE Only per-tensor (not per-channel) qparam can be written + if (_options->query(Options::Algorithm::ForceQuantParam)) + { + ForceQuantParamPass::TensorVector tensors = + _options->params(Options::AlgorithmParameters::Quantize_tensor_names); + auto str_scales = _options->params(Options::AlgorithmParameters::Quantize_scales); + auto str_zero_points = _options->params(Options::AlgorithmParameters::Quantize_zero_points); + + // Cast scales/zero_points to proper types + ForceQuantParamPass::ScaleVector scales = lexical_cast<float>(str_scales); + ForceQuantParamPass::ZPVector zero_points = lexical_cast<int64_t>(str_zero_points); + + ForceQuantParamPass fq(tensors, scales, zero_points); + fq.run(g); + } + + // Copy quantparam of a tensor to another tensor + if (_options->query(Options::Algorithm::CopyQuantParam)) + { + CopyQuantParamPass::TensorVector src_tensors = + _options->params(Options::AlgorithmParameters::Quantize_src_tensor_names); + CopyQuantParamPass::TensorVector dst_tensors = + _options->params(Options::AlgorithmParameters::Quantize_dst_tensor_names); + + CopyQuantParamPass cq(src_tensors, dst_tensors); + cq.run(g); + } + + // Convert quantized model to fake-quantized model + if (_options->query(Options::Algorithm::ConvertToFakeQuantizedModel)) + { + luci::ConvertToFakeQuantizedModelPass fake_quantizer; + fake_quantizer.run(g); + + logo::Phase phase; + + // Default passes + phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>()); + phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>()); + phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>()); + + // Fold Dequantize Ops generated during fake quantization + phase.emplace_back(std::make_unique<luci::FoldDequantizePass>()); + + ProgressReporter prog(g, logo::PhaseStrategy::Restart); + logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g}; + phase_runner.attach(&prog); + phase_runner.run(phase); + } + + logo::Phase phase; + + // Do Shape/Type inference + phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>()); + phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>()); + + ProgressReporter prog(g, logo::PhaseStrategy::Saturate); + logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g}; + phase_runner.attach(&prog); + phase_runner.run(phase); +} + +} // namespace luci diff --git a/compiler/luci/pass/src/CircleQuantizer.test.cpp b/compiler/luci/pass/src/CircleQuantizer.test.cpp new file mode 100644 index 000000000..5766d5fe5 --- /dev/null +++ b/compiler/luci/pass/src/CircleQuantizer.test.cpp @@ -0,0 +1,191 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/CircleQuantizer.h" + +#include <gtest/gtest.h> + +using namespace luci; +using Algorithms = luci::CircleQuantizer::Options::Algorithm; +using AlgorithmParameters = luci::CircleQuantizer::Options::AlgorithmParameters; + +TEST(CircleQuantizerTest, quantize_quantdequant_simple) +{ + loco::Graph g; + luci::CircleQuantizer o; + + auto options = o.options(); + + options->enable(Algorithms::QuantizeDequantizeWeights); + options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32"); + options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8"); + options->param(AlgorithmParameters::Quantize_granularity, "layer"); + + o.quantize(&g); + + SUCCEED(); +} + +TEST(CircleQuantizerTest, quantize_quantdequant_input_NEG) +{ + loco::Graph g; + luci::CircleQuantizer o; + + auto options = o.options(); + + options->enable(Algorithms::QuantizeDequantizeWeights); + options->param(AlgorithmParameters::Quantize_input_model_dtype, "invalid"); + options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8"); + options->param(AlgorithmParameters::Quantize_granularity, "layer"); + + EXPECT_THROW(o.quantize(&g), std::runtime_error); +} + +TEST(CircleQuantizerTest, quantize_quantdequant_output_NEG) +{ + loco::Graph g; + luci::CircleQuantizer o; + + auto options = o.options(); + + options->enable(Algorithms::QuantizeDequantizeWeights); + options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32"); + options->param(AlgorithmParameters::Quantize_output_model_dtype, "invalid"); + options->param(AlgorithmParameters::Quantize_granularity, "layer"); + + EXPECT_THROW(o.quantize(&g), std::runtime_error); +} + +TEST(CircleQuantizerTest, quantize_quantdequant_gran_NEG) +{ + loco::Graph g; + luci::CircleQuantizer o; + + auto options = o.options(); + + options->enable(Algorithms::QuantizeDequantizeWeights); + options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32"); + options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8"); + options->param(AlgorithmParameters::Quantize_granularity, "invalid"); + + EXPECT_THROW(o.quantize(&g), std::runtime_error); +} + +TEST(CircleQuantizerTest, quantize_minmax_simple) +{ + loco::Graph g; + luci::CircleQuantizer o; + + auto options = o.options(); + + options->enable(Algorithms::QuantizeWithMinMax); + options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32"); + options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8"); + options->param(AlgorithmParameters::Quantize_granularity, "layer"); + + o.quantize(&g); + + SUCCEED(); +} + +TEST(CircleQuantizerTest, quantize_minmax_input_NEG) +{ + loco::Graph g; + luci::CircleQuantizer o; + + auto options = o.options(); + + options->enable(Algorithms::QuantizeWithMinMax); + options->param(AlgorithmParameters::Quantize_input_model_dtype, "invalid"); + options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8"); + options->param(AlgorithmParameters::Quantize_granularity, "layer"); + + EXPECT_THROW(o.quantize(&g), std::runtime_error); +} + +TEST(CircleQuantizerTest, quantize_minmax_output_NEG) +{ + loco::Graph g; + luci::CircleQuantizer o; + + auto options = o.options(); + + options->enable(Algorithms::QuantizeWithMinMax); + options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32"); + options->param(AlgorithmParameters::Quantize_output_model_dtype, "invalid"); + options->param(AlgorithmParameters::Quantize_granularity, "layer"); + + EXPECT_THROW(o.quantize(&g), std::runtime_error); +} + +TEST(CircleQuantizerTest, quantize_minmax_gran_NEG) +{ + loco::Graph g; + luci::CircleQuantizer o; + + auto options = o.options(); + + options->enable(Algorithms::QuantizeWithMinMax); + options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32"); + options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8"); + options->param(AlgorithmParameters::Quantize_granularity, "invalid"); + + EXPECT_THROW(o.quantize(&g), std::runtime_error); +} + +TEST(CircleQuantizerTest, quantize_requant_simple) +{ + loco::Graph g; + luci::CircleQuantizer o; + + auto options = o.options(); + + options->enable(Algorithms::Requantize); + options->param(AlgorithmParameters::Quantize_input_model_dtype, "int8"); + options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8"); + + o.quantize(&g); + + SUCCEED(); +} + +TEST(CircleQuantizerTest, quantize_requant_input_NEG) +{ + loco::Graph g; + luci::CircleQuantizer o; + + auto options = o.options(); + + options->enable(Algorithms::Requantize); + options->param(AlgorithmParameters::Quantize_input_model_dtype, "invalid"); + options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8"); + + EXPECT_THROW(o.quantize(&g), std::runtime_error); +} + +TEST(CircleQuantizerTest, quantize_requant_output_NEG) +{ + loco::Graph g; + luci::CircleQuantizer o; + + auto options = o.options(); + + options->enable(Algorithms::Requantize); + options->param(AlgorithmParameters::Quantize_input_model_dtype, "int8"); + options->param(AlgorithmParameters::Quantize_output_model_dtype, "invalid"); + + EXPECT_THROW(o.quantize(&g), std::runtime_error); +} diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp index 270714049..ce4f54035 100644 --- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp +++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp @@ -228,6 +228,9 @@ bool check_4d_reshape(loco::Node *node, const std::vector<int32_t> indices) if (input->shape_status() != luci::ShapeStatus::VALID) return false; + if (input->rank() != 4) + return false; + if (reshape->shape_status() != luci::ShapeStatus::VALID) return false; @@ -804,6 +807,8 @@ class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool> return true; } + bool visit(luci::CircleElu *node) { return convert_unary_features<luci::CircleElu>(node); } + bool visit(luci::CircleLeakyRelu *node) { return convert_unary_features<luci::CircleLeakyRelu>(node); @@ -1240,6 +1245,7 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g) break; case luci::CircleOpcode::ADD: case luci::CircleOpcode::CONCATENATION: + case luci::CircleOpcode::ELU: case luci::CircleOpcode::LEAKY_RELU: case luci::CircleOpcode::LOGISTIC: case luci::CircleOpcode::MAXIMUM: diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp index c9412fbb1..dd81d1380 100644 --- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp +++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp @@ -264,6 +264,22 @@ public: luci::CircleConst *input2 = nullptr; }; +class EluGraph final : public SimpleGraph +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + elu = g.nodes()->create<luci::CircleElu>(); + elu->features(input); + elu->name("elu"); + + return elu; + } + +public: + luci::CircleElu *elu = nullptr; +}; + class LeakyReluGraph final : public SimpleGraph { protected: @@ -941,6 +957,26 @@ TEST(ConvertNCHWToNHWC, Concatenation) EXPECT_EQ(3, g.concat->axis()); } +TEST(ConvertNCHWToNHWC, Elu) +{ + EluGraph g; + g.init(); + + run_phase(&g.g, true, true); + + check_pre_trans(g.elu->features()); + + auto elu_succs = loco::succs(g.elu); + EXPECT_EQ(1, elu_succs.size()); + check_post_trans(*elu_succs.begin()); + + // Check elu shape + EXPECT_EQ(1, g.elu->dim(0).value()); + EXPECT_EQ(4, g.elu->dim(1).value()); + EXPECT_EQ(4, g.elu->dim(2).value()); + EXPECT_EQ(16, g.elu->dim(3).value()); +} + TEST(ConvertNCHWToNHWC, LeakyRelu) { LeakyReluGraph g; diff --git a/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp new file mode 100644 index 000000000..11970fff5 --- /dev/null +++ b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp @@ -0,0 +1,214 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/ConvertToFakeQuantizedModelPass.h" +#include "luci/Pass/QuantizationParameters.h" + +#include "QuantizationUtils.h" + +#include <luci/Profile/CircleNodeOrigin.h> +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNodeVisitor.h> +#include <luci/Log.h> + +namespace +{ + +// Create Quantize Op whose dtype/shape/qparam are the same with node +luci::CircleQuantize *create_quantize(luci::CircleNode *node) +{ + auto quantize = node->graph()->nodes()->create<luci::CircleQuantize>(); + quantize->name(node->name() + "_Quantize"); + quantize->dtype(node->dtype()); + quantize->rank(node->rank()); + for (uint32_t i = 0; i < node->rank(); i++) + quantize->dim(i).set(node->dim(i).value()); + + quantize->shape_status(luci::ShapeStatus::VALID); + + copy_quantparam(node, quantize); + + luci::add_origin(quantize, luci::get_origin(node)); + + return quantize; +} + +// Create Dequantize Op whose shape is the same with node +luci::CircleDequantize *create_dequantize(luci::CircleNode *node) +{ + auto dequantize = node->graph()->nodes()->create<luci::CircleDequantize>(); + dequantize->name(node->name() + "_Dequantize"); + dequantize->dtype(loco::DataType::FLOAT32); + dequantize->rank(node->rank()); + for (uint32_t i = 0; i < node->rank(); i++) + dequantize->dim(i).set(node->dim(i).value()); + + dequantize->shape_status(luci::ShapeStatus::VALID); + + luci::add_origin(dequantize, luci::get_origin(node)); + + return dequantize; +} + +// Return true if node is quantized activation +// 1. dtype is u8 or s16 +// 2. node has qparam +bool is_quant_act(const luci::CircleNode *node) +{ + if (node->dtype() != loco::DataType::U8 and node->dtype() != loco::DataType::S16) + return false; + + if (not node->quantparam()) + return false; + + return true; +} + +// Return true if node is quantized const +// 1. dtype is not fp32 +// 2. node has qparam +// NOTE Quantized const can have the following types +// u8 (weights, activation), s16 (weights, activation), s32 (bias), s64 (bias) +bool is_quant_const(const luci::CircleConst *node) +{ + if (node->dtype() == loco::DataType::FLOAT32) + return false; + + if (not node->quantparam()) + return false; + + return true; +} + +// Insert dequantize Op after node +void insert_dequantize(loco::Node *lnode) +{ + auto node = loco::must_cast<luci::CircleNode *>(lnode); + auto dequant = create_dequantize(node); + loco::replace(node).with(dequant); + dequant->input(node); +} + +// Insert quantize Op after node and return the quantize Op +luci::CircleQuantize *insert_quantize(loco::Node *lnode) +{ + auto node = loco::must_cast<luci::CircleNode *>(lnode); + auto quant = create_quantize(node); + loco::replace(node).with(quant); + quant->input(node); + return quant; +} + +// Dequantize node +void dequantize(luci::CircleNode *node) +{ + node->dtype(loco::DataType::FLOAT32); + node->quantparam(nullptr); +} + +// Do fake quantization on quantized activation +// 1. Insert Quantize-Dequantize Ops +// 2. Update dtype/quantparam of node +void fq_activation(luci::CircleNode *node) +{ + if (not is_quant_act(node)) + return; + + auto quant = insert_quantize(node); + insert_dequantize(quant); + + dequantize(node); +} + +#define RETURN_UNLESS(COND) \ + if (not(COND)) \ + return; + +// Visitor to do fake quantization for each Op +// For non-const activation, insert Quantize-Dequantize after the ofm +// For quantized const, insert Dequantize after the const +struct FakeQuantize final : public luci::CircleNodeMutableVisitor<void> +{ + void visit(luci::CircleNode *node) + { + throw std::runtime_error("Unsupported op for fake quantization in " + node->name()); + } + + void visit(luci::CircleInput *node) + { + RETURN_UNLESS(is_quant_act(node)); + + auto quant = insert_quantize(node); + insert_dequantize(quant); + + dequantize(node); + + // Update graph input + const auto inputs = node->graph()->inputs(); + auto graph_input = inputs->at(node->index()); + graph_input->dtype(loco::DataType::FLOAT32); + } + + void visit(luci::CircleOutput *node) + { + RETURN_UNLESS(is_quant_act(node)); + + dequantize(node); + + // Update graph output + const auto outputs = node->graph()->outputs(); + auto graph_output = outputs->at(node->index()); + graph_output->dtype(loco::DataType::FLOAT32); + } + + // For quantized const, insert Dequantize Op + void visit(luci::CircleConst *node) + { + RETURN_UNLESS(is_quant_const(node)); + + insert_dequantize(node); + } + + // For non-const activation, insert Quantize-Dequantize Ops + // and dequantize the node + void visit(luci::CircleConv2D *node) { fq_activation(node); } + void visit(luci::CircleAdd *node) { fq_activation(node); } +}; + +#undef RETURN_UNLESS + +} // namespace + +namespace luci +{ + +bool ConvertToFakeQuantizedModelPass::run(loco::Graph *g) +{ + LOGGER(l); + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + INFO(l) << "ConvertToFakeQuantizedModelPass visit node: " << circle_node->name() << std::endl; + + FakeQuantize fq; + circle_node->accept(&fq); + } + + // One time run + return false; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.test.cpp b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.test.cpp new file mode 100644 index 000000000..560d68a74 --- /dev/null +++ b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.test.cpp @@ -0,0 +1,277 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <logo/Phase.h> + +#include "luci/Pass/ConvertToFakeQuantizedModelPass.h" +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +namespace +{ + +// Check the below pattern +// Quantize (scale, zp) -> Dequantize (node) +void check_q_dq(loco::Node *node, float scale, int64_t zp) +{ + auto dequant = dynamic_cast<luci::CircleDequantize *>(node); + EXPECT_TRUE(dequant != nullptr); + auto quant = dynamic_cast<luci::CircleQuantize *>(dequant->input()); + EXPECT_TRUE(quant != nullptr); + auto qparam = quant->quantparam(); + EXPECT_EQ(scale, qparam->scale[0]); + EXPECT_EQ(zp, qparam->zerop[0]); +} + +// Check the below pattern +// Dequantize (node) +void check_dq(loco::Node *node) +{ + auto dequant = dynamic_cast<luci::CircleDequantize *>(node); + EXPECT_TRUE(dequant != nullptr); +} + +void set_qparam(luci::CircleNode *node, float scale, int64_t zp) +{ + auto qparam = std::make_unique<luci::CircleQuantParam>(); + { + qparam->scale.push_back(scale); + qparam->zerop.push_back(zp); + } + node->quantparam(std::move(qparam)); +} + +/** + * SimpleGraph for testing + * - Child class should implement insertGraphBody() + * + * Example (U8ConvGraph inherits SimpleGraph and create Conv2D Op) + * + * BEFORE + * - A model is quantized (ex: u8) + * + * [Input(u8)] [Filter(u8)] [Bias(s32)] + * \ | / + * \ | / + * \ | / + * [Conv2D(u8)] + * | + * [Output(u8)] + * + * AFTER + * - Ops are converted to fp32 + * - Quantize/Dequantize Ops are inserted properly + * - Q-DQ is inserted after non-const activation + * - DQ is inserted after const + * + * [Input(u8)] + * | + * [Quant(u8)] [Filter(u8)] [Bias(s32)] + * | | | + * [Dequant(fp32)] [Dequant(fp32)] [Dequant(fp32)] + * \ | / + * \ | / + * \ | / + * [Conv2D(fp32)] + * | + * [Quant(u8)] + * | + * [Dequant(fp32)] + * | + * [Output(fp32)] + */ +template <loco::DataType T> class SimpleGraph +{ +public: + void init() + { + input = g.nodes()->create<luci::CircleInput>(); + output = g.nodes()->create<luci::CircleOutput>(); + input->name("input"); + output->name("output"); + + auto graph_input = g.inputs()->create(); + input->index(graph_input->index()); + auto graph_output = g.outputs()->create(); + output->index(graph_output->index()); + + graph_input->dtype(T); + input->dtype(T); + output->dtype(T); + graph_output->dtype(T); + + graph_input->shape({1, 4, 4, 4}); + input->shape({1, 4, 4, 4}); + output->shape({1, 4, 4, 4}); + graph_output->shape({1, 4, 4, 4}); + + set_qparam(input, 1.0, 0); + set_qparam(output, 1.0, 0); + + auto graph_body = insertGraphBody(input); + output->from(graph_body); + } + + virtual ~SimpleGraph() = default; + +protected: + virtual loco::Node *insertGraphBody(loco::Node *input) = 0; + +public: + loco::Graph g; + luci::CircleInput *input = nullptr; + luci::CircleOutput *output = nullptr; +}; + +class U8ConvGraph final : public SimpleGraph<loco::DataType::U8> +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + conv = g.nodes()->create<luci::CircleConv2D>(); + weights = g.nodes()->create<luci::CircleConst>(); + bias = g.nodes()->create<luci::CircleConst>(); + + conv->dtype(loco::DataType::U8); + weights->dtype(loco::DataType::U8); + bias->dtype(loco::DataType::S32); + + conv->shape({1, 4, 4, 4}); + weights->shape({4, 1, 1, 4}); + bias->shape({4}); + + weights->size<loco::DataType::U8>(16); + for (uint32_t i = 0; i < 16; i++) + weights->at<loco::DataType::U8>(i) = i; + + bias->size<loco::DataType::S32>(4); + for (uint32_t i = 0; i < 4; i++) + bias->at<loco::DataType::S32>(i) = i; + + set_qparam(conv, 2.0, 127); + set_qparam(weights, 2.0, 127); + set_qparam(bias, 2.0, 127); + + conv->input(input); + conv->filter(weights); + conv->bias(bias); + + conv->name("conv"); + weights->name("weights"); + bias->name("bias"); + + return conv; + } + +public: + luci::CircleConv2D *conv = nullptr; + luci::CircleConst *weights = nullptr; + luci::CircleConst *bias = nullptr; +}; + +class FP32ConvGraph final : public SimpleGraph<loco::DataType::FLOAT32> +{ +protected: + loco::Node *insertGraphBody(loco::Node *input) override + { + conv = g.nodes()->create<luci::CircleConv2D>(); + weights = g.nodes()->create<luci::CircleConst>(); + bias = g.nodes()->create<luci::CircleConst>(); + + conv->dtype(loco::DataType::FLOAT32); + weights->dtype(loco::DataType::FLOAT32); + bias->dtype(loco::DataType::FLOAT32); + + conv->shape({1, 4, 4, 4}); + weights->shape({4, 1, 1, 4}); + bias->shape({4}); + + weights->size<loco::DataType::FLOAT32>(16); + for (uint32_t i = 0; i < 16; i++) + weights->at<loco::DataType::FLOAT32>(i) = i; + + bias->size<loco::DataType::FLOAT32>(4); + for (uint32_t i = 0; i < 4; i++) + bias->at<loco::DataType::FLOAT32>(i) = i; + + conv->input(input); + conv->filter(weights); + conv->bias(bias); + + conv->name("conv"); + weights->name("weights"); + bias->name("bias"); + + return conv; + } + +public: + luci::CircleConv2D *conv = nullptr; + luci::CircleConst *weights = nullptr; + luci::CircleConst *bias = nullptr; +}; + +} // namespace + +TEST(ConvertToFakeQuantizedModelTest, U8Conv2D) +{ + U8ConvGraph g; + g.init(); + + luci::ConvertToFakeQuantizedModelPass fq; + fq.run(&g.g); + + // Check ifm + check_q_dq(g.conv->input(), 1.0, 0); + + // Check weights + check_dq(g.conv->filter()); + + // Check bias + check_dq(g.conv->bias()); + + // Check ofm + check_q_dq(g.output->from(), 2.0, 127); + + SUCCEED(); +} + +TEST(ConvertToFakeQuantizedModelTest, F32Conv2D_NEG) +{ + FP32ConvGraph g; + g.init(); + + luci::ConvertToFakeQuantizedModelPass fq; + fq.run(&g.g); + + uint32_t dequant_count = 0; + uint32_t quant_count = 0; + + for (auto node : loco::active_nodes(loco::output_nodes(&g.g))) + { + auto cnode = loco::must_cast<luci::CircleNode *>(node); + auto opcode = cnode->opcode(); + if (opcode == luci::CircleOpcode::DEQUANTIZE) + dequant_count++; + if (opcode == luci::CircleOpcode::QUANTIZE) + quant_count++; + } + + // Check no quant/dequant Op is inserted + EXPECT_EQ(0, quant_count); + EXPECT_EQ(0, dequant_count); +} diff --git a/compiler/luci/pass/src/CopyQuantParamPass.cpp b/compiler/luci/pass/src/CopyQuantParamPass.cpp new file mode 100644 index 000000000..9b1bb0ea9 --- /dev/null +++ b/compiler/luci/pass/src/CopyQuantParamPass.cpp @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/CopyQuantParamPass.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/Log.h> + +namespace luci +{ + +namespace +{ + +struct SrcDst +{ + CircleNode *src = nullptr; + CircleNode *dst = nullptr; +}; + +} // namespace + +bool CopyQuantParamPass::run(loco::Graph *g) +{ + LOGGER(l); + + INFO(l) << "CopyQuantParamPass Start" << std::endl; + + if (_src_tensors.size() != _dst_tensors.size()) + throw std::runtime_error("The numbers of Source/Destination tensors do not match."); + + // Return src/dst CircleNodes + auto get_src_dst = [&g](std::string src, std::string dst) { + SrcDst src_dst; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto const cnode = loco::must_cast<CircleNode *>(node); + auto const name = cnode->name(); + if (name == src) + src_dst.src = cnode; + + if (name == dst) + src_dst.dst = cnode; + } + return src_dst; + }; + + for (uint32_t i = 0; i < _src_tensors.size(); i++) + { + auto src = _src_tensors[i]; + auto dst = _dst_tensors[i]; + + auto nodes = get_src_dst(src, dst); + if (not nodes.src) + throw std::runtime_error("The tensor named " + src + " does not exist."); + + if (not nodes.dst) + throw std::runtime_error("The tensor named " + dst + " does not exist."); + + copy_quantparam(nodes.src, nodes.dst); + + INFO(l) << "Quantparam of " << src << " is copied to " << dst << std::endl; + } + + INFO(l) << "CopyQuantParamPass End" << std::endl; + + return false; // one time run +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FoldGatherPass.cpp b/compiler/luci/pass/src/FoldGatherPass.cpp new file mode 100644 index 000000000..f179d74bd --- /dev/null +++ b/compiler/luci/pass/src/FoldGatherPass.cpp @@ -0,0 +1,185 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FoldGatherPass.h" +#include "CircleOptimizerUtils.h" + +#include <luci/IR/CircleNodes.h> + +namespace +{ + +/** + * Fold to const if + * + * 1. params: const and dtype = S32 or S64 + * 2. indices: const and dtype = S32 or S64 + * + * BEFORE + * + * [CircleConst] [CircleConst] + * | | + * +---------[Gather]---------+ + * + * AFTER + * + * [CircleConst] + * + **/ +template <loco::DataType InputT, loco::DataType IndexT> +bool fold_gather(luci::CircleGather *gather_node) +{ + const auto params = loco::must_cast<luci::CircleConst *>(gather_node->params()); + const auto indices = loco::must_cast<luci::CircleConst *>(gather_node->indices()); + + const auto rank = params->rank(); + auto axis = gather_node->axis(); + if (axis < 0) + { + axis += static_cast<int32_t>(rank); + } + + if (axis < 0 or axis >= static_cast<int32_t>(rank)) + throw std::runtime_error("Unsupported axis value"); + + const auto name = gather_node->name(); + assert(name.length() > 0); + + auto constant = gather_node->graph()->nodes()->create<luci::CircleConst>(); + constant->dtype(InputT); + constant->name(name + "_folded"); + + constant->rank(rank + indices->rank() - 1); + + assert(constant->rank() > 0); + + std::vector<uint32_t> shape; + for (uint32_t i = 0; i < rank; ++i) + { + if (i != static_cast<uint32_t>(axis)) + { + const auto dim = params->dim(i).value(); + shape.push_back(dim); + } + else + { + for (uint32_t j = 0; j < indices->rank(); ++j) + { + const auto dim = indices->dim(j).value(); + shape.push_back(dim); + } + } + } + + uint32_t size = 1; + for (uint32_t i = 0; i < shape.size(); ++i) + { + constant->dim(i).set(shape.at(i)); + size *= shape.at(i); + } + + constant->size<InputT>(size); + + uint32_t outer_size = 1; + for (uint32_t i = 0; i < static_cast<uint32_t>(axis); ++i) + { + outer_size *= params->dim(i).value(); + } + + uint32_t inner_size = 1; + for (uint32_t i = axis + 1; i < rank; ++i) + { + inner_size *= params->dim(i).value(); + } + + uint32_t coord_size = 1; + for (uint32_t i = 0; i < indices->rank(); ++i) + { + coord_size *= indices->dim(i).value(); + } + + const auto axis_size = params->dim(axis).value(); + + for (uint32_t outer = 0; outer < outer_size; ++outer) + { + for (uint32_t i = 0; i < coord_size; ++i) + { + constant->at<InputT>((outer * coord_size + i) * inner_size) = + params->at<InputT>((outer * axis_size + indices->at<IndexT>(i)) * inner_size); + } + } + loco::replace(gather_node).with(constant); + + return true; +} + +bool fold_gather(luci::CircleGather *gather_node) +{ + const auto params = dynamic_cast<luci::CircleConst *>(gather_node->params()); + if (not params) + return false; + + const auto indices = dynamic_cast<luci::CircleConst *>(gather_node->indices()); + if (not indices) + return false; + + // TODO: support more types + if (params->dtype() != loco::DataType::S32 and params->dtype() != loco::DataType::S64) + return false; + + if (indices->dtype() != loco::DataType::S32 and indices->dtype() != loco::DataType::S64) + throw std::runtime_error("Unsupported type"); + + if (params->dtype() == loco::DataType::S64) + { + if (indices->dtype() == loco::DataType::S64) + return fold_gather<loco::DataType::S64, loco::DataType::S64>(gather_node); + else + return fold_gather<loco::DataType::S64, loco::DataType::S32>(gather_node); + } + else + { + if (indices->dtype() == loco::DataType::S64) + return fold_gather<loco::DataType::S32, loco::DataType::S64>(gather_node); + else + return fold_gather<loco::DataType::S32, loco::DataType::S32>(gather_node); + } +} + +} // namespace + +namespace luci +{ + +/** + * Constant Folding for Gather Op + **/ +bool FoldGatherPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto gather_node = dynamic_cast<luci::CircleGather *>(node)) + { + if (fold_gather(gather_node)) + changed = true; + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FoldGatherPass.test.cpp b/compiler/luci/pass/src/FoldGatherPass.test.cpp new file mode 100644 index 000000000..b02c034a5 --- /dev/null +++ b/compiler/luci/pass/src/FoldGatherPass.test.cpp @@ -0,0 +1,214 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FoldGatherPass.h" +#include "PassTestGraphs.h" + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +namespace +{ + +/** + * + * Graph that has a Gather S64 Op with const inputs + * + * BEFORE + * params: [Const] (shape: [3], values: [1, 2, 3]) + * indices: [Const] (shape: [1], values: [1]) + * + * [params] [indices] + * | | + * ---[Gather]--- + * + * AFTER + * [Const] (shape: [1], values: [2]) + * + */ +class S64FoldGatherSimpleTest : public luci::ConstantFoldingAddTestGraph, public ::testing::Test +{ +public: + S64FoldGatherSimpleTest() : luci::ConstantFoldingAddTestGraph({1}, loco::DataType::S64) {} + + virtual void SetUp() { init(); } + + loco::Node *createFoldedPattern() override + { + _gather = _g.nodes()->create<luci::CircleGather>(); + _params = _g.nodes()->create<luci::CircleConst>(); + _indices = _g.nodes()->create<luci::CircleConst>(); + + _gather->dtype(loco::DataType::S64); + _params->dtype(loco::DataType::S64); + _indices->dtype(loco::DataType::S64); + + _params->shape({3}); + _indices->shape({1}); + + _params->size<loco::DataType::S64>(3); + _params->at<loco::DataType::S64>(0) = 1; + _params->at<loco::DataType::S64>(1) = 2; + _params->at<loco::DataType::S64>(2) = 3; + + _indices->size<loco::DataType::S64>(1); + _indices->at<loco::DataType::S64>(0) = 1; + + _gather->params(_params); + _gather->indices(_indices); + + _gather->name("gather"); + _params->name("params"); + _indices->name("indices"); + + return _gather; + } + +protected: + luci::CircleGather *_gather = nullptr; + luci::CircleConst *_params = nullptr; + luci::CircleConst *_indices = nullptr; +}; + +/** + * + * Graph that has a Gather S32 Op with axis = 1 and with const inputs + * + * BEFORE + * params: [Const] (shape: [2, 3], values: [0, 1, 2, 3, 4, 5]) + * indices: [Const] (shape: [2], values: [2, 1]) + * + * [params] [indices] + * | | + * ---[Gather]--- + * + * AFTER + * [Const] (shape: [2, 2], values: [2, 1, 5, 4]) + * + */ + +class S32FoldGatherTwoDimsTest : public luci::ConstantFoldingAddTestGraph, public ::testing::Test +{ +public: + S32FoldGatherTwoDimsTest() : luci::ConstantFoldingAddTestGraph({4, 2}, loco::DataType::S32) {} + + virtual void SetUp() { init(); } + + loco::Node *createFoldedPattern() override + { + _gather = _g.nodes()->create<luci::CircleGather>(); + _params = _g.nodes()->create<luci::CircleConst>(); + _indices = _g.nodes()->create<luci::CircleConst>(); + + _gather->dtype(loco::DataType::S32); + _params->dtype(loco::DataType::S32); + _indices->dtype(loco::DataType::S32); + + _params->shape({2, 3}); + _indices->shape({2}); + + _params->size<loco::DataType::S32>(6); + _params->at<loco::DataType::S32>(0) = 0; + _params->at<loco::DataType::S32>(1) = 1; + _params->at<loco::DataType::S32>(2) = 2; + _params->at<loco::DataType::S32>(3) = 3; + _params->at<loco::DataType::S32>(4) = 4; + _params->at<loco::DataType::S32>(5) = 5; + + _indices->size<loco::DataType::S32>(2); + _indices->at<loco::DataType::S32>(0) = 2; + _indices->at<loco::DataType::S32>(1) = 1; + + _gather->params(_params); + _gather->indices(_indices); + + _gather->axis(1); + + _gather->name("gather"); + _params->name("params"); + _indices->name("indices"); + + return _gather; + } + +protected: + luci::CircleGather *_gather = nullptr; + luci::CircleConst *_params = nullptr; + luci::CircleConst *_indices = nullptr; +}; + +} // namespace + +TEST(FoldGatherTest, name) +{ + luci::FoldGatherPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST_F(S64FoldGatherSimpleTest, fold_gather_simple) +{ + luci::FoldGatherPass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + // Chec type, shape, values of folded const + EXPECT_EQ(loco::DataType::S64, folded_const->dtype()); + EXPECT_EQ(1, folded_const->rank()); + EXPECT_EQ(1, folded_const->dim(0).value()); + EXPECT_EQ(2, folded_const->at<loco::DataType::S64>(0)); +} + +TEST_F(S32FoldGatherTwoDimsTest, fold_gather_with_two_dim) +{ + luci::FoldGatherPass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + // Chec type, shape, values of folded const + EXPECT_EQ(loco::DataType::S32, folded_const->dtype()); + EXPECT_EQ(2, folded_const->rank()); + EXPECT_EQ(2, folded_const->dim(0).value()); + EXPECT_EQ(2, folded_const->dim(1).value()); + + EXPECT_EQ(2, folded_const->at<loco::DataType::S32>(0)); + EXPECT_EQ(1, folded_const->at<loco::DataType::S32>(1)); + EXPECT_EQ(5, folded_const->at<loco::DataType::S32>(2)); + EXPECT_EQ(4, folded_const->at<loco::DataType::S32>(3)); +} + +TEST_F(S64FoldGatherSimpleTest, illegal_input_NEG) +{ + _indices->dtype(loco::DataType::FLOAT32); + + luci::FoldGatherPass pass; + EXPECT_ANY_THROW(pass.run(graph())); +} + +TEST_F(S64FoldGatherSimpleTest, illegal_axis_NEG) +{ + _gather->axis(1); + + luci::FoldGatherPass pass; + EXPECT_ANY_THROW(pass.run(graph())); +} diff --git a/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp b/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp index de973a431..68136b244 100644 --- a/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp +++ b/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp @@ -186,12 +186,12 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8) // (1) normal case: qparam is propagated to input_1 and input_2 // (2) input used by other Op: input_1 is an input of input_2. qparam is propagated only to // input_2 - // (3) subsequent concat: input_1 is concat. qparam is propagated only to input_2 + // (3) subsequent concat: input_1 is concat. qparam is propagated to subsequent concat // (4) const input: input_1 is const. constant values are quantized // normal case: qparam of concat_node is propagated to input_1 and input_2 SimpleConcatGraph g(loco::DataType::U8); - luci::propagate_concat_quantparam(&g.concat_node, loco::DataType::U8); + luci::propagate_concat_quantparam(&g.concat_node); EXPECT_FLOAT_EQ(3.14, g.concat_node.quantparam()->scale[0]); EXPECT_EQ(77, g.concat_node.quantparam()->zerop[0]); EXPECT_FLOAT_EQ(3.14, g.input_1.quantparam()->scale[0]); @@ -202,7 +202,7 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8) // input_1 is an input of input_2. qparam is propagated only to input_2 SimpleConcatGraph g2(loco::DataType::U8); g2.input_2.input(&g2.input_1); - luci::propagate_concat_quantparam(&g2.concat_node, loco::DataType::U8); + luci::propagate_concat_quantparam(&g2.concat_node); EXPECT_FLOAT_EQ(3.14, g2.concat_node.quantparam()->scale[0]); EXPECT_EQ(77, g2.concat_node.quantparam()->zerop[0]); EXPECT_FLOAT_EQ(1.0, g2.input_1.quantparam()->scale[0]); @@ -210,19 +210,19 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8) EXPECT_FLOAT_EQ(3.14, g2.input_2.quantparam()->scale[0]); EXPECT_EQ(77, g2.input_2.quantparam()->zerop[0]); - // input_1 is concat. qparam is propagated only to input_2 + // input_1 is concat. qparam is propagated to subsequent concat SubsequentConcatGraph sg(loco::DataType::U8); - luci::propagate_concat_quantparam(&sg.concat_node, loco::DataType::U8); + luci::propagate_concat_quantparam(&sg.concat_node); EXPECT_FLOAT_EQ(3.14, sg.concat_node.quantparam()->scale[0]); EXPECT_EQ(77, sg.concat_node.quantparam()->zerop[0]); - EXPECT_FLOAT_EQ(1.0, sg.input_1.quantparam()->scale[0]); - EXPECT_EQ(1, sg.input_1.quantparam()->zerop[0]); + EXPECT_FLOAT_EQ(3.14, sg.input_1.quantparam()->scale[0]); + EXPECT_EQ(77, sg.input_1.quantparam()->zerop[0]); EXPECT_FLOAT_EQ(3.14, sg.input_2.quantparam()->scale[0]); EXPECT_EQ(77, sg.input_2.quantparam()->zerop[0]); // input_1 is const. const values are quantized with the qparam of concat ConstInputConcatGraph cg(loco::DataType::U8); - luci::propagate_concat_quantparam(cg.concat_node, loco::DataType::U8); + luci::propagate_concat_quantparam(cg.concat_node); EXPECT_FLOAT_EQ(0.1, cg.concat_node->quantparam()->scale[0]); EXPECT_EQ(10, cg.concat_node->quantparam()->zerop[0]); const auto cg_input_1 = loco::must_cast<luci::CircleConst *>(cg.concat_node->values(0)); @@ -248,7 +248,7 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8_NEG) // concat has fused activation function g.concat_node.fusedActivationFunction(luci::FusedActFunc::RELU); - luci::propagate_concat_quantparam(&g.concat_node, loco::DataType::U8); + luci::propagate_concat_quantparam(&g.concat_node); EXPECT_FLOAT_EQ(3.14, g.concat_node.quantparam()->scale[0]); EXPECT_EQ(77, g.concat_node.quantparam()->zerop[0]); EXPECT_FLOAT_EQ(1.0, g.input_1.quantparam()->scale[0]); @@ -261,7 +261,7 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8_NEG) // const values are quantized using its min/max ConstInputConcatGraph cg(loco::DataType::U8); cg.concat_node->fusedActivationFunction(luci::FusedActFunc::RELU); - luci::propagate_concat_quantparam(cg.concat_node, loco::DataType::U8); + luci::propagate_concat_quantparam(cg.concat_node); EXPECT_FLOAT_EQ(0.1, cg.concat_node->quantparam()->scale[0]); EXPECT_EQ(10, cg.concat_node->quantparam()->zerop[0]); const auto cg_input_1 = loco::must_cast<luci::CircleConst *>(cg.concat_node->values(0)); @@ -283,12 +283,12 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16) // (1) normal case: qparam is propagated to input_1 and input_2 // (2) input used by other Op: input_1 is an input of input_2. qparam is propagated only to // input_2 - // (3) subsequent concat: input_1 is concat. qparam is propagated only to input_2 + // (3) subsequent concat: input_1 is concat. qparam is propagated to subsequent concat // (4) const input: input_1 is const. constant values are quantized // normal case: qparam of concat_node is propagated to input_1 and input_2 SimpleConcatGraph g(loco::DataType::S16); - luci::propagate_concat_quantparam(&g.concat_node, loco::DataType::S16); + luci::propagate_concat_quantparam(&g.concat_node); EXPECT_FLOAT_EQ(3.14, g.concat_node.quantparam()->scale[0]); EXPECT_EQ(0, g.concat_node.quantparam()->zerop[0]); EXPECT_FLOAT_EQ(3.14, g.input_1.quantparam()->scale[0]); @@ -299,7 +299,7 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16) // input_1 is an input of input_2. qparam is propagated only to input_2 SimpleConcatGraph g2(loco::DataType::S16); g2.input_2.input(&g2.input_1); - luci::propagate_concat_quantparam(&g2.concat_node, loco::DataType::S16); + luci::propagate_concat_quantparam(&g2.concat_node); EXPECT_FLOAT_EQ(3.14, g2.concat_node.quantparam()->scale[0]); EXPECT_EQ(0, g2.concat_node.quantparam()->zerop[0]); EXPECT_FLOAT_EQ(1.0, g2.input_1.quantparam()->scale[0]); @@ -309,17 +309,17 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16) // input_1 is concat. qparam is propagated only to input_2 SubsequentConcatGraph sg(loco::DataType::S16); - luci::propagate_concat_quantparam(&sg.concat_node, loco::DataType::S16); + luci::propagate_concat_quantparam(&sg.concat_node); EXPECT_FLOAT_EQ(3.14, sg.concat_node.quantparam()->scale[0]); EXPECT_EQ(0, sg.concat_node.quantparam()->zerop[0]); - EXPECT_FLOAT_EQ(1.0, sg.input_1.quantparam()->scale[0]); + EXPECT_FLOAT_EQ(3.14, sg.input_1.quantparam()->scale[0]); EXPECT_EQ(0, sg.input_1.quantparam()->zerop[0]); EXPECT_FLOAT_EQ(3.14, sg.input_2.quantparam()->scale[0]); EXPECT_EQ(0, sg.input_2.quantparam()->zerop[0]); // input_1 is const. const values are quantized with the qparam of concat ConstInputConcatGraph cg(loco::DataType::S16); - luci::propagate_concat_quantparam(cg.concat_node, loco::DataType::S16); + luci::propagate_concat_quantparam(cg.concat_node); EXPECT_FLOAT_EQ(0.1, cg.concat_node->quantparam()->scale[0]); EXPECT_EQ(0, cg.concat_node->quantparam()->zerop[0]); const auto cg_input_1 = loco::must_cast<luci::CircleConst *>(cg.concat_node->values(0)); @@ -345,7 +345,7 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16_NEG) // concat has fused activation function g.concat_node.fusedActivationFunction(luci::FusedActFunc::RELU); - luci::propagate_concat_quantparam(&g.concat_node, loco::DataType::S16); + luci::propagate_concat_quantparam(&g.concat_node); EXPECT_FLOAT_EQ(3.14, g.concat_node.quantparam()->scale[0]); EXPECT_EQ(0, g.concat_node.quantparam()->zerop[0]); EXPECT_FLOAT_EQ(1.0, g.input_1.quantparam()->scale[0]); @@ -358,7 +358,7 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16_NEG) // const values are quantized using its min/max ConstInputConcatGraph cg(loco::DataType::S16); cg.concat_node->fusedActivationFunction(luci::FusedActFunc::RELU); - luci::propagate_concat_quantparam(cg.concat_node, loco::DataType::S16); + luci::propagate_concat_quantparam(cg.concat_node); EXPECT_FLOAT_EQ(0.1, cg.concat_node->quantparam()->scale[0]); EXPECT_EQ(0, cg.concat_node->quantparam()->zerop[0]); const auto cg_input_1 = loco::must_cast<luci::CircleConst *>(cg.concat_node->values(0)); diff --git a/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp b/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp new file mode 100644 index 000000000..b4975486d --- /dev/null +++ b/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp @@ -0,0 +1,482 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/PropagateQParamBackwardPass.h" +#include "QuantizationUtils.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNodeVisitor.h> +#include <luci/Service/Nodes/CircleConst.h> +#include <luci/Log.h> + +#include <cmath> + +namespace +{ + +void quant_const_values(luci::CircleConst *const_node, float scaling_factor, float zerop, + loco::DataType quant_type) +{ + uint32_t size = const_node->size<loco::DataType::FLOAT32>(); + + const float scaling_factor_inv = 1.0 / scaling_factor; + std::vector<int32_t> quantized_values(size); + for (uint32_t i = 0; i < size; ++i) + { + auto data = static_cast<double>(const_node->at<loco::DataType::FLOAT32>(i)); + double quantized_data = std::round(data * scaling_factor_inv) + zerop; + constexpr double int_max = static_cast<double>(std::numeric_limits<int32_t>::max()); + constexpr double int_min = static_cast<double>(std::numeric_limits<int32_t>::min()); + quantized_data = std::min(int_max, std::max(int_min, quantized_data)); + + quantized_values[i] = static_cast<int32_t>(quantized_data); + } + + switch (quant_type) + { + case loco::DataType::U8: + const_node->dtype(loco::DataType::U8); // change the type of tensor + const_node->size<loco::DataType::U8>(size); // resize tensor + for (uint32_t i = 0; i < size; ++i) + const_node->at<loco::DataType::U8>(i) = std::min(255, std::max(0, quantized_values[i])); + break; + case loco::DataType::S16: + assert(zerop == 0); + const_node->dtype(loco::DataType::S16); // change the type of tensor + const_node->size<loco::DataType::S16>(size); // resize tensor + for (uint32_t i = 0; i < size; ++i) + const_node->at<loco::DataType::S16>(i) = + std::min(32767, std::max(-32767, quantized_values[i])); + break; + default: + throw std::runtime_error("Unsupported data type"); + } +} + +void overwrite_quantparam(const luci::CircleNode *source, luci::CircleNode *target) +{ + auto source_qparam = source->quantparam(); + if (source_qparam == nullptr) + throw std::runtime_error("source quantparam is not found during overwrite"); + + auto target_qparam = target->quantparam(); + if (target_qparam == nullptr) + { + auto quantparam = std::make_unique<luci::CircleQuantParam>(); + target->quantparam(std::move(quantparam)); + target_qparam = target->quantparam(); + + if (target_qparam == nullptr) + throw std::runtime_error("Creating new quant param failed"); + } + target_qparam->min = source_qparam->min; + target_qparam->max = source_qparam->max; + target_qparam->scale = source_qparam->scale; + target_qparam->zerop = source_qparam->zerop; + target_qparam->quantized_dimension = source_qparam->quantized_dimension; +} + +/** + * Tells if pad_v2 quantization should ignore padding value + * In that case padding const will be quantized with input parameters, and probably clipped + */ +bool ignore_pad_v2_const_quantization(const luci::CirclePadV2 *pad) +{ + // This is a workaround to quantize pad generated from MaxPoolWithArgmax operation properly + // TODO use metadata hints to detect this case + auto const_value_node = dynamic_cast<const luci::CircleConst *>(pad->arg(2)); + if (!const_value_node) + return false; + if (const_value_node->dtype() == loco::DataType::FLOAT32) + { + float const_value = const_value_node->at<loco::DataType::FLOAT32>(0); + if (const_value == std::numeric_limits<float>::lowest()) + return true; + } + return false; +} + +/** EXAMPLE + * + * BEFORE + * + * [CircleNode] [CircleConst] + * (qparam1) (FP32) + * \ / + * \ / + * [CirclePack] + * (qparam2) + * + * AFTER + * + * [CircleNode] [CircleConst] [CircleConst] <- Dead node + * (qparam2) (qparam2) (FP32) + * \ / + * \ / + * [CirclePack] + * (qparam2) + * + * NOTE Quantization parameter of CirclePack (qparam2) is propagated to the inputs. + */ +void propagate_pack_quantparam(luci::CirclePack *pack) +{ + assert(pack->quantparam() != nullptr); + + const auto num_inputs = pack->values_count(); + + for (uint32_t i = 0; i < num_inputs; i++) + { + auto node = loco::must_cast<luci::CircleNode *>(pack->arg(i)); + + // Quantize constant values + if (node->opcode() == luci::CircleOpcode::CIRCLECONST) + { + luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node); + if (const_node->dtype() != loco::DataType::FLOAT32) + throw std::runtime_error("Unsupported data type for constant input of pack Op"); + + const auto pack_qparam = pack->quantparam(); + if (pack_qparam == nullptr) + throw std::runtime_error("quantparam of pack is not found during propagation"); + + assert(pack_qparam->scale.size() == 1); + assert(pack_qparam->zerop.size() == 1); + const auto scaling_factor = pack_qparam->scale[0]; + const auto zerop = pack_qparam->zerop[0]; + + auto new_const = luci::clone(const_node); + quant_const_values(new_const, scaling_factor, zerop, pack->dtype()); + pack->values(i, new_const); + overwrite_quantparam(pack, new_const); + } + else + { + const auto succs = loco::succs(node); + if (succs.size() > 1) + continue; + + // Non-const input must have been quantized + assert(node->quantparam() != nullptr); + overwrite_quantparam(pack, node); + } + } +} + +/** EXAMPLE + * + * + * + * BEFORE + * + * [CircleNode] [CircleConst] [CircleConst] [CircleNode] + * (S32) (S32) (FP32) (U8 qparam1) + * \ \ / / + * \ \ / / + * \ \ / / + * -------[CircleOneHot]------- + * (U8 qparam2) + * + * AFTER + * + * [CircleNode] [CircleConst] [CircleConst] [CircleNode] [CircleConst] <- Dead node + * (S32) (S32) (U8 qparam2) (U8 qparam2) (FP32) + * \ \ / / + * \ \ / / + * \ \ / / + * -------[CircleOneHot]------- + * (U8 qparam2) + * + * NOTE Quantization parameter of CircleOneHot (qparam2) is propagated to on_value/off_value. + */ +void propagate_one_hot_quantparam(luci::CircleOneHot *one_hot) +{ + assert(one_hot->quantparam() != nullptr); + + // Propagate quantization parameters from output to inputs, + // to fit both input and counstant_value in one quant range. + auto quant_input = [one_hot](void (luci::CircleOneHot::*arg_setter)(loco::Node *), + loco::Node *(luci::CircleOneHot::*arg_getter)() const) { + auto node = loco::must_cast<luci::CircleNode *>((one_hot->*arg_getter)()); + + // Quantize constant values + if (node->opcode() == luci::CircleOpcode::CIRCLECONST) + { + luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node); + if (is_quantized(const_node)) + return; + + if (const_node->dtype() != loco::DataType::FLOAT32) + throw std::runtime_error("Unsupported data type for constant input of OneHot Op"); + + const auto qparam = one_hot->quantparam(); + if (qparam == nullptr) + throw std::runtime_error("quantparam of OneHot is not found during propagation"); + + assert(qparam->scale.size() == 1); + const auto scaling_factor = qparam->scale.at(0); + const auto zerop = qparam->zerop.at(0); + + auto new_const = luci::clone(const_node); + quant_const_values(new_const, scaling_factor, zerop, one_hot->dtype()); + overwrite_quantparam(one_hot, new_const); + (one_hot->*arg_setter)(new_const); + } + else + { + const auto succs = loco::succs(node); + if (succs.size() > 1) + return; + + // Non-const input must have been quantized + assert(node->quantparam() != nullptr); + overwrite_quantparam(one_hot, node); + } + }; + + quant_input(&luci::CircleOneHot::on_value, &luci::CircleOneHot::on_value); + quant_input(&luci::CircleOneHot::off_value, &luci::CircleOneHot::off_value); +} + +} // namespace + +namespace luci +{ + +/** BEFORE + * + * [CircleNode] [CircleConst] + * (U8 qparam1) (FP32) + * \ / + * \ / + * [CircleConcatenation] + * (U8 qparam2) + * + * AFTER + * [CircleNode] [CircleConst] [CircleConst] <- Dead node + * (U8 qparam2) (U8 qparam2) (FP32) + * \ / + * \ / + * [CircleConcatenation] + * (U8 qparam2) + */ +void propagate_concat_quantparam(luci::CircleConcatenation *concat) +{ + assert(concat->quantparam() != nullptr); + + const auto num_inputs = concat->numValues(); + + // Quantize const inputs using their values if concat has fused act function + if (concat->fusedActivationFunction() != luci::FusedActFunc::NONE) + { + for (uint32_t i = 0; i < num_inputs; i++) + { + auto node = concat->arg(i); + auto const_node = dynamic_cast<luci::CircleConst *>(node); + if (const_node != nullptr) + { + auto new_const = luci::clone(const_node); + quant_const(new_const, concat->dtype()); + concat->values(i, new_const); + } + } + return; + } + + for (uint32_t i = 0; i < num_inputs; i++) + { + auto node = loco::must_cast<luci::CircleNode *>(concat->arg(i)); + + // Quantize constant values + if (node->opcode() == luci::CircleOpcode::CIRCLECONST) + { + luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node); + + const auto concat_qparam = concat->quantparam(); + assert(concat_qparam->scale.size() == 1); + const auto scaling_factor = concat_qparam->scale[0]; + const auto zerop = concat_qparam->zerop[0]; + + auto new_const = luci::clone(const_node); + quant_const_values(new_const, scaling_factor, zerop, concat->dtype()); + concat->values(i, new_const); + overwrite_quantparam(concat, new_const); + } + else + { + const auto succs = loco::succs(node); + if (succs.size() > 1) + continue; + + // Non-const input must have been quantized + assert(node->quantparam() != nullptr); + overwrite_quantparam(concat, node); + } + } +} + +/** BEFORE + * + * [CircleNode] [CircleConst] [CircleConst] + * (U8 qparam1) (S32) (FP32) + * \ | / + * \ | / + * [CirclePadV2] + * (U8 qparam2) + * + * AFTER (case 1) + * + * By default qparam is propagated from output to inputs to meet backend requirements. + * + * [CircleNode] [CircleConst] [CircleConst] [CircleConst] <- Dead node + * (U8 qparam2) (S32) (U8 qparam2) (FP32) + * \ | / + * \ | / + * [CirclePadV2] + * (U8 qparam2) + * + * AFTER (case 2) + * + * In case padded value is the lowest float value + * Qparam is propagated from input to output and constant. + * + * This is a special case for optimization constructed pad, needed to guarantee that + * extremely large negative constant do not stretch output quantization range. + * + * [CircleNode] [CircleConst] [CircleConst] [CircleConst] <- Dead node + * (U8 qparam1) (S32) (U8 qparam1) (FP32) + * \ | / + * \ | / + * [CirclePadV2] + * (U8 qparam1) + */ +void propagate_pad_v2_quantparam(luci::CirclePadV2 *pad_v2) +{ + if (ignore_pad_v2_const_quantization(pad_v2)) + { + // propagate input quantization paramters from input to output and padding const value + auto pad_v2_input = loco::must_cast<luci::CircleNode *>(pad_v2->arg(0)); + overwrite_quantparam(pad_v2_input, pad_v2); + + auto const_value_node = loco::must_cast<luci::CircleConst *>( + pad_v2->arg(2)); // FIX ignore_pad_v2_const_quantization UNLESS + auto new_const = luci::clone(const_value_node); + + const auto pad_v2_input_qparam = pad_v2_input->quantparam(); + assert(pad_v2_input_qparam != nullptr); + assert(pad_v2_input_qparam->scale.size() == 1); + const auto scaling_factor = pad_v2_input_qparam->scale.at(0); + const auto zerop = pad_v2_input_qparam->zerop.at(0); + + quant_const_values(new_const, scaling_factor, zerop, pad_v2->dtype()); + overwrite_quantparam(pad_v2_input, new_const); + pad_v2->constant_values(new_const); + return; + } + + // Propagate quantization paramters from output to inputs, + // to fit both input and counstant_value in one quant range. + auto quant_input = [pad_v2](void (CirclePadV2::*arg_setter)(loco::Node *), uint32_t arg) { + auto node = loco::must_cast<luci::CircleNode *>(pad_v2->arg(arg)); + + // Quantize constant values + if (node->opcode() == luci::CircleOpcode::CIRCLECONST) + { + luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node); + if (is_quantized(const_node)) + return; + + if (const_node->dtype() != loco::DataType::FLOAT32) + throw std::runtime_error("Unsupported data type for constant input of PadV2 Op"); + + const auto pad_v2_qparam = pad_v2->quantparam(); + if (pad_v2_qparam == nullptr) + throw std::runtime_error("quantparam of PadV2 is not found during propagation"); + + assert(pad_v2_qparam->scale.size() == 1); + const auto scaling_factor = pad_v2_qparam->scale.at(0); + const auto zerop = pad_v2_qparam->zerop.at(0); + + auto new_const = luci::clone(const_node); + quant_const_values(new_const, scaling_factor, zerop, pad_v2->dtype()); + overwrite_quantparam(pad_v2, new_const); + (pad_v2->*arg_setter)(new_const); + } + else + { + const auto succs = loco::succs(node); + if (succs.size() > 1) + return; + + // Non-const input must have been quantized + assert(node->quantparam() != nullptr); + overwrite_quantparam(pad_v2, node); + } + }; + + quant_input(&CirclePadV2::input, 0); + quant_input(&CirclePadV2::constant_values, 2); +} + +} // namespace luci + +namespace +{ + +// Visitor to propagate quantization parameters backwards +struct PropagateQParamBackward final : public luci::CircleNodeMutableVisitor<void> +{ + void visit(luci::CircleNode *) {} + + void visit(luci::CircleConcatenation *node) { propagate_concat_quantparam(node); } + + void visit(luci::CircleOneHot *node) { propagate_one_hot_quantparam(node); } + + void visit(luci::CirclePack *node) { propagate_pack_quantparam(node); } + + void visit(luci::CirclePadV2 *node) { propagate_pad_v2_quantparam(node); } +}; + +} // namespace + +namespace luci +{ + +bool PropagateQParamBackwardPass::run(loco::Graph *g) +{ + LOGGER(l); + + // We use reverse post-order traversal as qparam is propagated backward + auto nodes = loco::postorder_traversal(loco::output_nodes(g)); + std::reverse(nodes.begin(), nodes.end()); + for (auto node : nodes) + { + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + INFO(l) << "PropagateQParamBackwardPass visit node: " << circle_node->name() << std::endl; + + // We can't propagate non-existent qparam + if (circle_node->quantparam() == nullptr) + continue; + + PropagateQParamBackward pqb; + circle_node->accept(&pqb); + } + + // This pass is only run once, so return false + // TODO Refactoring not to return meaningless value + return false; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/PropagateQParamBackwardPass.test.cpp b/compiler/luci/pass/src/PropagateQParamBackwardPass.test.cpp new file mode 100644 index 000000000..33af70449 --- /dev/null +++ b/compiler/luci/pass/src/PropagateQParamBackwardPass.test.cpp @@ -0,0 +1,167 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/PropagateQParamBackwardPass.h" + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +using namespace luci; + +namespace +{ + +void set_qparam(luci::CircleNode *node, float scale, int64_t zp) +{ + auto qparam = std::make_unique<luci::CircleQuantParam>(); + qparam->scale.emplace_back(scale); + qparam->zerop.emplace_back(zp); + + node->quantparam(std::move(qparam)); +} + +/** + * @brief Base Test Graph + */ +struct TestGraph +{ +public: + virtual void init(void) = 0; +}; + +/** + * Graph with two concats + * + * [CircleInput] [CircleConst] + * \ / + * [CircleConcatenation] [CircleConst] + * | | + * [CircleConcatenation] + * | + * [CircleOutput] + * + * BEFORE + * - Concat1 and Concat 2 have different qparams + * + * AFTER + * - All Ops have the same qparam + */ +struct SubsequentConcatGraph : public TestGraph +{ +public: + void init(void) final + { + // graph input and output + auto graph_input = g.inputs()->create(); + auto graph_output = g.outputs()->create(); + + // input + input = g.nodes()->create<luci::CircleInput>(); + input->index(graph_input->index()); + input->shape({1, 4, 4, 3}); + input->dtype(loco::DataType::U8); + set_qparam(input, 1.0, 1); + + // const1 + const1 = g.nodes()->create<luci::CircleConst>(); + const1->shape({1, 4, 4, 3}); + const1->dtype(loco::DataType::FLOAT32); + const1->size<loco::DataType::FLOAT32>(48); + for (uint32_t i = 0; i < 48; i++) + const1->at<loco::DataType::FLOAT32>(i) = i; + + // concat1 + concat1 = g.nodes()->create<luci::CircleConcatenation>(2); + concat1->shape({1, 4, 4, 6}); + concat1->dtype(loco::DataType::U8); + set_qparam(concat1, 2.0, 2); + concat1->values(0, input); + concat1->values(1, const1); + concat1->fusedActivationFunction(luci::FusedActFunc::NONE); + + // const2 + const2 = g.nodes()->create<luci::CircleConst>(); + const2->shape({1, 4, 4, 3}); + const2->dtype(loco::DataType::FLOAT32); + const2->size<loco::DataType::FLOAT32>(48); + for (uint32_t i = 0; i < 48; i++) + const2->at<loco::DataType::FLOAT32>(i) = i; + + // concat2 + concat2 = g.nodes()->create<luci::CircleConcatenation>(2); + concat2->shape({1, 4, 4, 9}); + concat2->dtype(loco::DataType::U8); + set_qparam(concat2, 3.0, 3); + concat2->values(0, concat1); + concat2->values(1, const2); + concat2->fusedActivationFunction(luci::FusedActFunc::NONE); + + // output + output = g.nodes()->create<luci::CircleOutput>(); + output->index(graph_output->index()); + output->from(concat2); + output->shape({1, 4, 4, 9}); + output->dtype(loco::DataType::U8); + set_qparam(output, 3.0, 3); + } + +public: + loco::Graph g; + CircleInput *input = nullptr; + CircleConcatenation *concat1 = nullptr; + CircleConcatenation *concat2 = nullptr; + CircleConst *const1 = nullptr; + CircleConst *const2 = nullptr; + CircleOutput *output = nullptr; +}; + +} // namespace + +TEST(PropagateQParamBackwardPassTest, name) +{ + luci::PropagateQParamBackwardPass pass(loco::DataType::U8); + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST(PropagateQParamBackwardPassTest, subsequent_propagation) +{ + SubsequentConcatGraph graph; + + graph.init(); + + luci::PropagateQParamBackwardPass pass(loco::DataType::U8); + + pass.run(&graph.g); + + EXPECT_EQ(3.0, graph.concat2->quantparam()->scale[0]); + EXPECT_EQ(3, graph.concat2->quantparam()->zerop[0]); + + auto const2 = loco::must_cast<CircleNode *>(graph.concat2->values(1)); + EXPECT_EQ(3.0, const2->quantparam()->scale[0]); + EXPECT_EQ(3, const2->quantparam()->zerop[0]); + + EXPECT_EQ(3.0, graph.concat1->quantparam()->scale[0]); + EXPECT_EQ(3, graph.concat1->quantparam()->zerop[0]); + + auto const1 = loco::must_cast<CircleNode *>(graph.concat1->values(1)); + EXPECT_EQ(3.0, const1->quantparam()->scale[0]); + EXPECT_EQ(3, const1->quantparam()->zerop[0]); + + EXPECT_EQ(3.0, graph.input->quantparam()->scale[0]); + EXPECT_EQ(3, graph.input->quantparam()->zerop[0]); +} diff --git a/compiler/luci/pass/src/PropagateQParamForwardPass.cpp b/compiler/luci/pass/src/PropagateQParamForwardPass.cpp new file mode 100644 index 000000000..003e4c293 --- /dev/null +++ b/compiler/luci/pass/src/PropagateQParamForwardPass.cpp @@ -0,0 +1,194 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/PropagateQParamForwardPass.h" + +#include "QuantizationUtils.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNodeVisitor.h> +#include <luci/Log.h> + +#include <iostream> + +namespace +{ + +bool copy_qparam(luci::CircleQuantParam *src, luci::CircleQuantParam *dst) +{ + assert(src->scale.size() == dst->scale.size()); + assert(src->zerop.size() == dst->zerop.size()); + + // src and dst have the same qparam + if (std::equal(src->scale.begin(), src->scale.end(), dst->scale.begin()) && + std::equal(src->zerop.begin(), src->zerop.end(), dst->zerop.begin()) && + src->quantized_dimension == dst->quantized_dimension) + return false; + + dst->scale.assign(src->scale.begin(), src->scale.end()); + dst->zerop.assign(src->zerop.begin(), src->zerop.end()); + dst->quantized_dimension = src->quantized_dimension; + return true; +} + +bool copy_qparam(luci::CircleNode *src, luci::CircleNode *dst) +{ + // Skip nodes that do not have quantparams + auto src_qparam = src->quantparam(); + if (not src_qparam) + return false; + + auto dst_qparam = dst->quantparam(); + if (not dst_qparam) + return false; + + return copy_qparam(src_qparam, dst_qparam); +} + +// Visitor to propagate quantization parameters +struct PropagateQParamForward final : public luci::CircleNodeMutableVisitor<bool> +{ + PropagateQParamForward() = default; + + bool visit(luci::CircleNode *) { return false; } + + bool visit(luci::CircleGather *node) + { + auto input_node = loco::must_cast<luci::CircleNode *>(node->params()); + return copy_qparam(input_node, node); + } + + bool visit(luci::CircleReshape *node) + { + auto input_node = loco::must_cast<luci::CircleNode *>(node->tensor()); + return copy_qparam(input_node, node); + } + + bool visit(luci::CircleTranspose *node) + { + auto input_node = loco::must_cast<luci::CircleNode *>(node->a()); + return copy_qparam(input_node, node); + } + + bool visit(luci::CircleStridedSlice *node) + { + auto input_node = loco::must_cast<luci::CircleNode *>(node->input()); + return copy_qparam(input_node, node); + } + + bool visit(luci::CircleSplitOut *node) + { + auto split = loco::must_cast<luci::CircleSplit *>(node->input()); + auto input_node = loco::must_cast<luci::CircleNode *>(split->input()); + return copy_qparam(input_node, node); + } + + bool visit(luci::CircleSplitVOut *node) + { + auto splitv = loco::must_cast<luci::CircleSplitV *>(node->input()); + auto input_node = loco::must_cast<luci::CircleNode *>(splitv->input()); + return copy_qparam(input_node, node); + } + + bool visit(luci::CircleUnpackOut *node) + { + auto unpack = loco::must_cast<luci::CircleUnpack *>(node->input()); + auto input_node = loco::must_cast<luci::CircleNode *>(unpack->value()); + return copy_qparam(input_node, node); + } + + // Propagate qparam across Quantize op to ensure + // special qparams (pre-defined values, integer scale) + bool visit(luci::CircleQuantize *node) + { + auto input_node = loco::must_cast<luci::CircleNode *>(node->input()); + + // Skip if input_node is not quantized activation + if (input_node->dtype() != loco::DataType::U8 and input_node->dtype() != loco::DataType::S16) + return false; + + // If input_node and node have the same dtype, Quantize op + // will do rescale, not requantize for mixed-precision + if (input_node->dtype() == node->dtype()) + return false; + + assert(node->dtype() == loco::DataType::U8 or node->dtype() == loco::DataType::S16); + + auto prev_qparam = node->quantparam(); + assert(prev_qparam); + assert(prev_qparam->scale.size() == 1); + assert(prev_qparam->zerop.size() == 1); + + const auto prev_scale = prev_qparam->scale[0]; + const auto prev_zerop = prev_qparam->zerop[0]; + + auto qtype = luci::activation_qtype(input_node); + switch (qtype) + { + case luci::ActivationQType::PreDefinedValue: + node->quantparam(luci::make_predefined_qparam(input_node->opcode(), node->dtype())); + break; + case luci::ActivationQType::IntScale: + luci::set_int_scale(node); + break; + default: + break; + } + + assert(node->quantparam()); + assert(node->quantparam()->scale.size() == 1); + assert(node->quantparam()->zerop.size() == 1); + + const auto scale = node->quantparam()->scale[0]; + const auto zerop = node->quantparam()->zerop[0]; + + // Compare qparam with saved values to detect update + return scale != prev_scale or zerop != prev_zerop; + } +}; + +} // namespace + +namespace luci +{ + +bool PropagateQParamForwardPass::run(loco::Graph *g) +{ + bool changed = false; + LOGGER(l); + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + INFO(l) << "PropagateQParamForwardPass visit node: " << circle_node->name() << std::endl; + + PropagateQParamForward pqp; + if (circle_node->accept(&pqp)) + changed = true; + + if (_TF_style_maxpool) + { + if (auto maxpool = dynamic_cast<luci::CircleMaxPool2D *>(node)) + { + auto input = loco::must_cast<luci::CircleNode *>(maxpool->value()); + copy_qparam(input, maxpool); + } + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/PropagateQParamForwardPass.test.cpp b/compiler/luci/pass/src/PropagateQParamForwardPass.test.cpp new file mode 100644 index 000000000..a734c0873 --- /dev/null +++ b/compiler/luci/pass/src/PropagateQParamForwardPass.test.cpp @@ -0,0 +1,260 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/PropagateQParamForwardPass.h" + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +namespace +{ + +void addQuantParam(luci::CircleNode *node, const std::vector<float> &scale, + const std::vector<int64_t> &zp) +{ + assert(node->quantparam() == nullptr); + + auto quantparam = std::make_unique<luci::CircleQuantParam>(); + quantparam->scale = scale; + quantparam->zerop = zp; + node->quantparam(std::move(quantparam)); +} + +/** + * Simple graph for test + * + * BEFORE + * + * [Conv] (qparam 1) + * | + * [Reshape] (qparam 2) + * + * AFTER + * + * [Conv] (qparam 2) + * | + * [Reshape] (qparam 2) + * + */ +class SimpleGraph +{ +public: + SimpleGraph() + { + input = g.nodes()->create<luci::CircleInput>(); + conv = g.nodes()->create<luci::CircleConv2D>(); + reshape = g.nodes()->create<luci::CircleReshape>(); + output = g.nodes()->create<luci::CircleOutput>(); + + auto graph_input = g.inputs()->create(); + input->index(graph_input->index()); + auto graph_output = g.outputs()->create(); + output->index(graph_output->index()); + + addQuantParam(conv, {0.1, 0.2, 0.3}, {0, 10, 20}); + addQuantParam(reshape, {0.2, 0.4, 0.6}, {-10, 0, 10}); + + conv->input(input); + reshape->tensor(conv); + output->from(reshape); + } + +public: + loco::Graph g; + luci::CircleInput *input = nullptr; + luci::CircleConv2D *conv = nullptr; + luci::CircleReshape *reshape = nullptr; + luci::CircleOutput *output = nullptr; +}; + +/** + * Test graph for forward propagation in Quantize Op + * + * BEFORE + * + * [Tanh U8] (qparam 1 - pre-defined for U8) + * | + * [Quantize S16] (qparam 2 - not pre-defined value) + * + * AFTER + * + * [Tanh U8] (qparam 1 - pre-defined for U8) + * | + * [Quantize S16] (qparam 3 - pre-defined for S16) + * + */ +class TanhQuantizeGraph +{ +public: + TanhQuantizeGraph() + { + input = g.nodes()->create<luci::CircleInput>(); + tanh = g.nodes()->create<luci::CircleTanh>(); + quantize = g.nodes()->create<luci::CircleQuantize>(); + output = g.nodes()->create<luci::CircleOutput>(); + + auto graph_input = g.inputs()->create(); + input->index(graph_input->index()); + auto graph_output = g.outputs()->create(); + output->index(graph_output->index()); + + tanh->dtype(loco::DataType::U8); + quantize->dtype(loco::DataType::S16); + + addQuantParam(tanh, {2.0f / 256.0f}, {128}); // pre-defined qparam for U8 + addQuantParam(quantize, {1.0}, {0}); // not pre-defined values + + tanh->x(input); + quantize->input(tanh); + output->from(quantize); + } + +public: + loco::Graph g; + luci::CircleInput *input = nullptr; + luci::CircleTanh *tanh = nullptr; + luci::CircleQuantize *quantize = nullptr; + luci::CircleOutput *output = nullptr; +}; + +/** + * Test graph for forward propagation in Quantize Op + * + * BEFORE + * + * [Floor U8] (qparam 1 - int scale) + * | + * [Quantize S16] (qparam 2 - not int scale) + * + * AFTER + * + * [Floor U8] (qparam 1 - int scale) + * | + * [Quantize S16] (qparam 3 - int scale) + * + */ +class FloorQuantizeGraph +{ +public: + FloorQuantizeGraph() + { + input = g.nodes()->create<luci::CircleInput>(); + floor = g.nodes()->create<luci::CircleFloor>(); + quantize = g.nodes()->create<luci::CircleQuantize>(); + output = g.nodes()->create<luci::CircleOutput>(); + + auto graph_input = g.inputs()->create(); + input->index(graph_input->index()); + auto graph_output = g.outputs()->create(); + output->index(graph_output->index()); + + floor->dtype(loco::DataType::U8); + quantize->dtype(loco::DataType::S16); + + addQuantParam(floor, {4.0f}, {128}); // int scale + addQuantParam(quantize, {0.3}, {0}); // not int scale + + floor->x(input); + quantize->input(floor); + output->from(quantize); + } + +public: + loco::Graph g; + luci::CircleInput *input = nullptr; + luci::CircleFloor *floor = nullptr; + luci::CircleQuantize *quantize = nullptr; + luci::CircleOutput *output = nullptr; +}; + +} // namespace + +TEST(PropagateQParamForwardPassTest, name) +{ + luci::PropagateQParamForwardPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST(PropagateQParamForward, simple) +{ + SimpleGraph g; + + luci::PropagateQParamForwardPass pass; + while (pass.run(&g.g)) + ; + + EXPECT_FLOAT_EQ(0.1, g.reshape->quantparam()->scale[0]); + EXPECT_FLOAT_EQ(0.2, g.reshape->quantparam()->scale[1]); + EXPECT_FLOAT_EQ(0.3, g.reshape->quantparam()->scale[2]); + EXPECT_EQ(0, g.reshape->quantparam()->zerop[0]); + EXPECT_EQ(10, g.reshape->quantparam()->zerop[1]); + EXPECT_EQ(20, g.reshape->quantparam()->zerop[2]); +} + +TEST(PropagateQParamForward, wrong_op_NEG) +{ + SimpleGraph g; + g.output->from(g.conv); + g.reshape->drop(); + + luci::PropagateQParamForwardPass pass; + while (pass.run(&g.g)) + ; + + EXPECT_FLOAT_EQ(0.1, g.conv->quantparam()->scale[0]); + EXPECT_FLOAT_EQ(0.2, g.conv->quantparam()->scale[1]); + EXPECT_FLOAT_EQ(0.3, g.conv->quantparam()->scale[2]); + EXPECT_EQ(0, g.conv->quantparam()->zerop[0]); + EXPECT_EQ(10, g.conv->quantparam()->zerop[1]); + EXPECT_EQ(20, g.conv->quantparam()->zerop[2]); +} + +TEST(PropagateQParamForward, tanh_predefined_value) +{ + TanhQuantizeGraph g; + + luci::PropagateQParamForwardPass pass; + while (pass.run(&g.g)) + ; + + EXPECT_FLOAT_EQ(1.0f / 32768.0f, g.quantize->quantparam()->scale[0]); +} + +TEST(PropagateQParamForward, floor_int_scale) +{ + FloorQuantizeGraph g; + + luci::PropagateQParamForwardPass pass; + while (pass.run(&g.g)) + ; + + EXPECT_FLOAT_EQ(1.0f, g.quantize->quantparam()->scale[0]); +} + +TEST(PropagateQParamForward, same_dtype_NEG) +{ + FloorQuantizeGraph g; + g.quantize->dtype(loco::DataType::U8); + + luci::PropagateQParamForwardPass pass; + while (pass.run(&g.g)) + ; + + // Qparam is not propagated as ifm/ofm of Quantize Op have the same dtype + EXPECT_FLOAT_EQ(0.3f, g.quantize->quantparam()->scale[0]); +} diff --git a/compiler/luci/pass/src/PropagateQuantParamPass.cpp b/compiler/luci/pass/src/PropagateQuantParamPass.cpp deleted file mode 100644 index b1cb7a418..000000000 --- a/compiler/luci/pass/src/PropagateQuantParamPass.cpp +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "luci/Pass/PropagateQuantParamPass.h" - -#include <luci/IR/CircleNodes.h> -#include <luci/IR/CircleNodeVisitor.h> -#include <luci/Log.h> - -#include <iostream> - -namespace -{ - -bool copy_qparam(luci::CircleQuantParam *src, luci::CircleQuantParam *dst) -{ - assert(src->scale.size() == dst->scale.size()); - assert(src->zerop.size() == dst->zerop.size()); - - // src and dst have the same qparam - if (std::equal(src->scale.begin(), src->scale.end(), dst->scale.begin()) && - std::equal(src->zerop.begin(), src->zerop.end(), dst->zerop.begin()) && - src->quantized_dimension == dst->quantized_dimension) - return false; - - dst->scale.assign(src->scale.begin(), src->scale.end()); - dst->zerop.assign(src->zerop.begin(), src->zerop.end()); - dst->quantized_dimension = src->quantized_dimension; - return true; -} - -bool copy_qparam(luci::CircleNode *src, luci::CircleNode *dst) -{ - // Skip nodes that do not have quantparams - auto src_qparam = src->quantparam(); - if (not src_qparam) - return false; - - auto dst_qparam = dst->quantparam(); - if (not dst_qparam) - return false; - - return copy_qparam(src_qparam, dst_qparam); -} - -// Visitor to propagate quantization parameters -struct PropagateQuantParam final : public luci::CircleNodeMutableVisitor<bool> -{ - PropagateQuantParam() = default; - - bool visit(luci::CircleNode *) { return false; } - - bool visit(luci::CircleReshape *node) - { - auto input = node->tensor(); - if (loco::succs(input).size() != 1) - return false; - - auto input_node = loco::must_cast<luci::CircleNode *>(input); - return copy_qparam(input_node, node); - } - - bool visit(luci::CircleTranspose *node) - { - auto input_node = loco::must_cast<luci::CircleNode *>(node->a()); - return copy_qparam(input_node, node); - } - - // TODO : Add more Ops (e.g., layout-changing Ops) -}; - -} // namespace - -namespace luci -{ - -bool PropagateQuantParamPass::run(loco::Graph *g) -{ - bool changed = false; - LOGGER(l); - for (auto node : loco::active_nodes(loco::output_nodes(g))) - { - auto circle_node = loco::must_cast<luci::CircleNode *>(node); - INFO(l) << "PropagateQuantParamPass visit node: " << circle_node->name() << std::endl; - - PropagateQuantParam pqp; - if (circle_node->accept(&pqp)) - changed = true; - } - - return changed; -} - -} // namespace luci diff --git a/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp b/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp deleted file mode 100644 index 0f1564223..000000000 --- a/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "luci/Pass/PropagateQuantParamPass.h" - -#include <luci/IR/CircleNodes.h> - -#include <gtest/gtest.h> - -namespace -{ - -void addQuantParam(luci::CircleNode *node, const std::vector<float> &scale, - const std::vector<int64_t> &zp) -{ - assert(node->quantparam() == nullptr); - - auto quantparam = std::make_unique<luci::CircleQuantParam>(); - quantparam->scale = scale; - quantparam->zerop = zp; - node->quantparam(std::move(quantparam)); -} - -/** - * Simple graph for test - * - * BEFORE - * - * [Conv] (qparam 1) - * | - * [Reshape] (qparam 2) - * - * AFTER - * - * [Conv] (qparam 2) - * | - * [Reshape] (qparam 2) - * - */ -class SimpleGraph -{ -public: - SimpleGraph() - { - input = g.nodes()->create<luci::CircleInput>(); - conv = g.nodes()->create<luci::CircleConv2D>(); - reshape = g.nodes()->create<luci::CircleReshape>(); - output = g.nodes()->create<luci::CircleOutput>(); - - auto graph_input = g.inputs()->create(); - input->index(graph_input->index()); - auto graph_output = g.outputs()->create(); - output->index(graph_output->index()); - - addQuantParam(conv, {0.1, 0.2, 0.3}, {0, 10, 20}); - addQuantParam(reshape, {0.2, 0.4, 0.6}, {-10, 0, 10}); - - conv->input(input); - reshape->tensor(conv); - output->from(reshape); - } - -public: - loco::Graph g; - luci::CircleInput *input; - luci::CircleConv2D *conv; - luci::CircleReshape *reshape; - luci::CircleOutput *output; -}; - -} // namespace - -TEST(PropagateQuantParamPassTest, name) -{ - luci::PropagateQuantParamPass pass; - auto const name = pass.name(); - ASSERT_NE(nullptr, name); -} - -TEST(PropagateQuantParam, simple) -{ - SimpleGraph g; - - luci::PropagateQuantParamPass pass; - while (pass.run(&g.g)) - ; - - EXPECT_FLOAT_EQ(0.1, g.reshape->quantparam()->scale[0]); - EXPECT_FLOAT_EQ(0.2, g.reshape->quantparam()->scale[1]); - EXPECT_FLOAT_EQ(0.3, g.reshape->quantparam()->scale[2]); - EXPECT_EQ(0, g.reshape->quantparam()->zerop[0]); - EXPECT_EQ(10, g.reshape->quantparam()->zerop[1]); - EXPECT_EQ(20, g.reshape->quantparam()->zerop[2]); -} - -TEST(PropagateQuantParam, wrong_op_NEG) -{ - SimpleGraph g; - g.output->from(g.conv); - g.reshape->drop(); - - luci::PropagateQuantParamPass pass; - while (pass.run(&g.g)) - ; - - EXPECT_FLOAT_EQ(0.1, g.conv->quantparam()->scale[0]); - EXPECT_FLOAT_EQ(0.2, g.conv->quantparam()->scale[1]); - EXPECT_FLOAT_EQ(0.3, g.conv->quantparam()->scale[2]); - EXPECT_EQ(0, g.conv->quantparam()->zerop[0]); - EXPECT_EQ(10, g.conv->quantparam()->zerop[1]); - EXPECT_EQ(20, g.conv->quantparam()->zerop[2]); -} diff --git a/compiler/luci/pass/src/QuantizationUtils.cpp b/compiler/luci/pass/src/QuantizationUtils.cpp index 2f6fed46e..ad86cedf4 100644 --- a/compiler/luci/pass/src/QuantizationUtils.cpp +++ b/compiler/luci/pass/src/QuantizationUtils.cpp @@ -33,43 +33,6 @@ bool is_quantized(const CircleNode *node) node->dtype() == loco::DataType::S64); // bias (int16 quant) } -// Check if node is weights of conv2d, depthwise_conv2d, or fully_connected layer -bool is_weights(CircleNode *node) -{ - auto circle_const = dynamic_cast<CircleConst *>(node); - if (circle_const == nullptr) - return false; - - auto succs = loco::succs(node); - - // Node is weights if it is the weights of all of its successors - for (auto out : succs) - { - bool is_weights = false; - - auto conv = dynamic_cast<CircleConv2D *>(out); - if (conv != nullptr && conv->filter() == circle_const) - is_weights = true; - - auto dw_conv = dynamic_cast<CircleDepthwiseConv2D *>(out); - if (dw_conv != nullptr && dw_conv->filter() == circle_const) - is_weights = true; - - auto t_conv = dynamic_cast<CircleTransposeConv *>(out); - if (t_conv != nullptr && t_conv->filter() == circle_const && circle_const->rank() == 4) - is_weights = true; - - auto fc = dynamic_cast<CircleFullyConnected *>(out); - if (fc != nullptr && fc->weights() == circle_const) - is_weights = true; - - if (!is_weights) - return false; - } - - return true; -} - uint8_t fp32_to_uint8_cast(float f) { assert(std::numeric_limits<uint8_t>::min() <= f); @@ -77,7 +40,6 @@ uint8_t fp32_to_uint8_cast(float f) return static_cast<uint8_t>(f); } -// Per-layer quantization of weights (const tensor) using given min/max values void asymmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float max, float &scaling_factor, int64_t &zp, float &nudged_min, float &nudged_max) @@ -107,7 +69,6 @@ void asymmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float } } -// Per-layer quantization of weights (const tensor) using given min/max values void symmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float max, float &scaling_factor, int64_t &zp, float &nudged_min, float &nudged_max) @@ -315,4 +276,123 @@ uint32_t cal_offset(loco::TensorShape &dimension, uint32_t *indices) indices[2] * dimension.dim(3).value() + indices[3]; } +ActivationQType activation_qtype(const CircleNode *node) +{ + auto fused_act_node = dynamic_cast<const CircleNodeMixin<CircleNodeTrait::FusedActFunc> *>(node); + if (fused_act_node && fused_act_node->fusedActivationFunction() == FusedActFunc::TANH) + return ActivationQType::PreDefinedValue; + + switch (node->opcode()) + { + case CircleOpcode::LOGISTIC: + case CircleOpcode::TANH: + case CircleOpcode::SOFTMAX: + return ActivationQType::PreDefinedValue; + case CircleOpcode::FLOOR: + case CircleOpcode::FLOOR_DIV: + case CircleOpcode::FLOOR_MOD: + case CircleOpcode::CEIL: + return ActivationQType::IntScale; + default: + break; + } + + return ActivationQType::MinMax; +} + +std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, loco::DataType dtype) +{ + auto qparam = std::make_unique<CircleQuantParam>(); + + auto set_qparam = [&qparam](float scale, int64_t zp) { + qparam->scale.emplace_back(scale); + qparam->zerop.emplace_back(zp); + }; + + switch (opcode) + { + case CircleOpcode::LOGISTIC: + if (dtype == loco::DataType::U8) + set_qparam(1.0f / 256.0f, 0); + else + { + assert(dtype == loco::DataType::S16); + set_qparam(1.0f / 32768.0f, 0); + } + break; + case CircleOpcode::TANH: + if (dtype == loco::DataType::U8) + set_qparam(2.0f / 256.0f, 128); + else + { + assert(dtype == loco::DataType::S16); + set_qparam(1.0f / 32768.0f, 0); + } + break; + case CircleOpcode::SOFTMAX: + if (dtype == loco::DataType::U8) + set_qparam(1.0f / 255.0f, 0); + else + { + assert(dtype == loco::DataType::S16); + set_qparam(1.0f / 32767.0f, 0); + } + break; + default: + throw std::runtime_error("Unsupported opcode with pre-defined qparam"); + } + return std::move(qparam); +} + +// For nodes with integer output, we use integer scale +void set_int_scale(luci::CircleNode *node) +{ + assert(node); // FIX_CALLER_UNLESS + + auto qparam = node->quantparam(); + assert(qparam); // FIX_CALLER_UNLESS + assert(qparam->scale.size() == 1); // FIX_CALLER_UNLESS + + auto fp_scale = qparam->scale[0]; + qparam->scale[0] = fp_scale < 1 ? 1.0f : std::round(fp_scale); +} + +void quant_const(luci::CircleConst *node, loco::DataType quant_type) +{ + assert(node->dtype() == loco::DataType::FLOAT32); + + float min = std::numeric_limits<float>::max(); + float max = std::numeric_limits<float>::lowest(); + for (uint32_t i = 0; i < node->size<loco::DataType::FLOAT32>(); i++) + { + auto data = node->at<loco::DataType::FLOAT32>(i); + min = data < min ? data : min; + max = data > max ? data : max; + } + + float scaling_factor{0.0}; + int64_t zp{0}; + float nudged_min{0.0}; + float nudged_max{0.0}; + + switch (quant_type) + { + case loco::DataType::U8: + asymmetric_wquant_with_minmax_per_layer(node, min, max, scaling_factor, zp, nudged_min, + nudged_max); + break; + case loco::DataType::S16: + symmetric_wquant_with_minmax_per_layer(node, min, max, scaling_factor, zp, nudged_min, + nudged_max); + break; + default: + throw std::runtime_error("Unsupported data type"); + } + + auto quantparam = std::make_unique<luci::CircleQuantParam>(); + quantparam->scale.push_back(scaling_factor); + quantparam->zerop.push_back(zp); + node->quantparam(std::move(quantparam)); +} + } // namespace luci diff --git a/compiler/luci/pass/src/QuantizationUtils.h b/compiler/luci/pass/src/QuantizationUtils.h index 605f6a77e..cd8cec95a 100644 --- a/compiler/luci/pass/src/QuantizationUtils.h +++ b/compiler/luci/pass/src/QuantizationUtils.h @@ -23,33 +23,61 @@ namespace luci { +// Compute scale/zp using given min/max for symmetric quantization (int16) void compute_sym_scale_zp(float min, float max, float &scaling_factor, int64_t &zp, float &nudged_min, float &nudged_max); +// Compute scale/zp using given min/max for asymmetric quantization (uint8) void compute_asym_scale_zp(float min, float max, float &scaling_factor, int64_t &zp, float &nudged_min, float &nudged_max); +// Asymmetric per-layer quantization of weights (const tensor) using given min/max values +// NOTE: in-place update of node data void asymmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float max, float &scaling_factor, int64_t &zp, float &nudged_min, float &nudged_max); +// Symmetric per-layer quantization of weights (const tensor) using given min/max values +// NOTE: in-place update of node data void symmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float max, float &scaling_factor, int64_t &zp, float &nudged_min, float &nudged_max); +// Helper function to get channel dimension +// TODO Embed this function into iterate_per_channel bool get_channel_dim_index(CircleConst *node, loco::TensorShape &dimension, int32_t &channel_dim_index); +// Calculate offset of the given indices in dimension uint32_t cal_offset(loco::TensorShape &dimension, uint32_t *indices); -void propagate_concat_quantparam(luci::CircleConcatenation *concat, loco::DataType quant_type); +// Backward propagation of concatenation qparam +void propagate_concat_quantparam(luci::CircleConcatenation *concat); -void propagate_pad_v2_quantparam(luci::CirclePadV2 *pad_v2, loco::DataType quant_type); - -bool is_weights(CircleNode *node); +// Backward propagation of pad_v2 qparam +void propagate_pad_v2_quantparam(luci::CirclePadV2 *pad_v2); +// Return true if the node is quantized bool is_quantized(const CircleNode *node); +enum ActivationQType +{ + MinMax, // Quantize using recorded min/max + PreDefinedValue, // Quantize using pre-defined values + IntScale, // Round scale to a positive integer +}; + +ActivationQType activation_qtype(const CircleNode *node); + +// Create qparam with pre-defined values for speical operators +std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, loco::DataType dtype); + +// Update node's scale to a positive integer (for special Ops e.g., Floor, Ceil) +void set_int_scale(luci::CircleNode *node); + +// Quantize const tensor using its min/max values +void quant_const(luci::CircleConst *node, loco::DataType quant_type); + } // namespace luci #endif // __LUCI_QUANTIZATION_UTILS_H__ diff --git a/compiler/luci/pass/src/QuantizeActivation.cpp b/compiler/luci/pass/src/QuantizeActivation.cpp new file mode 100644 index 000000000..149331824 --- /dev/null +++ b/compiler/luci/pass/src/QuantizeActivation.cpp @@ -0,0 +1,296 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "QuantizeActivation.h" +#include "QuantizationUtils.h" + +#include <luci/Service/Nodes/CircleConst.h> +#include <luci/Log.h> + +#include <algorithm> +#include <cmath> + +using namespace luci; + +namespace +{ + +bool has_min_max(const CircleNode *node) +{ + return node->quantparam() && !node->quantparam()->min.empty() && !node->quantparam()->max.empty(); +} + +} // namespace + +// QuantizeActivation +namespace luci +{ + +void QuantizeActivation::visit(luci::CircleNode *node) +{ + LOGGER(l); + INFO(l) << "QuantizeActivation visit node: " << node->name() << std::endl; + + // Check if this is already quantized + if (is_quantized(node)) + return; + + // Check if this is bool type (bool type is not quantized) + if (node->dtype() == loco::DataType::BOOL) + return; + + // Check if this is const (const activation is handled by QuantizeConstInputActivation) + // NOTE QuantizePreChecker guarantees weights/bias are const. + // Update this code when we accept non-const weights/bias. + if (node->opcode() == luci::CircleOpcode::CIRCLECONST) + return; + + // Check if this is activation + // We assume min/max are recorded only for activations + if (has_min_max(node)) + { + // Quantize using recorded min/max + auto quantparam = node->quantparam(); + assert(quantparam); + assert(quantparam->min.size() == 1); // only support layer-wise quant + assert(quantparam->max.size() == 1); // only support layer-wise quant + auto min = quantparam->min[0]; + auto max = quantparam->max[0]; + + float scaling_factor{0}; + int64_t zp{0}; + float nudged_min{0}; + float nudged_max{0}; + + if (output_type == loco::DataType::U8) + { + compute_asym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max); + node->dtype(loco::DataType::U8); + } + else + { + compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max); + node->dtype(loco::DataType::S16); + } + + node->quantparam()->scale.push_back(scaling_factor); + node->quantparam()->zerop.push_back(zp); + } + // Fix special attributes + if (node->opcode() == luci::CircleOpcode::CAST) + { + auto *cast = loco::must_cast<luci::CircleCast *>(node); + auto *cast_input = loco::must_cast<luci::CircleNode *>(cast->x()); + + // make sure that cast_input is already quantized + assert(cast_input->dtype() != loco::DataType::FLOAT32); + cast->in_data_type(cast_input->dtype()); + cast->out_data_type(cast->dtype()); + } +} + +} // namespace luci + +// QuantizeSpecialActivation +namespace luci +{ + +void QuantizeSpecialActivation::visit(luci::CircleNode *node) +{ + // Nodes fused with activation functions which need special quantization + auto fused_act_node = dynamic_cast<CircleNodeMixin<CircleNodeTrait::FusedActFunc> *>(node); + if (fused_act_node != nullptr && fused_act_node->fusedActivationFunction() == FusedActFunc::TANH) + { + auto qparam = make_predefined_qparam(luci::CircleOpcode::TANH, output_type); + node->quantparam(std::move(qparam)); + } +} + +void QuantizeSpecialActivation::visit(luci::CircleLogistic *node) +{ + assert(activation_qtype(node) == luci::ActivationQType::PreDefinedValue); + auto qparam = make_predefined_qparam(luci::CircleOpcode::LOGISTIC, output_type); + node->quantparam(std::move(qparam)); +} + +void QuantizeSpecialActivation::visit(luci::CircleTanh *node) +{ + assert(activation_qtype(node) == luci::ActivationQType::PreDefinedValue); + auto qparam = make_predefined_qparam(luci::CircleOpcode::TANH, output_type); + node->quantparam(std::move(qparam)); +} + +void QuantizeSpecialActivation::visit(luci::CircleSoftmax *node) +{ + assert(activation_qtype(node) == luci::ActivationQType::PreDefinedValue); + auto qparam = make_predefined_qparam(luci::CircleOpcode::SOFTMAX, output_type); + node->quantparam(std::move(qparam)); +} + +void QuantizeSpecialActivation::visit(luci::CircleFloor *node) +{ + assert(activation_qtype(node) == luci::ActivationQType::IntScale); + set_int_scale(node); +} + +void QuantizeSpecialActivation::visit(luci::CircleFloorDiv *node) +{ + assert(activation_qtype(node) == luci::ActivationQType::IntScale); + set_int_scale(node); +} + +void QuantizeSpecialActivation::visit(luci::CircleFloorMod *node) +{ + assert(activation_qtype(node) == luci::ActivationQType::IntScale); + set_int_scale(node); +} + +void QuantizeSpecialActivation::visit(luci::CircleCeil *node) +{ + assert(activation_qtype(node) == luci::ActivationQType::IntScale); + set_int_scale(node); +} + +} // namespace luci + +// QuantizeConstInputActivation +namespace luci +{ + +// Default behavior (NYI) +void QuantizeConstInputActivation::visit(luci::CircleNode *node) +{ + for (uint32_t i = 0; i < node->arity(); i++) + { + auto input_node = node->arg(i); + auto const_node = dynamic_cast<luci::CircleConst *>(input_node); + if (const_node != nullptr) + throw std::runtime_error("Unsupported Op for const inputs"); + } +} + +// INPUT_NAME is the only activation of NODE +#define QUANTIZE_SINGLE_CONST_INPUT(NODE, INPUT_NAME) \ + void QuantizeConstInputActivation::visit(NODE *node) \ + { \ + auto input = node->INPUT_NAME(); \ + auto const_node = dynamic_cast<luci::CircleConst *>(input); \ + if (const_node && !is_quantized(const_node)) \ + { \ + auto new_const = luci::clone(const_node); \ + quant_const(new_const, _output_type); \ + node->INPUT_NAME(new_const); \ + } \ + } + +// INPUT_NAME1 and INPUT_NAME2 are the only activations of NODE +#define QUANTIZE_TWO_CONST_INPUTS(NODE, INPUT_NAME1, INPUT_NAME2) \ + void QuantizeConstInputActivation::visit(NODE *node) \ + { \ + auto input1 = node->INPUT_NAME1(); \ + auto const_node1 = dynamic_cast<luci::CircleConst *>(input1); \ + if (const_node1 && !is_quantized(const_node1)) \ + { \ + auto new_const1 = luci::clone(const_node1); \ + quant_const(new_const1, _output_type); \ + node->INPUT_NAME1(new_const1); \ + } \ + auto input2 = node->INPUT_NAME2(); \ + auto const_node2 = dynamic_cast<luci::CircleConst *>(input2); \ + if (const_node2 && !is_quantized(const_node2)) \ + { \ + auto new_const2 = luci::clone(const_node2); \ + quant_const(new_const2, _output_type); \ + node->INPUT_NAME2(new_const2); \ + } \ + } + +// Ops that receive a single activation as an input +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleArgMax, input) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleArgMin, input) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleBatchToSpaceND, input) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleDepthToSpace, input) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleElu, features) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleExp, x) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleFloor, x) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleGather, params) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleLocalResponseNormalization, input) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleLogistic, x) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleMean, input) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleMirrorPad, input) +QUANTIZE_SINGLE_CONST_INPUT(luci::CirclePad, input) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleReduceAny, input) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleReduceProd, input) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleReduceMax, input) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleReduceMin, input) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleReshape, tensor) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleResizeBilinear, input) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleResizeNearestNeighbor, input) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleReverseSequence, input) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleRsqrt, x) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSlice, input) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSoftmax, logits) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSpaceToBatchND, input) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSpaceToDepth, input) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSplit, input) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSplitV, input) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSqrt, x) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleStridedSlice, input) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSum, input) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleTanh, x) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleTile, input) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleTopKV2, input) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleTranspose, a) +QUANTIZE_SINGLE_CONST_INPUT(luci::CircleUnpack, value) + +// Ops that receive two activations as inputs +QUANTIZE_TWO_CONST_INPUTS(luci::CircleAdd, x, y) +QUANTIZE_TWO_CONST_INPUTS(luci::CircleBatchMatMul, x, y) +QUANTIZE_TWO_CONST_INPUTS(luci::CircleDiv, x, y) +QUANTIZE_TWO_CONST_INPUTS(luci::CircleEqual, x, y) +QUANTIZE_TWO_CONST_INPUTS(luci::CircleFloorDiv, x, y) +QUANTIZE_TWO_CONST_INPUTS(luci::CircleGreater, x, y) +QUANTIZE_TWO_CONST_INPUTS(luci::CircleGreaterEqual, x, y) +QUANTIZE_TWO_CONST_INPUTS(luci::CircleLess, x, y) +QUANTIZE_TWO_CONST_INPUTS(luci::CircleLessEqual, x, y) +QUANTIZE_TWO_CONST_INPUTS(luci::CircleMaximum, x, y) +QUANTIZE_TWO_CONST_INPUTS(luci::CircleMinimum, x, y) +QUANTIZE_TWO_CONST_INPUTS(luci::CircleMul, x, y) +QUANTIZE_TWO_CONST_INPUTS(luci::CircleNotEqual, x, y) +QUANTIZE_TWO_CONST_INPUTS(luci::CirclePow, x, y) +QUANTIZE_TWO_CONST_INPUTS(luci::CircleSub, x, y) + +// AddN has arbitrary number of inputs +void QuantizeConstInputActivation::visit(luci::CircleAddN *node) +{ + auto arity = node->arity(); + for (uint32_t i = 0; i < arity; i++) + { + auto input_node = node->inputs(i); + auto const_node = dynamic_cast<luci::CircleConst *>(input_node); + if (const_node && !is_quantized(const_node)) + { + auto new_const = luci::clone(const_node); + quant_const(new_const, _output_type); + node->inputs(i, new_const); + } + } +} + +#undef QUANTIZE_SINGLE_CONST_INPUT +#undef QUANTIZE_TWO_CONST_INPUTS + +} // namespace luci diff --git a/compiler/luci/pass/src/QuantizeActivation.h b/compiler/luci/pass/src/QuantizeActivation.h new file mode 100644 index 000000000..fc32d1cde --- /dev/null +++ b/compiler/luci/pass/src/QuantizeActivation.h @@ -0,0 +1,165 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_QUANTIZATION_ACTIVATION_H__ +#define __LUCI_QUANTIZATION_ACTIVATION_H__ + +#include <luci/IR/CircleNodeVisitor.h> + +namespace luci +{ + +/** + * @brief Quantize non-const activation using recorded min/max values + */ +struct QuantizeActivation final : public luci::CircleNodeMutableVisitor<void> +{ + QuantizeActivation(loco::DataType input, loco::DataType output) + : input_type(input), output_type(output) + { + } + + loco::DataType input_type; + loco::DataType output_type; + + // Quantize each node using recorded min/max + void visit(luci::CircleNode *node); +}; + +/** + * @brief Quantize non-const activaion using pre-defined scale/zp for special Ops + */ +struct QuantizeSpecialActivation final : public luci::CircleNodeMutableVisitor<void> +{ + QuantizeSpecialActivation(loco::DataType input, loco::DataType output) + : input_type(input), output_type(output) + { + } + + loco::DataType input_type; + loco::DataType output_type; + + void visit(luci::CircleNode *node); + void visit(luci::CircleLogistic *node); + void visit(luci::CircleTanh *node); + void visit(luci::CircleSoftmax *node); + void visit(luci::CircleFloor *node); + void visit(luci::CircleFloorDiv *node); + void visit(luci::CircleFloorMod *node); + void visit(luci::CircleCeil *node); +}; + +// Quantize constant input activation of a node +// The input of a node is quantized if it is +// 1. Constant (instance of CircleConst*) +// 2. Activation (other inputs e.g., weights, bias, axis, etc should not be quantized here) +struct QuantizeConstInputActivation final : public luci::CircleNodeMutableVisitor<void> +{ + QuantizeConstInputActivation(loco::DataType output_type) : _output_type(output_type) {} + +private: + loco::DataType _output_type; + +// Skip NODE +#define SKIP(NODE) \ + void visit(NODE *) {} + + // Handled in QuantizeWeights and QuantizeBias + SKIP(luci::CircleConv2D) + SKIP(luci::CircleDepthwiseConv2D) + SKIP(luci::CircleFullyConnected) + SKIP(luci::CircleInstanceNorm) + SKIP(luci::CirclePRelu) + SKIP(luci::CircleTransposeConv) + + // Handled in PropagateQParamBackwardPass + SKIP(luci::CircleConcatenation) + SKIP(luci::CirclePadV2) + SKIP(luci::CirclePack) + SKIP(luci::CircleOneHot) + + // Inputs of logical Ops are bool, thus not quantized + SKIP(luci::CircleLogicalOr) + SKIP(luci::CircleLogicalAnd) + SKIP(luci::CircleLogicalNot) + +#undef SKIP + + // Default behavior (NYI) + void visit(luci::CircleNode *node); + + // Ops that receive a single activation as an input + void visit(luci::CircleArgMax *node); + void visit(luci::CircleArgMin *node); + void visit(luci::CircleBatchToSpaceND *node); + void visit(luci::CircleDepthToSpace *node); + void visit(luci::CircleElu *node); + void visit(luci::CircleExp *node); + void visit(luci::CircleFloor *node); + void visit(luci::CircleGather *node); + void visit(luci::CircleLocalResponseNormalization *node); + void visit(luci::CircleLogistic *node); + void visit(luci::CircleMean *node); + void visit(luci::CircleMirrorPad *node); + void visit(luci::CirclePad *node); + void visit(luci::CircleReduceAny *node); + void visit(luci::CircleReduceProd *node); + void visit(luci::CircleReduceMax *node); + void visit(luci::CircleReduceMin *node); + void visit(luci::CircleReshape *node); + void visit(luci::CircleResizeBilinear *node); + void visit(luci::CircleResizeNearestNeighbor *node); + void visit(luci::CircleReverseSequence *node); + void visit(luci::CircleRsqrt *node); + void visit(luci::CircleSlice *node); + void visit(luci::CircleSoftmax *node); + void visit(luci::CircleSpaceToBatchND *node); + void visit(luci::CircleSpaceToDepth *node); + void visit(luci::CircleSplit *node); + void visit(luci::CircleSplitV *node); + void visit(luci::CircleSqrt *node); + void visit(luci::CircleStridedSlice *node); + void visit(luci::CircleSum *node); + void visit(luci::CircleTanh *node); + void visit(luci::CircleTile *node); + void visit(luci::CircleTopKV2 *node); + void visit(luci::CircleTranspose *node); + void visit(luci::CircleUnpack *node); + + // Ops that receive two activations as inputs + void visit(luci::CircleAdd *node); + void visit(luci::CircleBatchMatMul *node); + void visit(luci::CircleDiv *node); + void visit(luci::CircleEqual *node); + void visit(luci::CircleFloorDiv *node); + void visit(luci::CircleGreater *node); + void visit(luci::CircleGreaterEqual *node); + void visit(luci::CircleLess *node); + void visit(luci::CircleLessEqual *node); + void visit(luci::CircleMaximum *node); + void visit(luci::CircleMinimum *node); + void visit(luci::CircleMul *node); + void visit(luci::CircleNotEqual *node); + void visit(luci::CirclePow *node); + void visit(luci::CircleSub *node); + + // AddN has arbitrary number of inputs + void visit(luci::CircleAddN *node); +}; + +} // namespace luci + +#endif // __LUCI_QUANTIZATION_ACTIVATION_H__ diff --git a/compiler/luci/pass/src/QuantizeBias.cpp b/compiler/luci/pass/src/QuantizeBias.cpp new file mode 100644 index 000000000..aa496232a --- /dev/null +++ b/compiler/luci/pass/src/QuantizeBias.cpp @@ -0,0 +1,300 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "QuantizeBias.h" +#include "QuantizationUtils.h" + +#include <luci/Service/Nodes/CircleConst.h> +#include <luci/Log.h> + +#include <algorithm> +#include <cmath> + +using namespace luci; + +namespace +{ + +// struct to carry Input/Weights/Bias +struct IWB +{ + CircleNode *input = nullptr; + CircleNode *weights = nullptr; + CircleConst *bias = nullptr; + + IWB(loco::Node *i, loco::Node *w, loco::Node *b) + { + input = dynamic_cast<luci::CircleNode *>(i); + weights = dynamic_cast<luci::CircleNode *>(w); + bias = dynamic_cast<luci::CircleConst *>(b); + } + + // Return true if bias can be quantized with valid input an weights + operator bool() + { + if (bias == nullptr || is_quantized(bias)) + return false; + if (input == nullptr || weights == nullptr) + return false; + return true; + } +}; + +// Create a new const node from an existing node. +// The new node has the following characteristics +// type: T +// shape: same with 'node' (given as an argument) +// buffer size: 'size' (given as an argument) +// Note that contents are not filled in this function. +template <loco::DataType T> +luci::CircleConst *create_empty_const_from(luci::CircleConst *node, uint32_t size) +{ + auto new_node = node->graph()->nodes()->create<CircleConst>(); + // TODO: We don't have any naming convention for quantized nodes yet. + // Fix this when we have one. + new_node->name(node->name()); + new_node->dtype(T); + new_node->rank(node->rank()); + for (uint32_t i = 0; i < node->rank(); i++) + new_node->dim(i).set(node->dim(i).value()); + + new_node->size<T>(size); + new_node->shape_status(luci::ShapeStatus::VALID); + + return new_node; +} + +CircleConst *asym_quant_bias_per_layer(CircleConst *node, float input_scale, float weight_scale, + float *scaling_factor, int64_t *zp) +{ + float scale = input_scale * weight_scale; + const float scaling_factor_inv = (scale == 0) ? 0 : 1.0 / scale; + + uint32_t size = node->size<loco::DataType::FLOAT32>(); + std::vector<int32_t> quantized_values(size); + for (uint32_t i = 0; i < size; ++i) + { + quantized_values[i] = + static_cast<int32_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv)); + } + + auto new_bias = create_empty_const_from<loco::DataType::S32>(node, size); + + const int32_t kMinScale = std::numeric_limits<int32_t>::lowest(); + const int32_t kMaxScale = std::numeric_limits<int32_t>::max(); + for (uint32_t i = 0; i < size; ++i) + { + new_bias->at<loco::DataType::S32>(i) = + std::min(kMaxScale, std::max(kMinScale, quantized_values[i])); + } + *scaling_factor = scale; + *zp = 0; + + return new_bias; +} + +CircleConst *quant_bias_per_channel(CircleConst *node, float input_scale, + std::vector<float> &weight_scale, + std::vector<float> &scaling_factor, std::vector<int64_t> &zp) +{ + float scaling_factor_inv{0}; + + uint32_t size = node->size<loco::DataType::FLOAT32>(); + std::vector<int32_t> quantized_values(size); + + for (uint32_t i = 0; i < size; ++i) + { + scaling_factor[i] = input_scale * weight_scale[i]; + scaling_factor_inv = (scaling_factor[i] == 0) ? 0 : 1.0 / scaling_factor[i]; + quantized_values[i] = + static_cast<int32_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv)); + zp[i] = 0; + } + + auto new_bias = create_empty_const_from<loco::DataType::S32>(node, size); + + const int32_t kMinScale = std::numeric_limits<int32_t>::lowest(); + const int32_t kMaxScale = std::numeric_limits<int32_t>::max(); + for (uint32_t i = 0; i < size; ++i) + { + new_bias->at<loco::DataType::S32>(i) = + std::min(kMaxScale, std::max(kMinScale, quantized_values[i])); + } + + return new_bias; +} + +CircleConst *int16_quant_bias_per_channel(CircleConst *node, float input_scale, + std::vector<float> &weight_scale, + std::vector<float> &scaling_factor, + std::vector<int64_t> &zp) +{ + float scaling_factor_inv{0}; + + uint32_t size = node->size<loco::DataType::FLOAT32>(); + std::vector<int64_t> quantized_values(size); + + for (uint32_t i = 0; i < size; ++i) + { + scaling_factor[i] = input_scale * weight_scale[i]; + scaling_factor_inv = (scaling_factor[i] == 0) ? 0 : 1.0 / scaling_factor[i]; + quantized_values[i] = + static_cast<int64_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv)); + zp[i] = 0; + } + + auto new_bias = create_empty_const_from<loco::DataType::S64>(node, size); + + for (uint32_t i = 0; i < size; ++i) + { + new_bias->at<loco::DataType::S64>(i) = quantized_values[i]; + } + + return new_bias; +} + +} // namespace + +namespace luci +{ + +// Return a quantized bias node +CircleConst *QuantizeBias::quantized_bias(CircleNode *input, const CircleNode *weight, + CircleNode *bias) +{ + auto const_bias = loco::must_cast<luci::CircleConst *>(bias); + assert(const_bias->dtype() == loco::DataType::FLOAT32); + + // If input is const, it is quantized here, not in QuantizeActivation + if (auto const_input = dynamic_cast<luci::CircleConst *>(input)) + { + quant_const(const_input, output_type); + } + + CircleConst *new_bias = nullptr; + + if (granularity == QuantizationGranularity::ChannelWise) + { + auto input_q = input->quantparam(); + assert(input_q); + assert(input_q->scale.size() == 1); // input scale's layer-wise + auto input_scale = input_q->scale[0]; + + assert(weight->quantparam() != nullptr); // weight scale's channel-wise + auto weight_scale = weight->quantparam()->scale; + + uint32_t size = const_bias->size<loco::DataType::FLOAT32>(); + assert(size == weight_scale.size()); + std::vector<float> scaling_factor(size); + std::vector<int64_t> zp(size); + + if (output_type == loco::DataType::U8) + { + new_bias = quant_bias_per_channel(const_bias, input_scale, weight_scale, scaling_factor, zp); + } + else if (output_type == loco::DataType::S16) + { + new_bias = + int16_quant_bias_per_channel(const_bias, input_scale, weight_scale, scaling_factor, zp); + } + else + { + throw std::runtime_error("Unsupported quantization type."); + } + + auto quantparam = std::make_unique<CircleQuantParam>(); + quantparam->scale = scaling_factor; + quantparam->zerop = zp; + assert(new_bias->quantparam() == nullptr); // bias should not be quantized before + new_bias->quantparam(std::move(quantparam)); + + return new_bias; + } + else + { + auto input_q = input->quantparam(); + assert(input_q); + assert(input_q->scale.size() == 1); // Only support per-layer quant + auto input_scale = input_q->scale[0]; + + auto weight_q = weight->quantparam(); + assert(weight_q); + assert(weight_q->scale.size() == 1); // Only support per-layer quant + auto weight_scale = weight_q->scale[0]; + + float scaling_factor{0}; + int64_t zp{0}; + new_bias = + asym_quant_bias_per_layer(const_bias, input_scale, weight_scale, &scaling_factor, &zp); + auto quantparam = std::make_unique<CircleQuantParam>(); + quantparam->scale.push_back(scaling_factor); + quantparam->zerop.push_back(zp); + assert(new_bias->quantparam() == nullptr); // bias should not be quantized before + new_bias->quantparam(std::move(quantparam)); + + return new_bias; + } +} + +void QuantizeBias::visit(luci::CircleConv2D *node) +{ + LOGGER(l); + INFO(l) << "QuantizeBias QuantizeBias::visit node: " << node->name() << std::endl; + + if (auto iwb = IWB(node->input(), node->filter(), node->bias())) + { + auto new_bias = quantized_bias(iwb.input, iwb.weights, iwb.bias); + node->bias(new_bias); + } +} + +void QuantizeBias::visit(luci::CircleDepthwiseConv2D *node) +{ + LOGGER(l); + INFO(l) << "QuantizeBias QuantizeBias::visit node: " << node->name() << std::endl; + + if (auto iwb = IWB(node->input(), node->filter(), node->bias())) + { + auto new_bias = quantized_bias(iwb.input, iwb.weights, iwb.bias); + node->bias(new_bias); + } +} + +void QuantizeBias::visit(luci::CircleTransposeConv *node) +{ + LOGGER(l); + INFO(l) << "QuantizeBias QuantizeBias::visit node: " << node->name() << std::endl; + + if (auto iwb = IWB(node->outBackprop(), node->filter(), node->bias())) + { + auto new_bias = quantized_bias(iwb.input, iwb.weights, iwb.bias); + node->bias(new_bias); + } +} + +void QuantizeBias::visit(luci::CircleFullyConnected *node) +{ + LOGGER(l); + INFO(l) << "QuantizeBias visit node: " << node->name() << std::endl; + + if (auto iwb = IWB(node->input(), node->weights(), node->bias())) + { + auto new_bias = quantized_bias(iwb.input, iwb.weights, iwb.bias); + node->bias(new_bias); + } +} + +} // namespace luci diff --git a/compiler/luci/pass/src/QuantizeBias.h b/compiler/luci/pass/src/QuantizeBias.h new file mode 100644 index 000000000..8de09df72 --- /dev/null +++ b/compiler/luci/pass/src/QuantizeBias.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_QUANTIZE_BIAS_H__ +#define __LUCI_QUANTIZE_BIAS_H__ + +#include <luci/Pass/QuantizationParameters.h> +#include <luci/IR/CircleNodeVisitor.h> + +namespace luci +{ + +/** + * @brief QuantizeBias quantizes tensors for bias + * @details Use input/weights scale to quantize values + */ +struct QuantizeBias final : public luci::CircleNodeMutableVisitor<void> +{ + QuantizeBias(loco::DataType input, loco::DataType output, QuantizationGranularity gr) + : input_type(input), output_type(output), granularity(gr) + { + } + + loco::DataType input_type; + loco::DataType output_type; + QuantizationGranularity granularity; + +private: + // Return a quantized bias node + CircleConst *quantized_bias(CircleNode *input, const CircleNode *weight, CircleNode *bias); + + void visit(luci::CircleConv2D *node); + void visit(luci::CircleDepthwiseConv2D *node); + void visit(luci::CircleTransposeConv *node); + void visit(luci::CircleFullyConnected *node); + + // Default behavior + void visit(luci::CircleNode *) {} +}; + +} // namespace luci + +#endif // __LUCI_QUANTIZE_BIAS_H__ diff --git a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp index c8ad87e3d..c9b35e0be 100644 --- a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp +++ b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp @@ -16,9 +16,11 @@ #include "luci/Pass/QuantizeDequantizeWeightsPass.h" #include "QuantizationUtils.h" +#include "helpers/LayerInfoMap.h" #include <luci/IR/CircleNodes.h> #include <luci/IR/CircleNodeVisitor.h> +#include <luci/Service/Nodes/CircleConst.h> #include <luci/Log.h> #include <loco/IR/TensorShape.h> @@ -251,7 +253,7 @@ void asymmetric_wdequant_with_minmax_per_layer(CircleConst *node, float scaling_ * @brief QuantizeDequantizeWeights quantizes and dequantizes tensors for weights * @details Find min/max values on the fly, quantize the model, and dequantize the model */ -struct QuantizeDequantizeWeights final : public luci::CircleNodeMutableVisitor<bool> +struct QuantizeDequantizeWeights final : public luci::CircleNodeMutableVisitor<void> { QuantizeDequantizeWeights(loco::DataType input, loco::DataType output, QuantizationGranularity granularity) @@ -263,88 +265,164 @@ struct QuantizeDequantizeWeights final : public luci::CircleNodeMutableVisitor<b loco::DataType output_type; QuantizationGranularity granularity; - // Quantize and dequantize input tensors of each node - bool visit(luci::CircleNode *node) +private: + // Fake quantize weights (Only u8 quantization is supported for LWQ) + void fake_quantize_lwq(luci::CircleConst *weights) const { - assert(output_type == loco::DataType::U8 || output_type == loco::DataType::S16); - LOGGER(l); - INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl; - auto arity = node->arity(); - for (uint32_t i = 0; i < arity; i++) + assert(output_type == loco::DataType::U8); // FIX_CALLER_UNLESS + + // Find min/max per layer + float min = std::numeric_limits<float>::max(); + float max = std::numeric_limits<float>::lowest(); + for (uint32_t i = 0; i < weights->size<loco::DataType::FLOAT32>(); i++) { - auto input_node = node->arg(i); - auto circle_node = loco::must_cast<luci::CircleNode *>(input_node); + auto data = weights->at<loco::DataType::FLOAT32>(i); + min = data < min ? data : min; + max = data > max ? data : max; + } + float scaling_factor{0}; + int64_t zp{0}; + float nudged_min{0}; + float nudged_max{0}; + + asymmetric_wquant_with_minmax_per_layer(weights, min, max, scaling_factor, zp, nudged_min, + nudged_max); + asymmetric_wdequant_with_minmax_per_layer(weights, scaling_factor, nudged_min); + auto quantparam = std::make_unique<CircleQuantParam>(); + quantparam->min.push_back(nudged_min); + quantparam->max.push_back(nudged_max); + quantparam->scale.push_back(scaling_factor); + quantparam->zerop.push_back(zp); + weights->quantparam(std::move(quantparam)); + } - // Check if this is already quantized - if (is_quantized(circle_node)) - continue; +private: + // Fake quantize weights (u8/s16 quantization are supported for CWQ) + void fake_quantize_cwq(luci::CircleConst *weights) const + { + assert(output_type == loco::DataType::U8 || + output_type == loco::DataType::S16); // FIX_CALLER_UNLESS - if (is_weights(circle_node)) - { - auto circle_const = loco::must_cast<luci::CircleConst *>(circle_node); + // Find min/max per channel + std::vector<float> min; + std::vector<float> max; - // Find min/max per channel-wise - if (granularity == QuantizationGranularity::ChannelWise) - { - std::vector<float> min; - std::vector<float> max; - - cal_minmax_per_channel(circle_const, min, max); - - std::vector<float> nudged_min(min.size()); - std::vector<float> nudged_max(min.size()); - std::vector<float> scaling_factor(min.size()); - std::vector<int64_t> zp(min.size()); - - if (output_type == loco::DataType::U8) - { - asymmetric_wquant_per_channel(circle_const, min, max, scaling_factor, zp, nudged_min, - nudged_max); - asymmetric_wdequant_per_channel(circle_const, scaling_factor, nudged_min); - } - else - { - sym_wquant_per_channel(circle_const, min, max, scaling_factor, zp, nudged_min, - nudged_max); - sym_wdequant_per_channel(circle_const, scaling_factor); - } - - auto quantparam = std::make_unique<CircleQuantParam>(); - quantparam->min = nudged_min; - quantparam->max = nudged_max; - quantparam->scale = scaling_factor; - quantparam->zerop = zp; - circle_node->quantparam(std::move(quantparam)); - } - // Find min/max per layer-wise - else - { - float min = std::numeric_limits<float>::max(); - float max = std::numeric_limits<float>::lowest(); - for (uint32_t i = 0; i < circle_const->size<loco::DataType::FLOAT32>(); i++) - { - auto data = circle_const->at<loco::DataType::FLOAT32>(i); - min = data < min ? data : min; - max = data > max ? data : max; - } - float scaling_factor{0}; - int64_t zp{0}; - float nudged_min{0}; - float nudged_max{0}; - - asymmetric_wquant_with_minmax_per_layer(circle_const, min, max, scaling_factor, zp, - nudged_min, nudged_max); - asymmetric_wdequant_with_minmax_per_layer(circle_const, scaling_factor, nudged_min); - auto quantparam = std::make_unique<CircleQuantParam>(); - quantparam->min.push_back(nudged_min); - quantparam->max.push_back(nudged_max); - quantparam->scale.push_back(scaling_factor); - quantparam->zerop.push_back(zp); - circle_node->quantparam(std::move(quantparam)); - } - } + cal_minmax_per_channel(weights, min, max); + + std::vector<float> nudged_min(min.size()); + std::vector<float> nudged_max(min.size()); + std::vector<float> scaling_factor(min.size()); + std::vector<int64_t> zp(min.size()); + + if (output_type == loco::DataType::U8) + { + asymmetric_wquant_per_channel(weights, min, max, scaling_factor, zp, nudged_min, nudged_max); + asymmetric_wdequant_per_channel(weights, scaling_factor, nudged_min); + } + else + { + sym_wquant_per_channel(weights, min, max, scaling_factor, zp, nudged_min, nudged_max); + sym_wdequant_per_channel(weights, scaling_factor); } - return false; + + auto quantparam = std::make_unique<CircleQuantParam>(); + quantparam->min = nudged_min; + quantparam->max = nudged_max; + quantparam->scale = scaling_factor; + quantparam->zerop = zp; + weights->quantparam(std::move(quantparam)); + } + +private: + void fake_quantize(luci::CircleConst *weights) const + { + switch (granularity) + { + case luci::QuantizationGranularity::ChannelWise: + fake_quantize_cwq(weights); + break; + case luci::QuantizationGranularity::LayerWise: + fake_quantize_lwq(weights); + break; + default: + throw std::invalid_argument("Unsupported granularity"); + } + } + +private: + // Check if + // 1. node is const + // 2. node was not quantized + bool is_quantizable(loco::Node *node) + { + auto const_node = dynamic_cast<luci::CircleConst *>(node); + if (not const_node) + return false; + + // Skip if this is already quantized + if (is_quantized(const_node)) + return false; + + return true; + } + + // Default behavior (Do nothing) + void visit(luci::CircleNode *) {} + + void visit(luci::CircleConv2D *node) + { + LOGGER(l); + INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl; + + if (not is_quantizable(node->filter())) + return; + + auto weights = loco::must_cast<luci::CircleConst *>(node->filter()); + auto new_weights = luci::clone(weights); + node->filter(new_weights); + fake_quantize(new_weights); + } + + void visit(luci::CircleDepthwiseConv2D *node) + { + LOGGER(l); + INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl; + + if (not is_quantizable(node->filter())) + return; + + auto weights = loco::must_cast<luci::CircleConst *>(node->filter()); + auto new_weights = luci::clone(weights); + node->filter(new_weights); + fake_quantize(new_weights); + } + + void visit(luci::CircleTransposeConv *node) + { + LOGGER(l); + INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl; + + if (not is_quantizable(node->filter())) + return; + + auto weights = loco::must_cast<luci::CircleConst *>(node->filter()); + auto new_weights = luci::clone(weights); + node->filter(new_weights); + fake_quantize(new_weights); + } + + void visit(luci::CircleFullyConnected *node) + { + LOGGER(l); + INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl; + + if (not is_quantizable(node->weights())) + return; + + auto weights = loco::must_cast<luci::CircleConst *>(node->weights()); + auto new_weights = luci::clone(weights); + node->weights(new_weights); + fake_quantize(new_weights); } }; @@ -355,11 +433,36 @@ bool QuantizeDequantizeWeightsPass::run(loco::Graph *g) LOGGER(l); INFO(l) << "QuantizeDequantizeWeightsPass Start" << std::endl; + auto info_by_name = layer_info_map(g, _ctx->layers_info); + + auto quantize_dtype = [&](const luci::CircleNode *node) { + auto iter = info_by_name.find(node->name()); + + // Return designated quantization dtype + if (iter != info_by_name.end()) + return iter->second.dtype; + + // Return default quantization dtype + return _ctx->output_model_dtype; + }; + + auto quantize_granularity = [&](const luci::CircleNode *node) { + auto iter = info_by_name.find(node->name()); + + // Return designated quantization granularity + if (iter != info_by_name.end()) + return iter->second.granularity; + + // Return default quantization granularity + return _ctx->granularity; + }; + // Quantize weights for (auto node : loco::active_nodes(loco::output_nodes(g))) { - QuantizeDequantizeWeights qw(_input_model_dtype, _output_model_dtype, _granularity); auto circle_node = loco::must_cast<luci::CircleNode *>(node); + QuantizeDequantizeWeights qw(_ctx->input_model_dtype, quantize_dtype(circle_node), + quantize_granularity(circle_node)); circle_node->accept(&qw); } diff --git a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.test.cpp b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.test.cpp index f226253c2..15f5ca7ac 100644 --- a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.test.cpp +++ b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.test.cpp @@ -25,3 +25,17 @@ TEST(QuantizeDequantizeWeightsPassTest, name) auto const name = pass.name(); ASSERT_NE(nullptr, name); } + +TEST(QuantizeDequantizeWeightsPassTest, name_ctx) +{ + auto ctx = std::make_unique<luci::QuantizeDequantizeWeightsPass::Context>(); + { + ctx->input_model_dtype = loco::DataType::FLOAT32; + ctx->output_model_dtype = loco::DataType::U8; + ctx->granularity = luci::QuantizationGranularity::LayerWise; + } + + luci::QuantizeDequantizeWeightsPass pass(std::move(ctx)); + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} diff --git a/compiler/luci/pass/src/QuantizePreCheckerPass.cpp b/compiler/luci/pass/src/QuantizePreCheckerPass.cpp new file mode 100644 index 000000000..4b3b7e330 --- /dev/null +++ b/compiler/luci/pass/src/QuantizePreCheckerPass.cpp @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/QuantizePreCheckerPass.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNodeVisitor.h> + +#include <luci/Log.h> + +namespace luci +{ + +namespace +{ + +void check_const_opcode(luci::CircleNode *node) +{ + if (node == nullptr) + return; + + if (node->opcode() != luci::CircleOpcode::CIRCLECONST and + node->opcode() != luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE) + { + throw std::runtime_error("Unsupported non const input " + node->name()); + } +} + +struct ConstInputChecker final : public luci::CircleNodeMutableVisitor<void> +{ +// INPUT_NAME is name for input const for current NODE +#define CHECK_NODE_WITH_ONE_INPUT_CONST(NODE, INPUT_NAME) \ + void visit(NODE *node) \ + { \ + const auto input = dynamic_cast<luci::CircleNode *>(node->INPUT_NAME()); \ + check_const_opcode(input); \ + } + +// INPUT_NAME_1 and INPUT_NAME_2 are names for input const for current NODE +#define CHECK_NODE_WITH_TWO_INPUT_CONST(NODE, INPUT_NAME_1, INPUT_NAME_2) \ + void visit(NODE *node) \ + { \ + const auto input_1 = dynamic_cast<luci::CircleNode *>(node->INPUT_NAME_1()); \ + const auto input_2 = dynamic_cast<luci::CircleNode *>(node->INPUT_NAME_2()); \ + \ + check_const_opcode(input_1); \ + check_const_opcode(input_2); \ + } + +// INPUT_NAME_1, INPUT_NAME_2 and INPUT_NAME_3 are names for input const for current NODE +#define CHECK_NODE_WITH_THREE_INPUT_CONST(NODE, INPUT_NAME_1, INPUT_NAME_2, INPUT_NAME_3) \ + void visit(NODE *node) \ + { \ + const auto input_1 = dynamic_cast<luci::CircleNode *>(node->INPUT_NAME_1()); \ + const auto input_2 = dynamic_cast<luci::CircleNode *>(node->INPUT_NAME_2()); \ + const auto input_3 = dynamic_cast<luci::CircleNode *>(node->INPUT_NAME_3()); \ + \ + check_const_opcode(input_1); \ + check_const_opcode(input_2); \ + check_const_opcode(input_3); \ + } + + // Skip other circle node + void visit(luci::CircleNode *) {} + + // Ops that receive one const nodes as inputs + CHECK_NODE_WITH_ONE_INPUT_CONST(luci::CirclePRelu, alpha) + + // Ops that receive two const node as an inputs + CHECK_NODE_WITH_TWO_INPUT_CONST(luci::CircleConv2D, filter, bias) + CHECK_NODE_WITH_TWO_INPUT_CONST(luci::CircleDepthwiseConv2D, filter, bias) + CHECK_NODE_WITH_TWO_INPUT_CONST(luci::CircleFullyConnected, weights, bias) + CHECK_NODE_WITH_TWO_INPUT_CONST(luci::CircleInstanceNorm, gamma, beta) + + // Ops that receive three const nodes as an inputs + CHECK_NODE_WITH_THREE_INPUT_CONST(luci::CircleTransposeConv, inputSizes, filter, bias) + +#undef CHECK_NODE_WITH_ONE_INPUT_CONST +#undef CHECK_NODE_WITH_TWO_INPUT_CONST +#undef CHECK_NODE_WITH_THREE_INPUT_CONST +}; + +} // namespace + +/** + * Verify the input model has the form acceptable by quantizer + */ +bool QuantizePreCheckerPass::run(loco::Graph *g) +{ + LOGGER(l); + INFO(l) << "QuantizePreCheckerPass Start" << std::endl; + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + // Check const inputs + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + ConstInputChecker checker{}; + circle_node->accept(&checker); + } + + INFO(l) << "QuantizePreCheckerPass End" << std::endl; + + return false; // one time run +} + +} // namespace luci diff --git a/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp b/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp new file mode 100644 index 000000000..788353cd8 --- /dev/null +++ b/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp @@ -0,0 +1,401 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/QuantizePreCheckerPass.h" + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +class SimpleConv2DGraph +{ +public: + SimpleConv2DGraph(bool make_valid) + { + conv2d_node = g.nodes()->create<luci::CircleConv2D>(); + input_1 = g.nodes()->create<luci::CircleInput>(); + filter = g.nodes()->create<luci::CircleConst>(); + + conv2d_node->input(input_1); + conv2d_node->filter(filter); + + if (make_valid) + { + bias = g.nodes()->create<luci::CircleConst>(); + conv2d_node->bias(bias); + } + else + { + input_2 = g.nodes()->create<luci::CircleInput>(); + conv2d_node->bias(input_2); + } + + output = g.nodes()->create<luci::CircleOutput>(); + + auto graph_output = g.outputs()->create(); + output->index(graph_output->index()); + + output->from(conv2d_node); + } + +public: + loco::Graph g; + +private: + luci::CircleConv2D *conv2d_node = nullptr; + luci::CircleInput *input_1 = nullptr; + luci::CircleInput *input_2 = nullptr; + luci::CircleConst *filter = nullptr; + luci::CircleConst *bias = nullptr; + luci::CircleOutput *output = nullptr; +}; + +class SimpleDepthConv2DGraph +{ +public: + SimpleDepthConv2DGraph(bool make_valid) + { + depth_conv2d_node = g.nodes()->create<luci::CircleDepthwiseConv2D>(); + input_1 = g.nodes()->create<luci::CircleInput>(); + filter = g.nodes()->create<luci::CircleConst>(); + + depth_conv2d_node->input(input_1); + depth_conv2d_node->filter(filter); + + if (make_valid) + { + bias = g.nodes()->create<luci::CircleConst>(); + depth_conv2d_node->bias(bias); + } + else + { + input_2 = g.nodes()->create<luci::CircleInput>(); + depth_conv2d_node->bias(input_2); + } + + output = g.nodes()->create<luci::CircleOutput>(); + + auto graph_output = g.outputs()->create(); + output->index(graph_output->index()); + + output->from(depth_conv2d_node); + } + +public: + loco::Graph g; + +private: + luci::CircleDepthwiseConv2D *depth_conv2d_node = nullptr; + luci::CircleInput *input_1 = nullptr; + luci::CircleInput *input_2 = nullptr; + luci::CircleConst *filter = nullptr; + luci::CircleConst *bias = nullptr; + luci::CircleOutput *output = nullptr; +}; + +class SimpleFCGraph +{ +public: + SimpleFCGraph(bool make_valid) + { + fc_node = g.nodes()->create<luci::CircleFullyConnected>(); + input_1 = g.nodes()->create<luci::CircleInput>(); + weights = g.nodes()->create<luci::CircleConst>(); + + fc_node->input(input_1); + fc_node->weights(weights); + + if (make_valid) + { + bias = g.nodes()->create<luci::CircleConst>(); + fc_node->bias(bias); + } + else + { + input_2 = g.nodes()->create<luci::CircleInput>(); + fc_node->bias(input_2); + } + + output = g.nodes()->create<luci::CircleOutput>(); + + auto graph_output = g.outputs()->create(); + output->index(graph_output->index()); + + output->from(fc_node); + } + +public: + loco::Graph g; + +private: + luci::CircleFullyConnected *fc_node = nullptr; + luci::CircleInput *input_1 = nullptr; + luci::CircleInput *input_2 = nullptr; + luci::CircleConst *weights = nullptr; + luci::CircleConst *bias = nullptr; + luci::CircleOutput *output = nullptr; +}; + +class SimpleInstanceNormGraph +{ +public: + SimpleInstanceNormGraph(bool make_valid) + { + instance_norm_node = g.nodes()->create<luci::CircleInstanceNorm>(); + input_1 = g.nodes()->create<luci::CircleInput>(); + gamma = g.nodes()->create<luci::CircleConst>(); + + instance_norm_node->input(input_1); + instance_norm_node->gamma(gamma); + + if (make_valid) + { + beta = g.nodes()->create<luci::CircleConst>(); + instance_norm_node->beta(beta); + } + else + { + input_2 = g.nodes()->create<luci::CircleInput>(); + instance_norm_node->beta(input_2); + } + + output = g.nodes()->create<luci::CircleOutput>(); + + auto graph_output = g.outputs()->create(); + output->index(graph_output->index()); + + output->from(instance_norm_node); + } + +public: + loco::Graph g; + +private: + luci::CircleInstanceNorm *instance_norm_node = nullptr; + luci::CircleInput *input_1 = nullptr; + luci::CircleInput *input_2 = nullptr; + luci::CircleConst *gamma = nullptr; + luci::CircleConst *beta = nullptr; + luci::CircleOutput *output = nullptr; +}; + +class SimpleTransposeConvGraph +{ +public: + SimpleTransposeConvGraph(bool make_valid) + { + transpose_conv = g.nodes()->create<luci::CircleTransposeConv>(); + input_1 = g.nodes()->create<luci::CircleInput>(); + + input_sizes = g.nodes()->create<luci::CircleConst>(); + filter = g.nodes()->create<luci::CircleConst>(); + + transpose_conv->outBackprop(input_1); + transpose_conv->filter(filter); + transpose_conv->inputSizes(input_sizes); + + if (make_valid) + { + bias = g.nodes()->create<luci::CircleConst>(); + transpose_conv->bias(bias); + } + else + { + input_2 = g.nodes()->create<luci::CircleInput>(); + transpose_conv->bias(input_2); + } + + output = g.nodes()->create<luci::CircleOutput>(); + + auto graph_output = g.outputs()->create(); + output->index(graph_output->index()); + + output->from(transpose_conv); + } + +public: + loco::Graph g; + +private: + luci::CircleTransposeConv *transpose_conv = nullptr; + luci::CircleInput *input_1 = nullptr; + luci::CircleInput *input_2 = nullptr; + luci::CircleConst *input_sizes = nullptr; + luci::CircleConst *filter = nullptr; + luci::CircleConst *bias = nullptr; + luci::CircleOutput *output = nullptr; +}; + +class SimplePReluGraph +{ +public: + SimplePReluGraph(bool make_valid) + { + prelu = g.nodes()->create<luci::CirclePRelu>(); + input_1 = g.nodes()->create<luci::CircleInput>(); + + prelu->input(input_1); + + if (make_valid) + { + alpha = g.nodes()->create<luci::CircleConst>(); + prelu->alpha(alpha); + } + else + { + input_2 = g.nodes()->create<luci::CircleInput>(); + prelu->alpha(input_2); + } + + output = g.nodes()->create<luci::CircleOutput>(); + + auto graph_output = g.outputs()->create(); + output->index(graph_output->index()); + + output->from(prelu); + } + +public: + loco::Graph g; + +private: + luci::CirclePRelu *prelu = nullptr; + luci::CircleInput *input_1 = nullptr; + luci::CircleInput *input_2 = nullptr; + luci::CircleConst *alpha = nullptr; + luci::CircleOutput *output = nullptr; +}; + +TEST(QuantizePreCheckerPassTest, name) +{ + luci::QuantizePreCheckerPass pass{}; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +// Test Conv2d +TEST(QuantizePreCheckerPassTest, conv2d) +{ + SimpleConv2DGraph valid_graph(true); + + luci::QuantizePreCheckerPass checker{}; + + EXPECT_NO_THROW(checker.run(&valid_graph.g)); +} + +TEST(QuantizePreCheckerPassTest, conv2d_NEG) +{ + SimpleConv2DGraph invalid_graph(false); + + luci::QuantizePreCheckerPass checker{}; + + EXPECT_ANY_THROW(checker.run(&invalid_graph.g)); +} + +// Test DepthwiseConv2d +TEST(QuantizePreCheckerPassTest, depthwise_conv2d) +{ + SimpleDepthConv2DGraph valid_graph(true); + + luci::QuantizePreCheckerPass checker{}; + + EXPECT_NO_THROW(checker.run(&valid_graph.g)); +} + +TEST(QuantizePreCheckerPassTest, depthwise_conv2d_NEG) +{ + SimpleDepthConv2DGraph invalid_graph(false); + + luci::QuantizePreCheckerPass checker{}; + + EXPECT_ANY_THROW(checker.run(&invalid_graph.g)); +} + +// Test FullyConnected +TEST(QuantizePreCheckerPassTest, fully_connected) +{ + SimpleFCGraph valid_graph(true); + + luci::QuantizePreCheckerPass checker{}; + + EXPECT_NO_THROW(checker.run(&valid_graph.g)); +} + +TEST(QuantizePreCheckerPassTest, fully_connected_NEG) +{ + SimpleFCGraph invalid_graph(false); + + luci::QuantizePreCheckerPass checker{}; + + EXPECT_ANY_THROW(checker.run(&invalid_graph.g)); +} + +// Test InstanceNorm +TEST(QuantizePreCheckerPassTest, instance_norm) +{ + SimpleInstanceNormGraph valid_graph(true); + + luci::QuantizePreCheckerPass checker{}; + + EXPECT_NO_THROW(checker.run(&valid_graph.g)); +} + +TEST(QuantizePreCheckerPassTest, instance_norm_NEG) +{ + SimpleInstanceNormGraph invalid_graph(false); + + luci::QuantizePreCheckerPass checker{}; + + EXPECT_ANY_THROW(checker.run(&invalid_graph.g)); +} + +// Test TransposeConv +TEST(QuantizePreCheckerPassTest, transpose_conv) +{ + SimpleTransposeConvGraph valid_graph(true); + + luci::QuantizePreCheckerPass checker{}; + + EXPECT_NO_THROW(checker.run(&valid_graph.g)); +} + +TEST(QuantizePreCheckerPassTest, transpose_conv_NEG) +{ + SimpleTransposeConvGraph invalid_graph(false); + + luci::QuantizePreCheckerPass checker{}; + + EXPECT_ANY_THROW(checker.run(&invalid_graph.g)); +} + +// Test PRelu +TEST(QuantizePreCheckerPassTest, prelu) +{ + SimplePReluGraph valid_graph(true); + + luci::QuantizePreCheckerPass checker{}; + + EXPECT_NO_THROW(checker.run(&valid_graph.g)); +} + +TEST(QuantizePreCheckerPassTest, prelu_NEG) +{ + SimplePReluGraph invalid_graph(false); + + luci::QuantizePreCheckerPass checker{}; + + EXPECT_ANY_THROW(checker.run(&invalid_graph.g)); +} diff --git a/compiler/luci/pass/src/QuantizeWeights.cpp b/compiler/luci/pass/src/QuantizeWeights.cpp new file mode 100644 index 000000000..11322ab44 --- /dev/null +++ b/compiler/luci/pass/src/QuantizeWeights.cpp @@ -0,0 +1,394 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "QuantizeWeights.h" +#include "QuantizationUtils.h" + +#include <luci/Service/Nodes/CircleConst.h> +#include <luci/Log.h> + +#include <cmath> +#include <vector> +#include <functional> + +using namespace luci; + +namespace +{ + +using IterFunc = std::function<void(uint32_t *, loco::TensorShape &, int32_t)>; + +void iterate_per_channel(CircleConst *node, int32_t &channel_dim_index, IterFunc func) +{ + loco::TensorShape dimension; + dimension.rank(4); + uint32_t indices[4] = { + 0, + }; + + if (!get_channel_dim_index(node, dimension, channel_dim_index)) + { + assert(false); + return; + } + + for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++) + { + for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++) + { + for (indices[2] = 0; indices[2] < dimension.dim(2).value(); indices[2]++) + { + for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++) + { + func(indices, dimension, channel_dim_index); + } + } + } + } +} + +void asym_wquant_per_channel(CircleConst *node, std::vector<float> &min, + std::vector<float> &scaling_factor, int32_t &channel_dim_index) +{ + assert(node->dtype() == loco::DataType::FLOAT32); + + const int32_t kMinScale = 0; + const int32_t kMaxScale = 255; + + uint32_t size = node->size<loco::DataType::FLOAT32>(); + std::vector<int32_t> quantized_values(size); + + auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int32_t channel_dim_index) { + int channel_idx = indices[channel_dim_index]; + const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx]; + auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices)); + quantized_values[cal_offset(dimension, indices)] = + static_cast<int32_t>(std::round((data - min[channel_idx]) * scaling_factor_inv)); + }; + + iterate_per_channel(node, channel_dim_index, quantize); + + node->dtype(loco::DataType::U8); // change the type of tensor + node->size<loco::DataType::U8>(size); // resize tensor + for (uint32_t i = 0; i < size; ++i) + { + node->at<loco::DataType::U8>(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i])); + } +} + +void sym_wquant_per_channel(CircleConst *node, std::vector<float> &scaling_factor, + int32_t &channel_dim_index) +{ + assert(node->dtype() == loco::DataType::FLOAT32); + + const int32_t kMaxScale = std::numeric_limits<int16_t>::max(); + const int32_t kMinScale = -kMaxScale; + + uint32_t size = node->size<loco::DataType::FLOAT32>(); + std::vector<int32_t> quantized_values(size); + + auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int32_t channel_dim_index) { + int channel_idx = indices[channel_dim_index]; + const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx]; + auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices)); + quantized_values[cal_offset(dimension, indices)] = + static_cast<int32_t>(std::round(data * scaling_factor_inv)); + }; + + iterate_per_channel(node, channel_dim_index, quantize); + + node->dtype(loco::DataType::S16); // change the type of tensor + node->size<loco::DataType::S16>(size); // resize tensor + for (uint32_t i = 0; i < size; ++i) + { + node->at<loco::DataType::S16>(i) = + std::min(kMaxScale, std::max(kMinScale, quantized_values[i])); + } +} + +void asym_wquant_per_layer(CircleConst *node, float min, float scaling_factor) +{ + const int32_t kMinScale = 0; + const int32_t kMaxScale = 255; + + uint32_t size = node->size<loco::DataType::FLOAT32>(); + + const float scaling_factor_inv = 1.0 / scaling_factor; + std::vector<int32_t> quantized_values(size); + for (uint32_t i = 0; i < size; ++i) + { + auto data = node->at<loco::DataType::FLOAT32>(i); + quantized_values[i] = static_cast<int32_t>(std::round((data - min) * scaling_factor_inv)); + } + + node->dtype(loco::DataType::U8); // change the type of tensor + node->size<loco::DataType::U8>(size); // resize tensor + for (uint32_t i = 0; i < size; ++i) + { + node->at<loco::DataType::U8>(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i])); + } +} + +// Quantize const per channel +// +// The last dimension of const is the same as the dimension of channel +// And the rest of the const dimensions should be 1 +// So, a 'single value' is quantized per channel +// +// Quantization spec (f: fp value, q: quantized value) +// +// uint8 +// Positive f: f = f * (q - 0) [q = 1, scale = f, zp = 0] +// Negative f: f = (-f) * (q - 1) [q = 0, scale = -f, zp = 1] +// +// int16 +// Positive f: f = f * (q - 0) [q = 1, scale = f, zp = 0] +// Negative f: f = (-f) * (q - 0) [q = -1, scale = -f, zp = 0] +void quant_const_per_channel(CircleConst *node, loco::DataType quant_type) +{ + assert(node->dtype() == loco::DataType::FLOAT32); + assert(node->rank() > 0); + + for (uint32_t i = 0; i < node->rank() - 1; i++) + { + // Caller should call this function when the below condition is satisfied + if (node->dim(i).value() != 1) + throw std::runtime_error("Non-channel dimension of const node must be 1"); + } + + uint32_t size = node->size<loco::DataType::FLOAT32>(); + assert(size == node->dim(node->rank() - 1).value()); + + auto quantparam = std::make_unique<CircleQuantParam>(); + quantparam->quantized_dimension = node->rank() - 1; + std::vector<int32_t> quantized_data(size); + + for (uint32_t i = 0; i < size; ++i) + { + auto data = node->at<loco::DataType::FLOAT32>(i); + if (quant_type == loco::DataType::U8) + { + if (data >= 0) + { + quantparam->scale.push_back(data); + quantparam->zerop.push_back(0); + quantized_data[i] = 1; + } + else + { + quantparam->scale.push_back(-data); + quantparam->zerop.push_back(1); + quantized_data[i] = 0; + } + } + else if (quant_type == loco::DataType::S16) + { + if (data >= 0) + { + quantparam->scale.push_back(data); + quantized_data[i] = 1; + } + else + { + quantparam->scale.push_back(-data); + quantized_data[i] = -1; + } + quantparam->zerop.push_back(0); + } + } + node->quantparam(std::move(quantparam)); + + switch (quant_type) + { + case loco::DataType::U8: + node->dtype(loco::DataType::U8); + node->size<loco::DataType::U8>(size); + for (uint32_t i = 0; i < size; ++i) + { + assert(quantized_data[i] == 0 || quantized_data[i] == 1); + node->at<loco::DataType::U8>(i) = quantized_data[i]; + } + break; + case loco::DataType::S16: + node->dtype(loco::DataType::S16); + node->size<loco::DataType::S16>(size); + for (uint32_t i = 0; i < size; ++i) + { + assert(quantized_data[i] == -1 || quantized_data[i] == 1); + node->at<loco::DataType::S16>(i) = quantized_data[i]; + } + break; + default: + throw std::runtime_error("Unsupported data type"); + } +} + +} // namespace + +namespace luci +{ + +void QuantizeWeights::quantize_weights(luci::CircleConst *weights) +{ + // Find min/max per channel-wise + if (granularity == QuantizationGranularity::ChannelWise) + { + auto quantparam = weights->quantparam(); + if (quantparam == nullptr) + { + assert(false && "quantparam is nullptr"); + return; + } + + auto min = quantparam->min; + auto scaling_factor = quantparam->scale; + int32_t channel_dim_index = 0; + + if (output_type == loco::DataType::U8) + { + asym_wquant_per_channel(weights, min, scaling_factor, channel_dim_index); + } + else + { + sym_wquant_per_channel(weights, scaling_factor, channel_dim_index); + } + quantparam->min.clear(); + quantparam->max.clear(); + quantparam->quantized_dimension = channel_dim_index; + } + // Find min/max per layer-wise + else + { + // Quantize using recorded quantparam + auto quantparam = weights->quantparam(); + assert(quantparam != nullptr); + assert(quantparam->min.size() == 1); // only support layer-wise quant + assert(quantparam->scale.size() == 1); // only support layer-wise quant + auto min = quantparam->min[0]; + auto scaling_factor = quantparam->scale[0]; + asym_wquant_per_layer(weights, min, scaling_factor); + quantparam->min.clear(); + quantparam->max.clear(); + } +} +void QuantizeWeights::visit(luci::CircleConv2D *node) +{ + LOGGER(l); + INFO(l) << "QuantizeWeights QuantizeWeights::visit node: " << node->name() << std::endl; + + auto weights = loco::must_cast<luci::CircleConst *>(node->filter()); + if (!is_quantized(weights)) + { + auto new_weights = luci::clone(weights); + node->filter(new_weights); + quantize_weights(new_weights); + } +} + +void QuantizeWeights::visit(luci::CircleDepthwiseConv2D *node) +{ + LOGGER(l); + INFO(l) << "QuantizeWeights QuantizeWeights::visit node: " << node->name() << std::endl; + + auto weights = loco::must_cast<luci::CircleConst *>(node->filter()); + if (!is_quantized(weights)) + { + auto new_weights = luci::clone(weights); + node->filter(new_weights); + quantize_weights(new_weights); + } +} + +void QuantizeWeights::visit(luci::CircleInstanceNorm *node) +{ + LOGGER(l); + INFO(l) << "QuantizeWeights QuantizeWeights::visit node: " << node->name() << std::endl; + + auto gamma = loco::must_cast<luci::CircleConst *>(node->gamma()); + auto beta = loco::must_cast<luci::CircleConst *>(node->beta()); + + if (!is_quantized(gamma)) + { + assert(gamma->dtype() == loco::DataType::FLOAT32); + auto new_gamma = luci::clone(gamma); + if (granularity == QuantizationGranularity::LayerWise) + quant_const(new_gamma, output_type); + else if (granularity == QuantizationGranularity::ChannelWise) + quant_const_per_channel(new_gamma, output_type); + node->gamma(new_gamma); + } + if (!is_quantized(beta)) + { + assert(beta->dtype() == loco::DataType::FLOAT32); + auto new_beta = luci::clone(beta); + if (granularity == QuantizationGranularity::LayerWise) + quant_const(new_beta, output_type); + else if (granularity == QuantizationGranularity::ChannelWise) + quant_const_per_channel(new_beta, output_type); + node->beta(new_beta); + } +} + +void QuantizeWeights::visit(luci::CirclePRelu *node) +{ + LOGGER(l); + INFO(l) << "QuantizeWeights QuantizeWeights::visit node: " << node->name() << std::endl; + + auto alpha = loco::must_cast<luci::CircleConst *>(node->alpha()); + + if (!is_quantized(alpha)) + { + assert(alpha->dtype() == loco::DataType::FLOAT32); + auto new_alpha = luci::clone(alpha); + if (granularity == QuantizationGranularity::LayerWise) + quant_const(new_alpha, output_type); + else if (granularity == QuantizationGranularity::ChannelWise) + quant_const_per_channel(new_alpha, output_type); + node->alpha(new_alpha); + } +} + +void QuantizeWeights::visit(luci::CircleTransposeConv *node) +{ + LOGGER(l); + INFO(l) << "QuantizeWeights QuantizeWeights::visit node: " << node->name() << std::endl; + + auto weights = loco::must_cast<luci::CircleConst *>(node->filter()); + if (!is_quantized(weights)) + { + auto new_weights = luci::clone(weights); + node->filter(new_weights); + quantize_weights(new_weights); + } +} + +void QuantizeWeights::visit(luci::CircleFullyConnected *node) +{ + LOGGER(l); + INFO(l) << "QuantizeWeights QuantizeWeights::visit node: " << node->name() << std::endl; + + auto weights = loco::must_cast<luci::CircleConst *>(node->weights()); + if (!is_quantized(weights)) + { + auto new_weights = luci::clone(weights); + node->weights(new_weights); + quantize_weights(new_weights); + } +} + +void QuantizeWeights::visit(luci::CircleNode *) {} + +} // namespace luci diff --git a/compiler/luci/pass/src/QuantizeWeights.h b/compiler/luci/pass/src/QuantizeWeights.h new file mode 100644 index 000000000..f62cd40f3 --- /dev/null +++ b/compiler/luci/pass/src/QuantizeWeights.h @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_QUANTIZE_WEIGHTS_H__ +#define __LUCI_QUANTIZE_WEIGHTS_H__ + +#include <luci/Pass/QuantizationParameters.h> +#include <luci/IR/CircleNodeVisitor.h> + +namespace luci +{ + +/** + * @brief QuantizeWeights quantizes tensors for weights + * @details Find min/max values on the fly and then quantize + */ +struct QuantizeWeights final : public luci::CircleNodeMutableVisitor<void> +{ + QuantizeWeights(loco::DataType input, loco::DataType output, QuantizationGranularity gr) + : input_type(input), output_type(output), granularity(gr) + { + } + + loco::DataType input_type; + loco::DataType output_type; + QuantizationGranularity granularity; + +private: + void quantize_weights(luci::CircleConst *weights); + + void visit(luci::CircleConv2D *node); + void visit(luci::CircleDepthwiseConv2D *node); + void visit(luci::CircleInstanceNorm *node); + void visit(luci::CirclePRelu *node); + void visit(luci::CircleTransposeConv *node); + void visit(luci::CircleFullyConnected *node); + void visit(luci::CircleNode *); +}; + +} // namespace luci + +#endif // __LUCI_QUANTIZE_WEIGHTS_H__ diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp index c3552ec52..d9a9d4db7 100644 --- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp +++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp @@ -15,55 +15,32 @@ */ #include "luci/Pass/QuantizeWithMinMaxPass.h" +#include "luci/Pass/PropagateQParamForwardPass.h" +#include "luci/Pass/PropagateQParamBackwardPass.h" +#include "luci/Pass/RemoveRedundantQuantizePass.h" +#include "QuantizeActivation.h" +#include "QuantizeWeights.h" +#include "QuantizeBias.h" #include "QuantizationUtils.h" +#include "ProgressReporter.h" +#include "helpers/LayerInfoMap.h" #include <luci/IR/CircleNodes.h> #include <luci/IR/CircleNodeVisitor.h> #include <luci/Service/Nodes/CircleConst.h> #include <luci/Profile/CircleNodeOrigin.h> #include <luci/Log.h> +#include <logo/Phase.h> #include <oops/UserExn.h> #include <iostream> #include <cmath> -#include <functional> namespace { using namespace luci; -using IterFunc = std::function<void(uint32_t *, loco::TensorShape &, int32_t)>; - -void iterate_per_channel(CircleConst *node, int32_t &channel_dim_index, IterFunc func) -{ - loco::TensorShape dimension; - dimension.rank(4); - uint32_t indices[4] = { - 0, - }; - - if (!get_channel_dim_index(node, dimension, channel_dim_index)) - { - assert(false); - return; - } - - for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++) - { - for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++) - { - for (indices[2] = 0; indices[2] < dimension.dim(2).value(); indices[2]++) - { - for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++) - { - func(indices, dimension, channel_dim_index); - } - } - } - } -} - // Create a Quantize Op whose // dtype is out_type // shape is the same with node @@ -80,7 +57,17 @@ luci::CircleQuantize *create_quantize_op(luci::CircleNode *node, loco::DataType quantize->shape_status(luci::ShapeStatus::VALID); auto qparam = node->quantparam(); - assert(qparam); // FIX_CALLER_UNLESS + assert(qparam); // FIX_CALLER_UNLESS + + auto qtype = luci::activation_qtype(node); + if (qtype == ActivationQType::PreDefinedValue) + { + quantize->quantparam(luci::make_predefined_qparam(node->opcode(), out_type)); + return quantize; + } + + assert(qtype == ActivationQType::MinMax or qtype == ActivationQType::IntScale); + assert(qparam->min.size() == 1); // FIX_CALLER_UNLESS assert(qparam->max.size() == 1); // FIX_CALLER_UNLESS auto min = qparam->min[0]; @@ -104,9 +91,17 @@ luci::CircleQuantize *create_quantize_op(luci::CircleNode *node, loco::DataType auto quantparam = std::make_unique<CircleQuantParam>(); quantparam->scale.push_back(scaling_factor); quantparam->zerop.push_back(zp); + // Save original min/max (not nudged_min/max). Nudged min/max + // is different from the real min/max values, causing wrong + // qparam when quantization dtype is changed. + quantparam->min.push_back(min); + quantparam->max.push_back(max); quantize->quantparam(std::move(quantparam)); + if (qtype == ActivationQType::IntScale) + set_int_scale(quantize); + return quantize; } @@ -118,1412 +113,232 @@ namespace luci namespace { -// Create a new const node from an existing node. -// The new node has the following characteristics -// type: T -// shape: same with 'node' (given as an argument) -// buffer size: 'size' (given as an argument) -// Note that contents are not filled in this function. -template <loco::DataType T> -luci::CircleConst *create_empty_const_from(luci::CircleConst *node, uint32_t size) -{ - auto new_node = node->graph()->nodes()->create<CircleConst>(); - // TODO: We don't have any naming convention for quantized nodes yet. - // Fix this when we have one. - new_node->name(node->name()); - new_node->dtype(T); - new_node->rank(node->rank()); - for (uint32_t i = 0; i < node->rank(); i++) - new_node->dim(i).set(node->dim(i).value()); - - new_node->size<T>(size); - new_node->shape_status(luci::ShapeStatus::VALID); - - return new_node; -} - -void overwrite_quantparam(luci::CircleNode *source, luci::CircleNode *target) -{ - auto source_qparam = source->quantparam(); - if (source_qparam == nullptr) - throw std::runtime_error("source quantparam is not found during overwrite"); - - auto target_qparam = target->quantparam(); - if (target_qparam == nullptr) - { - auto quantparam = std::make_unique<CircleQuantParam>(); - target->quantparam(std::move(quantparam)); - target_qparam = target->quantparam(); - - if (target_qparam == nullptr) - throw std::runtime_error("Creating new quant param failed"); - } - target_qparam->min = source_qparam->min; - target_qparam->max = source_qparam->max; - target_qparam->scale = source_qparam->scale; - target_qparam->zerop = source_qparam->zerop; - target_qparam->quantized_dimension = source_qparam->quantized_dimension; -} - -void quant_const_values(luci::CircleConst *const_node, float scaling_factor, float zerop, - loco::DataType quant_type) -{ - uint32_t size = const_node->size<loco::DataType::FLOAT32>(); - - const float scaling_factor_inv = 1.0 / scaling_factor; - std::vector<int32_t> quantized_values(size); - for (uint32_t i = 0; i < size; ++i) - { - auto data = static_cast<double>(const_node->at<loco::DataType::FLOAT32>(i)); - double quantized_float = std::round(data * scaling_factor_inv) + zerop; - constexpr auto int_max = static_cast<double>(std::numeric_limits<int32_t>::max()); - constexpr auto int_min = static_cast<double>(std::numeric_limits<int32_t>::min()); - quantized_float = std::min(int_max, std::max(int_min, quantized_float)); - - quantized_values[i] = static_cast<int32_t>(quantized_float); - } - - switch (quant_type) - { - case loco::DataType::U8: - const_node->dtype(loco::DataType::U8); // change the type of tensor - const_node->size<loco::DataType::U8>(size); // resize tensor - for (uint32_t i = 0; i < size; ++i) - const_node->at<loco::DataType::U8>(i) = std::min(255, std::max(0, quantized_values[i])); - break; - case loco::DataType::S16: - assert(zerop == 0); - const_node->dtype(loco::DataType::S16); // change the type of tensor - const_node->size<loco::DataType::S16>(size); // resize tensor - for (uint32_t i = 0; i < size; ++i) - const_node->at<loco::DataType::S16>(i) = - std::min(32767, std::max(-32767, quantized_values[i])); - break; - default: - throw std::runtime_error("Unsupported data type"); - } -} - -// Quantize const per channel -// -// The last dimension of const is the same as the dimension of channel -// And the rest of the const dimensions should be 1 -// So, a 'single value' is quantized per channel -// -// Quantization spec (f: fp value, q: quantized value) -// -// uint8 -// Positive f: f = f * (q - 0) [q = 1, scale = f, zp = 0] -// Negative f: f = (-f) * (q - 1) [q = 0, scale = -f, zp = 1] -// -// int16 -// Positive f: f = f * (q - 0) [q = 1, scale = f, zp = 0] -// Negative f: f = (-f) * (q - 0) [q = -1, scale = -f, zp = 0] -void quant_const_per_channel(CircleConst *node, loco::DataType quant_type) -{ - assert(node->dtype() == loco::DataType::FLOAT32); - assert(node->rank() > 0); - - for (uint32_t i = 0; i < node->rank() - 1; i++) - { - // Caller should call this function when the below condition is satisfied - if (node->dim(i).value() != 1) - throw std::runtime_error("Non-channel dimension of const node must be 1"); - } - - uint32_t size = node->size<loco::DataType::FLOAT32>(); - assert(size == node->dim(node->rank() - 1).value()); - - auto quantparam = std::make_unique<CircleQuantParam>(); - quantparam->quantized_dimension = node->rank() - 1; - std::vector<int32_t> quantized_data(size); - - for (uint32_t i = 0; i < size; ++i) - { - auto data = node->at<loco::DataType::FLOAT32>(i); - if (quant_type == loco::DataType::U8) - { - if (data >= 0) - { - quantparam->scale.push_back(data); - quantparam->zerop.push_back(0); - quantized_data[i] = 1; - } - else - { - quantparam->scale.push_back(-data); - quantparam->zerop.push_back(1); - quantized_data[i] = 0; - } - } - else if (quant_type == loco::DataType::S16) - { - if (data >= 0) - { - quantparam->scale.push_back(data); - quantized_data[i] = 1; - } - else - { - quantparam->scale.push_back(-data); - quantized_data[i] = -1; - } - quantparam->zerop.push_back(0); - } - } - node->quantparam(std::move(quantparam)); - - switch (quant_type) - { - case loco::DataType::U8: - node->dtype(loco::DataType::U8); - node->size<loco::DataType::U8>(size); - for (uint32_t i = 0; i < size; ++i) - { - assert(quantized_data[i] == 0 || quantized_data[i] == 1); - node->at<loco::DataType::U8>(i) = quantized_data[i]; - } - break; - case loco::DataType::S16: - node->dtype(loco::DataType::S16); - node->size<loco::DataType::S16>(size); - for (uint32_t i = 0; i < size; ++i) - { - assert(quantized_data[i] == -1 || quantized_data[i] == 1); - node->at<loco::DataType::S16>(i) = quantized_data[i]; - } - break; - default: - throw std::runtime_error("Unsupported data type"); - } -} - -void quant_const(CircleConst *node, loco::DataType quant_type) -{ - assert(node->dtype() == loco::DataType::FLOAT32); - - float min = std::numeric_limits<float>::max(); - float max = std::numeric_limits<float>::lowest(); - for (uint32_t i = 0; i < node->size<loco::DataType::FLOAT32>(); i++) - { - auto data = node->at<loco::DataType::FLOAT32>(i); - min = data < min ? data : min; - max = data > max ? data : max; - } - - float scaling_factor{0.0}; - int64_t zp{0}; - float nudged_min{0.0}; - float nudged_max{0.0}; - - switch (quant_type) - { - case loco::DataType::U8: - asymmetric_wquant_with_minmax_per_layer(node, min, max, scaling_factor, zp, nudged_min, - nudged_max); - break; - case loco::DataType::S16: - symmetric_wquant_with_minmax_per_layer(node, min, max, scaling_factor, zp, nudged_min, - nudged_max); - break; - default: - throw std::runtime_error("Unsupported data type"); - } - - auto quantparam = std::make_unique<CircleQuantParam>(); - quantparam->scale.push_back(scaling_factor); - quantparam->zerop.push_back(zp); - node->quantparam(std::move(quantparam)); -} - -// Check if the node is the bias of Conv2D, DepthwiseConv2D, FullyConnected, or TransposeConv layer -// Returns a list of <input, weights, output> vectors for the above operators. -// Note that it returns a 'list' because bias can be used by multiple operators. -std::vector<std::vector<loco::Node *>> get_input_weight_output_of_bias(CircleNode *node) -{ - std::vector<std::vector<loco::Node *>> result; - auto circle_const = dynamic_cast<CircleConst *>(node); - if (circle_const == nullptr) - return result; - - auto succs = loco::succs(node); - - for (auto out : succs) - { - auto conv = dynamic_cast<CircleConv2D *>(out); - if (conv != nullptr && conv->bias() == circle_const) - { - assert(conv->input() != nullptr); - assert(conv->filter() != nullptr); - result.push_back({conv->input(), conv->filter(), conv}); - continue; - } - auto dw_conv = dynamic_cast<CircleDepthwiseConv2D *>(out); - if (dw_conv != nullptr && dw_conv->bias() == circle_const) - { - assert(dw_conv->input() != nullptr); - assert(dw_conv->filter() != nullptr); - result.push_back({dw_conv->input(), dw_conv->filter(), dw_conv}); - continue; - } - auto fc = dynamic_cast<CircleFullyConnected *>(out); - if (fc != nullptr && fc->bias() == circle_const) - { - assert(fc->input() != nullptr); - assert(fc->weights() != nullptr); - result.push_back({fc->input(), fc->weights(), fc}); - continue; - } - auto tconv = dynamic_cast<CircleTransposeConv *>(out); - if (tconv != nullptr && tconv->bias() == circle_const) - { - assert(tconv->outBackprop() != nullptr); - assert(tconv->filter() != nullptr); - result.push_back({tconv->outBackprop(), tconv->filter(), tconv}); - continue; - } - } - return result; -} - -CircleConst *asym_quant_bias_per_layer(CircleConst *node, float input_scale, float weight_scale, - float *scaling_factor, int64_t *zp) -{ - float scale = input_scale * weight_scale; - const float scaling_factor_inv = (scale == 0) ? 0 : 1.0 / scale; - - uint32_t size = node->size<loco::DataType::FLOAT32>(); - std::vector<int32_t> quantized_values(size); - for (uint32_t i = 0; i < size; ++i) - { - quantized_values[i] = - static_cast<int32_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv)); - } - - auto new_bias = create_empty_const_from<loco::DataType::S32>(node, size); - - const int32_t kMinScale = std::numeric_limits<int32_t>::lowest(); - const int32_t kMaxScale = std::numeric_limits<int32_t>::max(); - for (uint32_t i = 0; i < size; ++i) - { - new_bias->at<loco::DataType::S32>(i) = - std::min(kMaxScale, std::max(kMinScale, quantized_values[i])); - } - *scaling_factor = scale; - *zp = 0; - - return new_bias; -} - -CircleConst *quant_bias_per_channel(CircleConst *node, float input_scale, - std::vector<float> &weight_scale, - std::vector<float> &scaling_factor, std::vector<int64_t> &zp) -{ - float scaling_factor_inv{0}; - - uint32_t size = node->size<loco::DataType::FLOAT32>(); - std::vector<int32_t> quantized_values(size); - - for (uint32_t i = 0; i < size; ++i) - { - scaling_factor[i] = input_scale * weight_scale[i]; - scaling_factor_inv = (scaling_factor[i] == 0) ? 0 : 1.0 / scaling_factor[i]; - quantized_values[i] = - static_cast<int32_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv)); - zp[i] = 0; - } - - auto new_bias = create_empty_const_from<loco::DataType::S32>(node, size); - - const int32_t kMinScale = std::numeric_limits<int32_t>::lowest(); - const int32_t kMaxScale = std::numeric_limits<int32_t>::max(); - for (uint32_t i = 0; i < size; ++i) - { - new_bias->at<loco::DataType::S32>(i) = - std::min(kMaxScale, std::max(kMinScale, quantized_values[i])); - } - - return new_bias; -} - -CircleConst *int16_quant_bias_per_channel(CircleConst *node, float input_scale, - std::vector<float> &weight_scale, - std::vector<float> &scaling_factor, - std::vector<int64_t> &zp) -{ - float scaling_factor_inv{0}; - - uint32_t size = node->size<loco::DataType::FLOAT32>(); - std::vector<int64_t> quantized_values(size); - - for (uint32_t i = 0; i < size; ++i) - { - scaling_factor[i] = input_scale * weight_scale[i]; - scaling_factor_inv = (scaling_factor[i] == 0) ? 0 : 1.0 / scaling_factor[i]; - quantized_values[i] = - static_cast<int64_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv)); - zp[i] = 0; - } - - auto new_bias = create_empty_const_from<loco::DataType::S64>(node, size); - - for (uint32_t i = 0; i < size; ++i) - { - new_bias->at<loco::DataType::S64>(i) = quantized_values[i]; - } - - return new_bias; -} - -bool has_min_max(const CircleNode *node) -{ - return node->quantparam() && !node->quantparam()->min.empty() && !node->quantparam()->max.empty(); -} - -void sym_wquant_per_channel(CircleConst *node, std::vector<float> &scaling_factor, - int32_t &channel_dim_index) -{ - assert(node->dtype() == loco::DataType::FLOAT32); - - const int32_t kMaxScale = std::numeric_limits<int16_t>::max(); - const int32_t kMinScale = -kMaxScale; - - uint32_t size = node->size<loco::DataType::FLOAT32>(); - std::vector<int32_t> quantized_values(size); - - auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int32_t channel_dim_index) { - int channel_idx = indices[channel_dim_index]; - const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx]; - auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices)); - quantized_values[cal_offset(dimension, indices)] = - static_cast<int32_t>(std::round(data * scaling_factor_inv)); - }; - - iterate_per_channel(node, channel_dim_index, quantize); - - node->dtype(loco::DataType::S16); // change the type of tensor - node->size<loco::DataType::S16>(size); // resize tensor - for (uint32_t i = 0; i < size; ++i) - { - node->at<loco::DataType::S16>(i) = - std::min(kMaxScale, std::max(kMinScale, quantized_values[i])); - } -} - -void asym_wquant_per_channel(CircleConst *node, std::vector<float> &min, - std::vector<float> &scaling_factor, int32_t &channel_dim_index) -{ - assert(node->dtype() == loco::DataType::FLOAT32); - - const int32_t kMinScale = 0; - const int32_t kMaxScale = 255; - - uint32_t size = node->size<loco::DataType::FLOAT32>(); - std::vector<int32_t> quantized_values(size); - - auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int32_t channel_dim_index) { - int channel_idx = indices[channel_dim_index]; - const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx]; - auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices)); - quantized_values[cal_offset(dimension, indices)] = - static_cast<int32_t>(std::round((data - min[channel_idx]) * scaling_factor_inv)); - }; - - iterate_per_channel(node, channel_dim_index, quantize); - - node->dtype(loco::DataType::U8); // change the type of tensor - node->size<loco::DataType::U8>(size); // resize tensor - for (uint32_t i = 0; i < size; ++i) - { - node->at<loco::DataType::U8>(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i])); - } -} - -void asym_wquant_per_layer(CircleConst *node, float min, float scaling_factor) -{ - const int32_t kMinScale = 0; - const int32_t kMaxScale = 255; - - uint32_t size = node->size<loco::DataType::FLOAT32>(); - - const float scaling_factor_inv = 1.0 / scaling_factor; - std::vector<int32_t> quantized_values(size); - for (uint32_t i = 0; i < size; ++i) - { - auto data = node->at<loco::DataType::FLOAT32>(i); - quantized_values[i] = static_cast<int32_t>(std::round((data - min) * scaling_factor_inv)); - } - - node->dtype(loco::DataType::U8); // change the type of tensor - node->size<loco::DataType::U8>(size); // resize tensor - for (uint32_t i = 0; i < size; ++i) - { - node->at<loco::DataType::U8>(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i])); - } -} - -void set_bias(luci::CircleNode *node, luci::CircleConst *bias) -{ - if (auto conv = dynamic_cast<CircleConv2D *>(node)) - conv->bias(bias); - else if (auto dconv = dynamic_cast<CircleDepthwiseConv2D *>(node)) - dconv->bias(bias); - else if (auto tconv = dynamic_cast<CircleTransposeConv *>(node)) - tconv->bias(bias); - else if (auto fc = dynamic_cast<CircleFullyConnected *>(node)) - fc->bias(bias); - else - throw std::runtime_error("Only convolution, depthwise convolution, transposed convolution, and " - "fully-connected layer have bias"); -} - -void set_act_qparam(luci::CircleNode *node, float scale, int64_t zp) -{ - assert(node); // FIX_CALLER_UNLESS - assert(node->quantparam()); // FIX_CALLER_UNLESS - - auto qparam = node->quantparam(); - assert(qparam->scale.size() == 1); // FIX_CALLER_UNLESS - assert(qparam->zerop.size() == 1); // FIX_CALLER_UNLESS - qparam->scale[0] = scale; - qparam->zerop[0] = zp; -} - -/** - * @brief Manually set scale/zp of output tensor of special Ops - */ -struct QuantizeSpecialActivation final : public luci::CircleNodeMutableVisitor<void> -{ - QuantizeSpecialActivation(loco::DataType input, loco::DataType output) - : input_type(input), output_type(output) - { - } - - loco::DataType input_type; - loco::DataType output_type; - - void visit(luci::CircleNode *) - { - // Do nothing by default - } - - void visit(luci::CircleLogistic *node) - { - if (output_type == loco::DataType::U8) - set_act_qparam(node, 1.0f / 256.0f, 0); - else - { - assert(output_type == loco::DataType::S16); - set_act_qparam(node, 1.0f / 32768.0f, 0); - } - } - - void visit(luci::CircleTanh *node) - { - if (output_type == loco::DataType::U8) - set_act_qparam(node, 2.0f / 256.0f, 128); - else - { - assert(output_type == loco::DataType::S16); - set_act_qparam(node, 1.0f / 32768.0f, 0); - } - } - - void visit(luci::CircleStridedSlice *node) - { - auto input = loco::must_cast<luci::CircleNode *>(node->input()); - auto i_qparam = input->quantparam(); - assert(i_qparam); - assert(i_qparam->scale.size() == 1); // FIX_CALLER_UNLESS - assert(i_qparam->zerop.size() == 1); // FIX_CALLER_UNLESS - auto i_scale = i_qparam->scale[0]; - auto i_zp = i_qparam->zerop[0]; - - set_act_qparam(node, i_scale, i_zp); - } - - void visit(luci::CircleSplitOut *node) - { - auto split = loco::must_cast<luci::CircleSplit *>(node->input()); - auto input = loco::must_cast<luci::CircleNode *>(split->input()); - auto i_qparam = input->quantparam(); - assert(i_qparam); - assert(i_qparam->scale.size() == 1); // FIX_CALLER_UNLESS - assert(i_qparam->zerop.size() == 1); // FIX_CALLER_UNLESS - auto i_scale = i_qparam->scale[0]; - auto i_zp = i_qparam->zerop[0]; - - set_act_qparam(node, i_scale, i_zp); - } - - void visit(luci::CircleSplitVOut *node) - { - auto splitv = loco::must_cast<luci::CircleSplitV *>(node->input()); - auto input = loco::must_cast<luci::CircleNode *>(splitv->input()); - auto i_qparam = input->quantparam(); - assert(i_qparam); - assert(i_qparam->scale.size() == 1); // FIX_CALLER_UNLESS - assert(i_qparam->zerop.size() == 1); // FIX_CALLER_UNLESS - auto i_scale = i_qparam->scale[0]; - auto i_zp = i_qparam->zerop[0]; - - set_act_qparam(node, i_scale, i_zp); - } - - void visit(luci::CircleUnpackOut *node) - { - auto unpack = loco::must_cast<luci::CircleUnpack *>(node->input()); - auto input = loco::must_cast<luci::CircleNode *>(unpack->value()); - auto i_qparam = input->quantparam(); - assert(i_qparam); - assert(i_qparam->scale.size() == 1); // FIX_CALLER_UNLESS - assert(i_qparam->zerop.size() == 1); // FIX_CALLER_UNLESS - auto i_scale = i_qparam->scale[0]; - auto i_zp = i_qparam->zerop[0]; - - set_act_qparam(node, i_scale, i_zp); - } - - // TODO Move Softmax, Floor, Ceil from QuantizeActivation to here -}; - /** - * @brief QuantizeActivation quantizes tensors for activations - * @details Quantize using recorded min/max values + * Insert Quantize operator for mixed-precision quantization + * 1. Before input feature map (only for non-const) + * 2. After output feature map + * + * For example, if default_dtype = U8 and op_dtype = S16, + * 1. Quantize Op for U8->S16 is inserted before ifm + * 2. Quantize Op for S16->U8 is inserted after ofm + * + * Why not insert Quantize Op for const ifm? + * We quantize const tensor at once to preserve precision. + * For example, if default dtype = U8, op_dtype = S16, and op is CONV2D, + * We directly quantize weights to 16 bits, not 8->16 bits. */ -struct QuantizeActivation final : public luci::CircleNodeMutableVisitor<bool> +struct InsertQuantizeOp final : public luci::CircleNodeMutableVisitor<void> { - QuantizeActivation(loco::DataType input, loco::DataType output) - : input_type(input), output_type(output) + InsertQuantizeOp(loco::DataType default_dtype, loco::DataType op_dtype) + : _default_dtype(default_dtype), _op_dtype(op_dtype) { + assert(default_dtype != op_dtype); // FIX_CALLER_UNLESS } - loco::DataType input_type; - loco::DataType output_type; +private: + loco::DataType _default_dtype; + loco::DataType _op_dtype; - // Quantize input tensors of each node - bool visit(luci::CircleNode *node) +private: + luci::CircleQuantize *create_in_quantize(loco::Node *in, loco::Node *origin) + { + auto input = loco::must_cast<luci::CircleNode *>(in); + if (input->opcode() == luci::CircleOpcode::CIRCLECONST) + return nullptr; + + auto input_quant = create_quantize_op(input, _op_dtype); + input_quant->input(input); + auto origin_node = loco::must_cast<luci::CircleNode *>(origin); + luci::add_origin(input_quant, luci::get_origin(origin_node)); + return input_quant; + } + + void insert_out_quantize(loco::Node *node) + { + auto output = loco::must_cast<luci::CircleNode *>(node); + assert(output->opcode() != luci::CircleOpcode::CIRCLECONST); // FIX_CALLER_UNLESS + auto output_quant = create_quantize_op(output, _default_dtype); + + luci::add_origin(output_quant, luci::get_origin(output)); + loco::replace(node).with(output_quant); + output_quant->input(node); + } + +// INPUT_NAME is the only activation of NODE +#define INSERT_QUANTIZE_TO_UNARY_OP(NODE, INPUT_NAME) \ + void visit(NODE *node) \ + { \ + if (auto input_quant = create_in_quantize(node->INPUT_NAME(), node)) \ + node->INPUT_NAME(input_quant); \ + \ + insert_out_quantize(node); \ + } + +// INPUT_NAME is the only activation of NODE +#define INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(NODE, INPUT_NAME, OUT_NAME) \ + void visit(NODE *node) \ + { \ + if (auto input_quant = create_in_quantize(node->INPUT_NAME(), node)) \ + node->INPUT_NAME(input_quant); \ + \ + auto out_nodes = loco::succs(node); \ + for (auto out_node : out_nodes) \ + { \ + auto out_circle = loco::must_cast<OUT_NAME *>(out_node); \ + insert_out_quantize(out_circle); \ + } \ + } + +// INPUT_NAME1 and INPUT_NAME2 are the only activations of NODE +#define INSERT_QUANTIZE_TO_BINARY_OP(NODE, INPUT_NAME1, INPUT_NAME2) \ + void visit(NODE *node) \ + { \ + if (auto input1_quant = create_in_quantize(node->INPUT_NAME1(), node)) \ + node->INPUT_NAME1(input1_quant); \ + \ + if (auto input2_quant = create_in_quantize(node->INPUT_NAME2(), node)) \ + node->INPUT_NAME2(input2_quant); \ + \ + insert_out_quantize(node); \ + } + + // Default behavior (NYI) + void visit(luci::CircleNode *node) + { + throw std::runtime_error("Unsupported Op for mixed-precision quantization. Layer name: " + + node->name()); + } + + // Skip output layer + void visit(luci::CircleOutput *) {} + void visit(luci::CircleSplitVOut *) {} + void visit(luci::CircleSplitOut *) {} + void visit(luci::CircleTopKV2Out *) {} + void visit(luci::CircleUniqueOut *) {} + void visit(luci::CircleUnpackOut *) {} + + // Ops that receive a single activation as an input + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleAveragePool2D, value) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleBatchToSpaceND, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleConv2D, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleDepthToSpace, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleDepthwiseConv2D, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleElu, features) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleExp, x) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleFloor, x) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleFullyConnected, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleGather, params) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleInstanceNorm, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLocalResponseNormalization, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLogistic, x) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMaxPool2D, value) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMean, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMirrorPad, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePad, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePadV2, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePRelu, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReduceProd, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReduceMax, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReduceMin, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleRelu, features) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReshape, tensor) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleResizeBilinear, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleResizeNearestNeighbor, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReverseSequence, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleRsqrt, x) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSlice, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSoftmax, logits) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSpaceToBatchND, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSpaceToDepth, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSqrt, x) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleStridedSlice, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSum, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleTanh, x) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleTile, input) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleTranspose, a) + INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleTransposeConv, outBackprop) + + // Ops that receive two activations as inputs + INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleAdd, x, y) + INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleBatchMatMul, x, y) + INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleDiv, x, y) + INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleFloorDiv, x, y) + INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleMaximum, x, y) + INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleMinimum, x, y) + INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleMul, x, y) + INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleOneHot, on_value, off_value) + INSERT_QUANTIZE_TO_BINARY_OP(luci::CirclePow, x, y) + INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleSub, x, y) + + // Multiple-output ops that receive one activation as inputs + INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(luci::CircleSplit, input, luci::CircleSplitOut) + INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(luci::CircleSplitV, input, luci::CircleSplitVOut) + INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(luci::CircleTopKV2, input, luci::CircleTopKV2Out) + INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(luci::CircleUnique, input, luci::CircleUniqueOut) + INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(luci::CircleUnpack, value, luci::CircleUnpackOut) + + // AddN has arbitrary number of inputs + void visit(luci::CircleAddN *node) { - LOGGER(l); - INFO(l) << "QuantizeActivation visit node: " << node->name() << std::endl; auto arity = node->arity(); for (uint32_t i = 0; i < arity; i++) { - auto input_node = node->arg(i); - auto circle_node = loco::must_cast<luci::CircleNode *>(input_node); - - // Check if this is already quantized - if (is_quantized(circle_node)) - continue; - - // Check if this is bias (bias is quantized later) - auto iwo = get_input_weight_output_of_bias(circle_node); - if (iwo.size() > 0) - continue; - - // Check if this is bool type (bool type is not quantized) - if (circle_node->dtype() == loco::DataType::BOOL) - continue; - - // Check if this is activation - // We assume min/max are recorded only for activations - if (has_min_max(circle_node) && !is_weights(circle_node)) - { - // Quantize using recorded min/max - auto quantparam = circle_node->quantparam(); - assert(quantparam); - assert(quantparam->min.size() == 1); // only support layer-wise quant - assert(quantparam->max.size() == 1); // only support layer-wise quant - auto min = quantparam->min[0]; - auto max = quantparam->max[0]; - - // Special values - if (circle_node->opcode() == luci::CircleOpcode::SOFTMAX) - { - min = 0.0f; - max = 1.0f; - } - - float scaling_factor{0}; - int64_t zp{0}; - float nudged_min{0}; - float nudged_max{0}; - - if (output_type == loco::DataType::U8) - { - compute_asym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max); - circle_node->dtype(loco::DataType::U8); - } - else - { - compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max); - circle_node->dtype(loco::DataType::S16); - } - - // Nodes fused with activation functions which need special quantization - auto fused_act_node = - dynamic_cast<CircleNodeMixin<CircleNodeTrait::FusedActFunc> *>(circle_node); - if (fused_act_node != nullptr && - fused_act_node->fusedActivationFunction() == FusedActFunc::TANH) - { - if (output_type == loco::DataType::U8) - { - scaling_factor = 2.0f / 256.0f; - zp = 128; - } - else - { - assert(output_type == loco::DataType::S16); - scaling_factor = 1.0f / 32768.0f; - zp = 0; - } - } - - // The output of these Ops should be integer, so scale should be integer - // TODO Handle cases where the integer scale needs to be propagated - if (circle_node->opcode() == CircleOpcode::FLOOR || - circle_node->opcode() == CircleOpcode::FLOOR_DIV || - circle_node->opcode() == CircleOpcode::FLOOR_MOD || - circle_node->opcode() == CircleOpcode::CEIL) - { - assert(scaling_factor >= 0); // FIX_ME_UNLESS - scaling_factor = scaling_factor < 1 ? 1.0f : std::round(scaling_factor); - } - - circle_node->quantparam()->scale.push_back(scaling_factor); - circle_node->quantparam()->zerop.push_back(zp); - } - // Fix special attributes - if (circle_node->opcode() == luci::CircleOpcode::CAST) - { - auto *cast = loco::must_cast<luci::CircleCast *>(circle_node); - auto *cast_input = loco::must_cast<luci::CircleNode *>(cast->x()); - - // make sure that cast_input is already quantized - assert(cast_input->dtype() != loco::DataType::FLOAT32); - cast->in_data_type(cast_input->dtype()); - cast->out_data_type(cast->dtype()); - } - } - return false; - } -}; - -struct QuantizeBias final : public luci::CircleNodeMutableVisitor<bool> -{ - QuantizeBias(loco::DataType input, loco::DataType output, QuantizationGranularity gr) - : input_type(input), output_type(output), granularity(gr) - { - } - - loco::DataType input_type; - loco::DataType output_type; - QuantizationGranularity granularity; - - // Quantize bias node - bool visit(luci::CircleNode *node) - { - // Check if this is already quantized - if (is_quantized(node)) - return false; - - auto iwo_list = get_input_weight_output_of_bias(node); - - for (auto iwo : iwo_list) - { - assert(iwo.size() == 3); - - auto input = loco::must_cast<luci::CircleNode *>(iwo[0]); - auto weight = loco::must_cast<luci::CircleNode *>(iwo[1]); - auto output = loco::must_cast<luci::CircleNode *>(iwo[2]); - - auto const_bias = loco::must_cast<luci::CircleConst *>(node); - assert(const_bias->dtype() == loco::DataType::FLOAT32); - - // If input is const, it is quantized here, not in QuantizeActivation - if (auto const_input = dynamic_cast<luci::CircleConst *>(input)) - { - quant_const(const_input, output_type); - } - - CircleConst *new_bias = nullptr; - - if (granularity == QuantizationGranularity::ChannelWise) - { - auto input_q = input->quantparam(); - assert(input_q); - assert(input_q->scale.size() == 1); // input scale's layer-wise - auto input_scale = input_q->scale[0]; - - assert(weight->quantparam() != nullptr); // weight scale's channel-wise - auto weight_scale = weight->quantparam()->scale; - - uint32_t size = const_bias->size<loco::DataType::FLOAT32>(); - assert(size == weight_scale.size()); - std::vector<float> scaling_factor(size); - std::vector<int64_t> zp(size); - - if (output_type == loco::DataType::U8) - { - new_bias = - quant_bias_per_channel(const_bias, input_scale, weight_scale, scaling_factor, zp); - } - else if (output_type == loco::DataType::S16) - { - new_bias = - int16_quant_bias_per_channel(const_bias, input_scale, weight_scale, scaling_factor, zp); - } - else - { - throw std::runtime_error("Unsupported quantization type."); - } - - auto quantparam = std::make_unique<CircleQuantParam>(); - quantparam->scale = scaling_factor; - quantparam->zerop = zp; - assert(new_bias->quantparam() == nullptr); // bias should not be quantized before - new_bias->quantparam(std::move(quantparam)); - - set_bias(output, new_bias); - } - else - { - auto input_q = input->quantparam(); - assert(input_q); - assert(input_q->scale.size() == 1); // Only support per-layer quant - auto input_scale = input_q->scale[0]; - - auto weight_q = weight->quantparam(); - assert(weight_q); - assert(weight_q->scale.size() == 1); // Only support per-layer quant - auto weight_scale = weight_q->scale[0]; - - float scaling_factor{0}; - int64_t zp{0}; - new_bias = - asym_quant_bias_per_layer(const_bias, input_scale, weight_scale, &scaling_factor, &zp); - auto quantparam = std::make_unique<CircleQuantParam>(); - quantparam->scale.push_back(scaling_factor); - quantparam->zerop.push_back(zp); - assert(new_bias->quantparam() == nullptr); // bias should not be quantized before - new_bias->quantparam(std::move(quantparam)); - - set_bias(output, new_bias); - } - } - return false; - } -}; - -/** - * @brief QuantizeWeights quantizes tensors for weights - * @details Find min/max values on the fly and then quantize - */ -struct QuantizeWeights final : public luci::CircleNodeMutableVisitor<bool> -{ - QuantizeWeights(loco::DataType input, loco::DataType output, QuantizationGranularity gr) - : input_type(input), output_type(output), granularity(gr) - { - } - - loco::DataType input_type; - loco::DataType output_type; - QuantizationGranularity granularity; - -private: - void quantize_weights(luci::CircleConst *weights) - { - // Find min/max per channel-wise - if (granularity == QuantizationGranularity::ChannelWise) - { - auto quantparam = weights->quantparam(); - if (quantparam == nullptr) - { - assert(false && "quantparam is nullptr"); - return; - } - - auto min = quantparam->min; - auto scaling_factor = quantparam->scale; - int32_t channel_dim_index = 0; - - if (output_type == loco::DataType::U8) - { - asym_wquant_per_channel(weights, min, scaling_factor, channel_dim_index); - } - else - { - sym_wquant_per_channel(weights, scaling_factor, channel_dim_index); - } - quantparam->min.clear(); - quantparam->max.clear(); - quantparam->quantized_dimension = channel_dim_index; - } - // Find min/max per layer-wise - else - { - // Quantize using recorded quantparam - auto quantparam = weights->quantparam(); - assert(quantparam != nullptr); - assert(quantparam->min.size() == 1); // only support layer-wise quant - assert(quantparam->scale.size() == 1); // only support layer-wise quant - auto min = quantparam->min[0]; - auto scaling_factor = quantparam->scale[0]; - asym_wquant_per_layer(weights, min, scaling_factor); - quantparam->min.clear(); - quantparam->max.clear(); - } - } - - bool visit(luci::CircleConv2D *node) - { - LOGGER(l); - INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl; - - auto weights = loco::must_cast<luci::CircleConst *>(node->filter()); - if (!is_quantized(weights)) - { - auto new_weights = luci::clone(weights); - node->filter(new_weights); - quantize_weights(new_weights); - return true; + if (auto input_quant = create_in_quantize(node->inputs(i), node)) + node->inputs(i, input_quant); } - return false; - } - - bool visit(luci::CircleDepthwiseConv2D *node) - { - LOGGER(l); - INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl; - auto weights = loco::must_cast<luci::CircleConst *>(node->filter()); - if (!is_quantized(weights)) - { - auto new_weights = luci::clone(weights); - node->filter(new_weights); - quantize_weights(new_weights); - return true; - } - return false; + insert_out_quantize(node); } - bool visit(luci::CircleInstanceNorm *node) + // Concat has arbitrary number of inputs + void visit(luci::CircleConcatenation *node) { - LOGGER(l); - INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl; - - auto gamma = loco::must_cast<luci::CircleConst *>(node->gamma()); - auto beta = loco::must_cast<luci::CircleConst *>(node->beta()); - - bool changed = false; - if (!is_quantized(gamma)) - { - assert(gamma->dtype() == loco::DataType::FLOAT32); - auto new_gamma = luci::clone(gamma); - if (granularity == QuantizationGranularity::LayerWise) - quant_const(new_gamma, output_type); - else if (granularity == QuantizationGranularity::ChannelWise) - quant_const_per_channel(new_gamma, output_type); - node->gamma(new_gamma); - changed = true; - } - if (!is_quantized(beta)) - { - assert(beta->dtype() == loco::DataType::FLOAT32); - auto new_beta = luci::clone(beta); - if (granularity == QuantizationGranularity::LayerWise) - quant_const(new_beta, output_type); - else if (granularity == QuantizationGranularity::ChannelWise) - quant_const_per_channel(new_beta, output_type); - node->beta(new_beta); - changed = true; - } - - return changed; - } - - bool visit(luci::CirclePRelu *node) - { - LOGGER(l); - INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl; - - auto alpha = loco::must_cast<luci::CircleConst *>(node->alpha()); - - if (!is_quantized(alpha)) + auto arity = node->arity(); + for (uint32_t i = 0; i < arity; i++) { - assert(alpha->dtype() == loco::DataType::FLOAT32); - auto new_alpha = luci::clone(alpha); - if (granularity == QuantizationGranularity::LayerWise) - quant_const(new_alpha, output_type); - else if (granularity == QuantizationGranularity::ChannelWise) - quant_const_per_channel(new_alpha, output_type); - node->alpha(new_alpha); - return true; + if (auto input_quant = create_in_quantize(node->values(i), node)) + node->values(i, input_quant); } - return false; + insert_out_quantize(node); } - bool visit(luci::CircleTransposeConv *node) + // Pack has arbitrary number of inputs + void visit(luci::CirclePack *node) { - LOGGER(l); - INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl; - - auto weights = loco::must_cast<luci::CircleConst *>(node->filter()); - if (!is_quantized(weights)) + auto arity = node->arity(); + for (uint32_t i = 0; i < arity; i++) { - auto new_weights = luci::clone(weights); - node->filter(new_weights); - quantize_weights(new_weights); - return true; + if (auto input_quant = create_in_quantize(node->values(i), node)) + node->values(i, input_quant); } - return false; - } - - bool visit(luci::CircleFullyConnected *node) - { - LOGGER(l); - INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl; - auto weights = loco::must_cast<luci::CircleConst *>(node->weights()); - if (!is_quantized(weights)) - { - auto new_weights = luci::clone(weights); - node->weights(new_weights); - quantize_weights(new_weights); - return true; - } - return false; + insert_out_quantize(node); } - bool visit(luci::CircleNode *) { return false; } +#undef INSERT_QUANTIZE_TO_UNARY_OP +#undef INSERT_QUANTIZE_TO_BINARY_OP +#undef INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP }; -/** EXAMPLE - * - * BEFORE - * - * [CircleNode] [CircleConst] - * (qparam1) (FP32) - * \ / - * \ / - * [CirclePack] - * (qparam2) - * - * AFTER - * - * [CircleNode] [CircleConst] [CircleConst] <- Dead node - * (qparam2) (qparam2) (FP32) - * \ / - * \ / - * [CirclePack] - * (qparam2) - * - * NOTE Quantization parameter of CirclePack (qparam2) is propagated to the inputs. - */ -void propagate_pack_quantparam(luci::CirclePack *pack, loco::DataType quant_type) -{ - assert(pack->quantparam() != nullptr); - - const auto num_inputs = pack->values_count(); - - for (uint32_t i = 0; i < num_inputs; i++) - { - auto node = loco::must_cast<luci::CircleNode *>(pack->arg(i)); - - // Skip if this input is PACK Op - if (node->opcode() == luci::CircleOpcode::PACK) - continue; - - // Quantize constant values - if (node->opcode() == luci::CircleOpcode::CIRCLECONST) - { - luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node); - if (const_node->dtype() != loco::DataType::FLOAT32) - throw std::runtime_error("Unsupported data type for constant input of pack Op"); - - const auto pack_qparam = pack->quantparam(); - if (pack_qparam == nullptr) - throw std::runtime_error("quantparam of pack is not found during propagation"); - - assert(pack_qparam->scale.size() == 1); - assert(pack_qparam->zerop.size() == 1); - const auto scaling_factor = pack_qparam->scale[0]; - const auto zerop = pack_qparam->zerop[0]; - - auto new_const = luci::clone(const_node); - quant_const_values(new_const, scaling_factor, zerop, quant_type); - pack->values(i, new_const); - overwrite_quantparam(pack, new_const); - } - else - { - const auto succs = loco::succs(node); - if (succs.size() > 1) - continue; - - // Non-const input must have been quantized - assert(node->quantparam() != nullptr); - overwrite_quantparam(pack, node); - } - } -} - -/** - * @brief Quantize const input tensors using min/max of const values - */ -void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type) -{ - auto opcode = node->opcode(); - auto arity = node->arity(); - - loco::Node *input_node{nullptr}; - luci::CircleConst *const_node{nullptr}; - - switch (opcode) - { - case luci::CircleOpcode::CONV_2D: - case luci::CircleOpcode::DEPTHWISE_CONV_2D: - case luci::CircleOpcode::FULLY_CONNECTED: - case luci::CircleOpcode::INSTANCE_NORM: - case luci::CircleOpcode::PRELU: - case luci::CircleOpcode::TRANSPOSE_CONV: - // Handled in QuantizeWeights and QuantizeBias - break; - - case luci::CircleOpcode::CONCATENATION: - // Handled in propagate_concat_quantparam - break; - - case luci::CircleOpcode::LOGICAL_OR: - // Inputs of logical Ops are bool, thus not quantized - break; - - case luci::CircleOpcode::ARG_MAX: - case luci::CircleOpcode::ARG_MIN: - case luci::CircleOpcode::BATCH_TO_SPACE_ND: - case luci::CircleOpcode::LOCAL_RESPONSE_NORMALIZATION: - case luci::CircleOpcode::MEAN: - case luci::CircleOpcode::MIRROR_PAD: - case luci::CircleOpcode::PAD: - case luci::CircleOpcode::REDUCE_ANY: - case luci::CircleOpcode::REDUCE_PROD: - case luci::CircleOpcode::REDUCE_MAX: - case luci::CircleOpcode::REDUCE_MIN: - case luci::CircleOpcode::RESHAPE: - case luci::CircleOpcode::RESIZE_BILINEAR: - case luci::CircleOpcode::RESIZE_NEAREST_NEIGHBOR: - case luci::CircleOpcode::REVERSE_SEQUENCE: - case luci::CircleOpcode::SLICE: - case luci::CircleOpcode::SPACE_TO_BATCH_ND: - case luci::CircleOpcode::SPLIT_V: - case luci::CircleOpcode::STRIDED_SLICE: - case luci::CircleOpcode::SUM: - case luci::CircleOpcode::TILE: - case luci::CircleOpcode::TOPK_V2: - case luci::CircleOpcode::TRANSPOSE: - // The second input of these Ops should not be quantized - // Ex: axis, paddings - input_node = node->arg(0); - const_node = dynamic_cast<luci::CircleConst *>(input_node); - if (const_node != nullptr && !is_quantized(const_node)) - quant_const(const_node, output_type); - break; - - case luci::CircleOpcode::ADD: - case luci::CircleOpcode::ADD_N: - case luci::CircleOpcode::DEPTH_TO_SPACE: - case luci::CircleOpcode::DIV: - case luci::CircleOpcode::ELU: - case luci::CircleOpcode::EQUAL: - case luci::CircleOpcode::EXP: - case luci::CircleOpcode::FLOOR: - case luci::CircleOpcode::FLOOR_DIV: - case luci::CircleOpcode::GREATER: - case luci::CircleOpcode::GREATER_EQUAL: - case luci::CircleOpcode::LESS: - case luci::CircleOpcode::LESS_EQUAL: - case luci::CircleOpcode::LOGISTIC: - case luci::CircleOpcode::MAXIMUM: - case luci::CircleOpcode::MINIMUM: - case luci::CircleOpcode::MUL: - case luci::CircleOpcode::NOT_EQUAL: - case luci::CircleOpcode::POW: - case luci::CircleOpcode::RSQRT: - case luci::CircleOpcode::SOFTMAX: - case luci::CircleOpcode::SPACE_TO_DEPTH: - case luci::CircleOpcode::SQRT: - case luci::CircleOpcode::SUB: - case luci::CircleOpcode::TANH: - case luci::CircleOpcode::UNPACK: - // Quantize all const inputs using their values - for (uint32_t i = 0; i < arity; i++) - { - input_node = node->arg(i); - const_node = dynamic_cast<luci::CircleConst *>(input_node); - if (const_node != nullptr && !is_quantized(const_node)) - quant_const(const_node, output_type); - } - break; - - case luci::CircleOpcode::SPLIT: - // Only the second input is quantized - // First input should not be quantized (e.g., split_dim) - input_node = node->arg(1); - const_node = dynamic_cast<luci::CircleConst *>(input_node); - if (const_node != nullptr && !is_quantized(const_node)) - quant_const(const_node, output_type); - break; - - case luci::CircleOpcode::PADV2: - // First and third constant inputs are quantized - // Second input should not be quantized (e.g., paddings) - // Quant params are propagated either from output range to the non-constant input - // or from input to output and constant values - propagate_pad_v2_quantparam(loco::must_cast<CirclePadV2 *>(node), output_type); - break; - - case luci::CircleOpcode::PACK: - // Quant param is propagated from output to inputs - propagate_pack_quantparam(loco::must_cast<CirclePack *>(node), output_type); - break; - - default: - for (uint32_t i = 0; i < arity; i++) - { - input_node = node->arg(i); - const_node = dynamic_cast<luci::CircleConst *>(input_node); - if (const_node != nullptr) - throw std::runtime_error("Unsupported Op for const inputs"); - } - break; - } -} - } // namespace -/** BEFORE - * - * [CircleNode] [CircleConst] - * (U8 qparam1) (FP32) - * \ / - * \ / - * [CircleConcatenation] - * (U8 qparam2) - * - * AFTER - * [CircleNode] [CircleConst] [CircleConst] <- Dead node - * (U8 qparam2) (U8 qparam2) (FP32) - * \ / - * \ / - * [CircleConcatenation] - * (U8 qparam2) - */ -void propagate_concat_quantparam(luci::CircleConcatenation *concat, loco::DataType quant_type) -{ - assert(concat->quantparam() != nullptr); - - const auto num_inputs = concat->numValues(); - - // Quantize const inputs using their values if concat has fused act function - if (concat->fusedActivationFunction() != luci::FusedActFunc::NONE) - { - for (uint32_t i = 0; i < num_inputs; i++) - { - auto node = concat->arg(i); - auto const_node = dynamic_cast<luci::CircleConst *>(node); - if (const_node != nullptr) - { - auto new_const = luci::clone(const_node); - quant_const(new_const, quant_type); - concat->values(i, new_const); - } - } - return; - } - - for (uint32_t i = 0; i < num_inputs; i++) - { - auto node = loco::must_cast<luci::CircleNode *>(concat->arg(i)); - - // Skip if this input is CONCAT Op - if (node->opcode() == luci::CircleOpcode::CONCATENATION) - continue; - - // Quantize constant values - if (node->opcode() == luci::CircleOpcode::CIRCLECONST) - { - luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node); - if (const_node->dtype() != loco::DataType::FLOAT32) - throw std::runtime_error("Unsupported data type for constant input of concatenation Op"); - - const auto concat_qparam = concat->quantparam(); - if (concat_qparam == nullptr) - throw std::runtime_error("quantparam of concat is not found during propagation"); - - assert(concat_qparam->scale.size() == 1); - const auto scaling_factor = concat_qparam->scale[0]; - const auto zerop = concat_qparam->zerop[0]; - - auto new_const = luci::clone(const_node); - quant_const_values(new_const, scaling_factor, zerop, quant_type); - concat->values(i, new_const); - overwrite_quantparam(concat, new_const); - } - else - { - const auto succs = loco::succs(node); - if (succs.size() > 1) - continue; - - // Non-const input must have been quantized - assert(node->quantparam() != nullptr); - overwrite_quantparam(concat, node); - } - } -} - -/** - * tells if pad_v2 quantization should ignore padding value - * In that case padding const will be quantized with input parameters, and probably clipped - */ -bool ignore_pad_v2_const_quantization(luci::CirclePadV2 *pad) -{ - // This is a workaround to quantize pad generated from MaxPoolWithArgmax operation properly - // TODO use metadata hints to detect this case - auto const_value_node = dynamic_cast<luci::CircleConst *>(pad->arg(2)); - if (!const_value_node) - return false; - if (const_value_node->dtype() == loco::DataType::FLOAT32) - { - float const_value = const_value_node->at<loco::DataType::FLOAT32>(0); - if (const_value == std::numeric_limits<float>::lowest()) - return true; - } - return false; -} - -/** BEFORE - * - * [CircleNode] [CircleConst] [CircleConst] - * (U8 qparam1) (S32) (FP32) - * \ | / - * \ | / - * [CirclePadV2] - * (U8 qparam2) - * - * AFTER (case 1) - * - * By default qparam is propagated from output to inputs to meet backend requirements. - * - * [CircleNode] [CircleConst] [CircleConst] [CircleConst] <- Dead node - * (U8 qparam2) (S32) (U8 qparam2) (FP32) - * \ | / - * \ | / - * [CirclePadV2] - * (U8 qparam2) - * - * AFTER (case 2) - * - * In case padded value is the lowest float value - * Qparam is propagated from input to output and constant. - * - * This is a special case for optimization constructed pad, needed to guarantee that - * extremely large negative constant do not stretch output quantization range. - * - * [CircleNode] [CircleConst] [CircleConst] [CircleConst] <- Dead node - * (U8 qparam1) (S32) (U8 qparam1) (FP32) - * \ | / - * \ | / - * [CirclePadV2] - * (U8 qparam1) - */ -void propagate_pad_v2_quantparam(luci::CirclePadV2 *pad_v2, loco::DataType quant_type) -{ - if (ignore_pad_v2_const_quantization(pad_v2)) - { - // propagate input quantization paramters from input to output and padding const value - auto pad_v2_input = loco::must_cast<luci::CircleNode *>(pad_v2->arg(0)); - overwrite_quantparam(pad_v2_input, pad_v2); - - auto const_value_node = loco::must_cast<luci::CircleConst *>( - pad_v2->arg(2)); // FIX ignore_pad_v2_const_quantization UNLESS - auto new_const = luci::clone(const_value_node); - - const auto pad_v2_input_qparam = pad_v2_input->quantparam(); - assert(pad_v2_input_qparam != nullptr); - assert(pad_v2_input_qparam->scale.size() == 1); - const auto scaling_factor = pad_v2_input_qparam->scale.at(0); - const auto zerop = pad_v2_input_qparam->zerop.at(0); - - quant_const_values(new_const, scaling_factor, zerop, quant_type); - overwrite_quantparam(pad_v2_input, new_const); - pad_v2->constant_values(new_const); - return; - } - - // Propagate quantization paramters from output to inputs, - // to fit both input and counstant_value in one quant range. - auto quant_input = [pad_v2, quant_type](void (CirclePadV2::*arg_setter)(loco::Node *), - uint32_t arg) { - auto node = loco::must_cast<luci::CircleNode *>(pad_v2->arg(arg)); - - // Quantize constant values - if (node->opcode() == luci::CircleOpcode::CIRCLECONST) - { - luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node); - if (is_quantized(const_node)) - return; - - if (const_node->dtype() != loco::DataType::FLOAT32) - throw std::runtime_error("Unsupported data type for constant input of PadV2 Op"); - - const auto pad_v2_qparam = pad_v2->quantparam(); - if (pad_v2_qparam == nullptr) - throw std::runtime_error("quantparam of PadV2 is not found during propagation"); - - assert(pad_v2_qparam->scale.size() == 1); - const auto scaling_factor = pad_v2_qparam->scale.at(0); - const auto zerop = pad_v2_qparam->zerop.at(0); - - auto new_const = luci::clone(const_node); - quant_const_values(new_const, scaling_factor, zerop, quant_type); - overwrite_quantparam(pad_v2, new_const); - (pad_v2->*arg_setter)(new_const); - } - // Subsequent PadV2 Ops quant params are not propagated - else if (node->opcode() == luci::CircleOpcode::PADV2) - { - return; - } - else - { - const auto succs = loco::succs(node); - if (succs.size() > 1) - return; - - // Non-const input must have been quantized - assert(node->quantparam() != nullptr); - overwrite_quantparam(pad_v2, node); - } - }; - - quant_input(&CirclePadV2::input, 0); - quant_input(&CirclePadV2::constant_values, 2); -} - void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const { auto inputs = g->inputs(); for (auto node : loco::input_nodes(g)) { auto input = loco::must_cast<luci::CircleInput *>(node); - if (input->dtype() == _input_type) + if (input->dtype() == _ctx->input_type) continue; // Bool type is not quantizable if (input->dtype() == loco::DataType::BOOL) continue; + if (input->dtype() == loco::DataType::S32) + continue; + if (input->dtype() == loco::DataType::S64) + continue; // Insert Quantize Op auto quant_op = create_quantize_op(input, input->dtype()); @@ -1552,22 +367,22 @@ void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const float nudged_min{0}; float nudged_max{0}; - if (_input_type == loco::DataType::U8) + if (_ctx->input_type == loco::DataType::U8) { compute_asym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max); } else { - assert(_input_type == loco::DataType::S16); + assert(_ctx->input_type == loco::DataType::S16); compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max); } - input->dtype(_input_type); + input->dtype(_ctx->input_type); input->quantparam()->scale[0] = scaling_factor; input->quantparam()->zerop[0] = zp; } auto graph_input = inputs->at(input->index()); - graph_input->dtype(_input_type); + graph_input->dtype(_ctx->input_type); } } @@ -1577,7 +392,7 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const for (auto node : loco::output_nodes(g)) { auto output = loco::must_cast<luci::CircleOutput *>(node); - if (output->dtype() == _output_type) + if (output->dtype() == _ctx->output_type) continue; // Bool type is not quantizable @@ -1591,7 +406,7 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const continue; // Insert Quantize Op - auto quant_op = create_quantize_op(from, _output_type); + auto quant_op = create_quantize_op(from, _ctx->output_type); loco::replace(from).with(quant_op); quant_op->input(from); @@ -1599,67 +414,165 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const luci::add_origin(quant_op, luci::get_origin(from)); auto graph_output = outputs->at(output->index()); - graph_output->dtype(_output_type); + graph_output->dtype(_ctx->output_type); } } +/** + * How QuantizeWithMinMax works? + * + * We categorized tensors into four groups + * - Activation: Feature maps (both Const/Non-const) + * - Weights: Const tensors of specific Ops (Conv, FC, ...) + * - Bias: Const tensors of specific Ops (Conv, FC, ...) + * - Others: padding value, one_hot value, axis, .. + * + * Activation is quantized in different ways + * 1. For non-constant activation, quantize using recorded min/max + * 2. For constant activation, quantize using min/max of its value + * 3. For some Ops (ex: pad_v2), output qparam is used as input qparam (backward propagation) + * 4. For some Ops (ex: reshape), input qparam is used as output qparam (forward propagation) + * 5. For some Ops (ex: tanh), output qparam has pre-defined values + * + * Weights is quantized using min/max of its value + * + * Bias is quantized using input scale (s_i) and weights scale (s_w) + * - Activation and weights should be quantized earlier than bias + * + * Quantization Steps + * 1. Quantize Activation + * - Quantize using recorded min/max (QuantizeActivation) + * - Insert Quantize Ops for mixed-precision quantization (InsertQuantizeOp) + * - Remove redundant Quantize Ops (RemoveRedundantQuantizePass) + * - Propagate qparam backward (PropagateQParamBackwardPass) + * - Quantize const inputs (QuantizeConstInputActivation) + * - Quantize using pre-defined values (QuantizeSpecialActivation) + * - Propagate qparam forward (PropagateQParamForwardPass) + * 2. Quantize Weights + * 3. Quantize Bias + * 4. Set input dtype + * 5. Set output dtype + * + * Why quantization sequence was determined as above? + * - Activation and weights should be quantized before bias (1->2->3). Input/Output + * dtype can be updated at the end (4->5). + * - During activation quantization, + * - Backward propagation is performed earlier than forward propagation. This allows + * backward-propagated qpram to be overwritten during forward propagation. + * We made this decision as Ops for forward propagation (reshape, transpose, ..) + * are more common than backward propagation. TODO Check this decision is safe. + * - QuantizeSpecialActivation is called before forward propagation to make sure that + * the pre-defined qparam values are propagated. + */ bool QuantizeWithMinMaxPass::run(loco::Graph *g) { LOGGER(l); INFO(l) << "QuantizeWithMinMaxPass Start" << std::endl; + auto info_by_name = layer_info_map(g, _ctx->layers_info); + + auto quantize_dtype = [&](const luci::CircleNode *node) { + auto iter = info_by_name.find(node->name()); + + // Return designated quantization dtype + if (iter != info_by_name.end()) + return iter->second.dtype; + + // Return default quantization dtype + return _ctx->output_model_dtype; + }; + + auto quantize_granularity = [&](const luci::CircleNode *node) { + auto iter = info_by_name.find(node->name()); + + // Return designated quantization granularity + if (iter != info_by_name.end()) + return iter->second.granularity; + + // Return default quantization granularity + return _ctx->granularity; + }; + // Quantize activation for (auto node : loco::active_nodes(loco::output_nodes(g))) { - QuantizeActivation qa(_input_model_dtype, _output_model_dtype); auto circle_node = loco::must_cast<luci::CircleNode *>(node); + QuantizeActivation qa(_ctx->input_model_dtype, quantize_dtype(circle_node)); circle_node->accept(&qa); } - // Quantize weights + // Insert Quantize Op for (auto node : loco::active_nodes(loco::output_nodes(g))) { - QuantizeWeights qw(_input_model_dtype, _output_model_dtype, _granularity); auto circle_node = loco::must_cast<luci::CircleNode *>(node); - circle_node->accept(&qw); + auto op_dtype = quantize_dtype(circle_node); + if (op_dtype != _ctx->output_model_dtype) + { + InsertQuantizeOp iqo(_ctx->output_model_dtype, op_dtype); + circle_node->accept(&iqo); + } } - // Quantize bias + // Remove redundant Quantize Op + { + logo::Phase phase; + + phase.emplace_back(std::make_unique<luci::RemoveRedundantQuantizePass>()); + + ProgressReporter prog(g, logo::PhaseStrategy::Saturate); + logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g}; + phase_runner.attach(&prog); + phase_runner.run(phase); + } + + // Backward propagation of activation qparam + { + PropagateQParamBackwardPass pqbp(_ctx->output_model_dtype); + pqbp.run(g); + } + + // Quantize const input activation for (auto node : loco::active_nodes(loco::output_nodes(g))) { - QuantizeBias qb(_input_model_dtype, _output_model_dtype, _granularity); auto circle_node = loco::must_cast<luci::CircleNode *>(node); - circle_node->accept(&qb); + QuantizeConstInputActivation qcia(quantize_dtype(circle_node)); + circle_node->accept(&qcia); } - // Propagate quantization parameters of concat Op + // Update qparam of output of special Ops for (auto node : loco::active_nodes(loco::output_nodes(g))) { - auto concat = dynamic_cast<luci::CircleConcatenation *>(node); - if (not concat) - continue; - - // Propagate qparam of concat to its inputs if - // (1) concat is uint8-quantized - // (2) concat has no fused activation function - // (3) the input is not concatenation Op - // (4) the input is not produced to Ops other than concat - propagate_concat_quantparam(concat, _output_model_dtype); + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + QuantizeSpecialActivation qsa(_ctx->input_model_dtype, quantize_dtype(circle_node)); + circle_node->accept(&qsa); } - // Quantize const inputs other than weights and bias + // Forward propagation of activation qparam + logo::Phase phase; + + phase.emplace_back(std::make_unique<luci::PropagateQParamForwardPass>(_ctx->TF_style_maxpool)); + + ProgressReporter prog(g, logo::PhaseStrategy::Saturate); + logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g}; + phase_runner.attach(&prog); + phase_runner.run(phase); + + // Quantize weights for (auto node : loco::active_nodes(loco::output_nodes(g))) { auto circle_node = loco::must_cast<luci::CircleNode *>(node); - quantize_const_inputs(circle_node, _output_model_dtype); + QuantizeWeights qw(_ctx->input_model_dtype, quantize_dtype(circle_node), + quantize_granularity(circle_node)); + circle_node->accept(&qw); } - // Update qparam of output of special Ops + // Quantize bias for (auto node : loco::active_nodes(loco::output_nodes(g))) { - QuantizeSpecialActivation qsa(_input_model_dtype, _output_model_dtype); auto circle_node = loco::must_cast<luci::CircleNode *>(node); - circle_node->accept(&qsa); + QuantizeBias qb(_ctx->input_model_dtype, quantize_dtype(circle_node), + quantize_granularity(circle_node)); + circle_node->accept(&qb); } // Update output dtype @@ -1667,11 +580,11 @@ bool QuantizeWithMinMaxPass::run(loco::Graph *g) for (auto node : loco::output_nodes(g)) { auto circle_node = loco::must_cast<luci::CircleOutput *>(node); - if (static_cast<luci::CircleNode *>(circle_node->from())->dtype() == _output_model_dtype) + if (static_cast<luci::CircleNode *>(circle_node->from())->dtype() == _ctx->output_model_dtype) { - circle_node->dtype(_output_model_dtype); + circle_node->dtype(_ctx->output_model_dtype); auto graph_output = graph_outputs->at(circle_node->index()); - graph_output->dtype(_output_model_dtype); + graph_output->dtype(_ctx->output_model_dtype); } } diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp index 75ec0cfd8..d5fa21ffd 100644 --- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp +++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp @@ -16,8 +16,41 @@ #include "luci/Pass/QuantizeWithMinMaxPass.h" +#include <luci/IR/CircleNodes.h> + #include <gtest/gtest.h> +class SimpleConcatGraph +{ +public: + SimpleConcatGraph(loco::DataType quant_type) + { + concat_node = g.nodes()->create<luci::CircleConcatenation>(2); + input_1 = g.nodes()->create<luci::CircleConst>(); + input_2 = g.nodes()->create<luci::CircleConst>(); + + concat_node->dtype(quant_type); + concat_node->fusedActivationFunction(luci::FusedActFunc::NONE); + input_1->dtype(quant_type); + input_2->dtype(quant_type); + + concat_node->values(0, input_1); + concat_node->values(1, input_2); + } + + ~SimpleConcatGraph() + { + concat_node->values(0, nullptr); + concat_node->values(1, nullptr); + } + +public: + loco::Graph g; + luci::CircleConcatenation *concat_node = nullptr; + luci::CircleConst *input_1 = nullptr; + luci::CircleConst *input_2 = nullptr; +}; + TEST(QuantizeWithMinMaxPassTest, name) { luci::QuantizeWithMinMaxPass pass(loco::DataType::FLOAT32, loco::DataType::U8, @@ -25,3 +58,19 @@ TEST(QuantizeWithMinMaxPassTest, name) auto const name = pass.name(); ASSERT_NE(nullptr, name); } + +// Test concat of integer tensors +// Integer tensors are not quantized +TEST(QuantizeWithMinMaxPassTest, int_concat) +{ + SimpleConcatGraph g(loco::DataType::S32); + + luci::QuantizeWithMinMaxPass qwmm(loco::DataType::FLOAT32, loco::DataType::U8, + luci::QuantizationGranularity::LayerWise); + + qwmm.run(&g.g); + + EXPECT_EQ(nullptr, g.concat_node->quantparam()); + EXPECT_EQ(nullptr, g.input_1->quantparam()); + EXPECT_EQ(nullptr, g.input_2->quantparam()); +} diff --git a/compiler/luci/pass/src/QuantizedModelVerifier.cpp b/compiler/luci/pass/src/QuantizedModelVerifier.cpp index f02301ed1..684d5d48a 100644 --- a/compiler/luci/pass/src/QuantizedModelVerifier.cpp +++ b/compiler/luci/pass/src/QuantizedModelVerifier.cpp @@ -15,10 +15,10 @@ #include "QuantizedModelVerifier.h" -#include "VerifyQuantizedNodeLayerWiseGranularity.h" -#include "VerifyQuantizedNodeChannelWiseGranularity.h" -#include "VerifyQuantizedNodeU8Type.h" -#include "VerifyQuantizedNodeS16Type.h" +#include "VerifyQuantizedNodeGranularity.h" +#include "VerifyQuantizedNodeType.h" +#include "VerifyQuantizedBiasScale.h" +#include "helpers/LayerInfoMap.h" #include <luci/IR/CircleNodes.h> #include <luci/IR/CircleNodeVisitor.h> @@ -28,12 +28,33 @@ namespace luci void QuantizedModelVerifier::verify(loco::Graph *g) { - if (_quantized_dtype != Type::U8 && _quantized_dtype != Type::S16) - throw std::runtime_error("Unsupported quantized dtype"); - - if (_granularity != Granularity::ChannelWise && _granularity != Granularity::LayerWise) + if (_ctx->granularity != Granularity::ChannelWise && _ctx->granularity != Granularity::LayerWise) throw std::runtime_error("Unsupported granularity"); + auto info_by_name = layer_info_map(g, _ctx->layers_info); + + auto quantize_dtype = [&](const luci::CircleNode *node) { + auto iter = info_by_name.find(node->name()); + + // Return designated quantization dtype + if (iter != info_by_name.end()) + return iter->second.dtype; + + // Return default quantization dtype + return _ctx->output_model_dtype; + }; + + auto quantize_granularity = [&](const luci::CircleNode *node) { + auto iter = info_by_name.find(node->name()); + + // Return designated quantization granularity + if (iter != info_by_name.end()) + return iter->second.granularity; + + // Return default quantization granularity + return _ctx->granularity; + }; + for (auto node : loco::active_nodes(loco::output_nodes(g))) { auto circle_node = loco::must_cast<luci::CircleNode *>(node); @@ -46,32 +67,17 @@ void QuantizedModelVerifier::verify(loco::Graph *g) }; // Verify Type - if (_quantized_dtype == Type::U8) - { - VerifyQuantizedNodeU8Type vt; - if (!circle_node->accept(&vt)) - throw std::runtime_error("Wrong data type detected in " + node_name()); - } - else if (_quantized_dtype == Type::S16) - { - VerifyQuantizedNodeS16Type vt; - if (!circle_node->accept(&vt)) - throw std::runtime_error("Wrong data type detected in " + node_name()); - } + if (!VerifyQuantizedNodeType::create(quantize_dtype(circle_node))->verify(circle_node)) + throw std::runtime_error("Wrong data type detected in " + node_name()); // Verify Granularity - if (_granularity == Granularity::LayerWise) - { - VerifyQuantizedNodeLayerWiseGranularity vg; - if (!circle_node->accept(&vg)) - throw std::runtime_error("Wrong granularity detected in " + node_name()); - } - else if (_granularity == Granularity::ChannelWise) - { - VerifyQuantizedNodeChannelWiseGranularity vg; - if (!circle_node->accept(&vg)) - throw std::runtime_error("Wrong granularity detected in " + node_name()); - } + if (!circle_node->accept( + VerifyQuantizedNodeGranularity::create(quantize_granularity(circle_node)).get())) + throw std::runtime_error("Wrong granularity detected in " + node_name()); + + // Verify Bias scale + if (!VerifyQuantizedBiasScale::create()->verify(circle_node)) + throw std::runtime_error("Wrong bias scale detected in " + node_name()); } } diff --git a/compiler/luci/pass/src/QuantizedModelVerifier.h b/compiler/luci/pass/src/QuantizedModelVerifier.h index d5fbb8e74..7409a51d7 100644 --- a/compiler/luci/pass/src/QuantizedModelVerifier.h +++ b/compiler/luci/pass/src/QuantizedModelVerifier.h @@ -21,6 +21,8 @@ #include <loco.h> +#include <memory> + namespace luci { @@ -31,18 +33,40 @@ namespace luci */ struct QuantizedModelVerifier { +public: + struct Context + { + loco::DataType output_model_dtype = loco::DataType::Unknown; + QuantizationGranularity granularity = QuantizationGranularity::ChannelWise; + loco::DataType input_type = loco::DataType::Unknown; + loco::DataType output_type = loco::DataType::Unknown; + bool TF_style_maxpool = false; + std::vector<LayerInfo> layers_info; + }; public: QuantizedModelVerifier(loco::DataType quantized_dtype, QuantizationGranularity granularity) - : _quantized_dtype(quantized_dtype), _granularity(granularity) { + _ctx = std::make_unique<Context>(); + { + _ctx->output_model_dtype = quantized_dtype; + _ctx->granularity = granularity; + _ctx->input_type = quantized_dtype; + _ctx->output_type = quantized_dtype; + _ctx->TF_style_maxpool = false; + } + } + +public: + QuantizedModelVerifier(std::unique_ptr<Context> &&ctx) : _ctx{std::move(ctx)} + { + // DO NOTHING } void verify(loco::Graph *g); private: - loco::DataType _quantized_dtype; - QuantizationGranularity _granularity; + std::unique_ptr<Context> _ctx; }; } // namespace luci diff --git a/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp b/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp index 3a6d86c33..cebafd32b 100644 --- a/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp +++ b/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp @@ -17,6 +17,7 @@ #include "QuantizedModelVerifier.h" #include "luci/Pass/QuantizeWithMinMaxPass.h" +#include "luci/Pass/QuantizationParameters.h" #include <luci/test/TestIOGraph.h> @@ -112,57 +113,77 @@ void quantize_and_verify(loco::Graph *g, Type quantized_dtype, Granularity granu verifier.verify(g); } -// Helper function to reduce duplicate test codes -// Assumption: g->output()->from() is the target node -void quantize_and_verify_with_wrong_type(luci::test::TestIOGraph *g, Type quantized_dtype, - Granularity granularity, Type wrong_dtype) +void quantize_and_verify_with_layer_info(loco::Graph *g, Type quantized_dtype, + Granularity granularity) { - luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity); - pass.run(g->g()); - - auto node = loco::must_cast<luci::CircleNode *>(g->output()->from()); - node->dtype(wrong_dtype); + // A layer named "test" has dtype different from quantized_dtype + luci::LayerInfo info; + { + info.name = "test"; + // dtype is different from quantized_dtype + info.dtype = quantized_dtype == Type::U8 ? Type::S16 : Type::U8; + info.granularity = Granularity::ChannelWise; + } - luci::QuantizedModelVerifier verifier(quantized_dtype, granularity); - verifier.verify(g->g()); -} + // Do quantization + { + auto ctx = std::make_unique<luci::QuantizeWithMinMaxPass::Context>(); + { + ctx->input_model_dtype = Type::FLOAT32; + ctx->output_model_dtype = quantized_dtype; + ctx->granularity = granularity; + ctx->input_type = quantized_dtype; + ctx->output_type = quantized_dtype; + ctx->TF_style_maxpool = false; + ctx->layers_info.push_back(info); + } -void quantize_and_verify_with_wrong_type(luci::test::TestIOGraph *g, Type quantized_dtype, - Granularity granularity, Type wrong_dtype, - luci::CircleNode *target) -{ - luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity); - pass.run(g->g()); + luci::QuantizeWithMinMaxPass pass(std::move(ctx)); + pass.run(g); + } - target->dtype(wrong_dtype); + // Do verification + { + auto ctx = std::make_unique<luci::QuantizedModelVerifier::Context>(); + { + ctx->output_model_dtype = quantized_dtype; + ctx->granularity = granularity; + ctx->input_type = quantized_dtype; + ctx->output_type = quantized_dtype; + ctx->TF_style_maxpool = false; + ctx->layers_info.push_back(info); + } - luci::QuantizedModelVerifier verifier(quantized_dtype, granularity); - verifier.verify(g->g()); + luci::QuantizedModelVerifier verifier(std::move(ctx)); + verifier.verify(g); + } } // Helper function to reduce duplicate test codes // Assumption: g->output()->from() is the target node -void quantize_and_verify_with_wrong_granularity(luci::test::TestIOGraph *g, Type quantized_dtype, - Granularity granularity) +void quantize_and_verify_with_wrong_type(luci::test::TestIOGraph *g, Type quantized_dtype, + Granularity granularity, Type wrong_dtype) { luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity); pass.run(g->g()); auto node = loco::must_cast<luci::CircleNode *>(g->output()->from()); - insert_scale_zp(node, 1.0, 1); + node->dtype(wrong_dtype); luci::QuantizedModelVerifier verifier(quantized_dtype, granularity); verifier.verify(g->g()); } // Helper function to reduce duplicate test codes +// Assumption: g->output()->from() is the target node void quantize_and_verify_with_wrong_granularity(luci::test::TestIOGraph *g, Type quantized_dtype, - Granularity granularity, luci::CircleNode *target) + Granularity granularity) { luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity); pass.run(g->g()); - insert_scale_zp(target, 1.0, 1); + auto node = loco::must_cast<luci::CircleNode *>(g->output()->from()); + insert_scale_zp(node, 1.0, 1); luci::QuantizedModelVerifier verifier(quantized_dtype, granularity); verifier.verify(g->g()); @@ -230,6 +251,8 @@ public: _instnorm->input(input()); _instnorm->gamma(_gamma); _instnorm->beta(_beta); + _instnorm->fusedActivationFunction(luci::FusedActFunc::NONE); + _instnorm->name("test"); } output()->from(_instnorm); @@ -256,6 +279,7 @@ public: _logistic = g()->nodes()->create<luci::CircleLogistic>(); { _logistic->x(input()); + _logistic->name("test"); } output()->from(_logistic); @@ -275,6 +299,7 @@ public: _lrn = g()->nodes()->create<luci::CircleLocalResponseNormalization>(); { _lrn->input(input()); + _lrn->name("test"); } output()->from(_lrn); @@ -295,6 +320,7 @@ public: { _softmax->logits(input()); _softmax->beta(0.1); + _softmax->name("test"); } output()->from(_softmax); @@ -324,6 +350,7 @@ public: _stob->input(input()); _stob->block_shape(_block_shape); _stob->paddings(_paddings); + _stob->name("test"); } output()->from(_stob); @@ -346,6 +373,7 @@ public: { _stod->input(input()); _stod->block_size(2); + _stod->name("test"); } output()->from(_stod); @@ -375,6 +403,7 @@ public: _slice->input(input()); _slice->begin(_begin); _slice->size(_size); + _slice->name("test"); } output()->from(_slice); @@ -472,6 +501,7 @@ public: _slice->begin(_begin); _slice->end(_end); _slice->strides(_strides); + _slice->name("test"); } output()->from(_slice); @@ -499,6 +529,7 @@ public: { _reshape->tensor(input()); _reshape->shape(_shape); + _reshape->name("test"); } output()->from(_reshape); @@ -519,6 +550,7 @@ public: _tanh = g()->nodes()->create<luci::CircleTanh>(); { _tanh->x(input()); + _tanh->name("test"); } output()->from(_tanh); @@ -538,6 +570,7 @@ public: _floor = g()->nodes()->create<luci::CircleFloor>(); { _floor->x(input()); + _floor->name("test"); } output()->from(_floor); @@ -601,6 +634,7 @@ public: _btos->input(input()); _btos->block_shape(_block_shape); _btos->crops(_crops); + _btos->name("test"); } output()->from(_btos); @@ -623,6 +657,7 @@ public: { _dtos->input(input()); _dtos->block_size(2); + _dtos->name("test"); } output()->from(_dtos); @@ -645,6 +680,7 @@ public: _pack->values(0, input()); _pack->values(1, _param); _pack->axis(0); + _pack->name("test"); } output()->from(_pack); @@ -680,6 +716,7 @@ public: { _pad->input(input()); _pad->paddings(_paddings); + _pad->name("test"); } output()->from(_pad); @@ -707,6 +744,7 @@ public: _pad->input(input()); _pad->paddings(_paddings); _pad->constant_values(_constant_values); + _pad->name("test"); } output()->from(_pad); @@ -735,6 +773,7 @@ public: _mirror_pad->input(input()); _mirror_pad->paddings(_paddings); _mirror_pad->mode(luci::MirrorPadMode::REFLECT); + _mirror_pad->name("test"); } output()->from(_mirror_pad); @@ -761,6 +800,7 @@ public: { _transpose->a(input()); _transpose->perm(_perm); + _transpose->name("test"); } output()->from(_transpose); @@ -784,6 +824,8 @@ public: _concat->values(0, input()); _concat->values(1, _param); _concat->axis(0); + _concat->fusedActivationFunction(luci::FusedActFunc::NONE); + _concat->name("test"); } output()->from(_concat); @@ -795,6 +837,54 @@ private: luci::CircleConst *_param = nullptr; }; +template <Type indexT> class OneHotTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({32}, {32, 10}); + { + // input dtype is float by default, but OneHot's input should have indexType (s32/s64) + input()->dtype(indexT); + } + + _depth = g()->nodes()->template create<luci::CircleConst>(); + { + _depth->dtype(loco::DataType::S32); + } + + _on_value = g()->nodes()->template create<luci::CircleConst>(); + { + _on_value->dtype(loco::DataType::FLOAT32); + } + + _off_value = g()->nodes()->template create<luci::CircleConst>(); + { + _off_value->dtype(loco::DataType::FLOAT32); + } + + _one_hot = g()->nodes()->template create<luci::CircleOneHot>(); + { + _one_hot->indices(input()); + _one_hot->depth(_depth); + _one_hot->on_value(_on_value); + _one_hot->off_value(_off_value); + _one_hot->axis(-1); + _one_hot->dtype(loco::DataType::FLOAT32); + _one_hot->name("test"); + } + output()->from(_one_hot); + + set_minmax_to_non_const(g(), -1, 1); + } + +private: + luci::CircleOneHot *_one_hot = nullptr; + luci::CircleConst *_depth = nullptr; + luci::CircleConst *_on_value = nullptr; + luci::CircleConst *_off_value = nullptr; +}; + // Test graph for comparison Ops // GREATER, GREATER_EQUAL, LESS, LESS_EQUAL, EQUAL, NOT_EQUAL template <class Op> class ComparisonOpTestGraph final : public SimpleTestGraph @@ -866,6 +956,7 @@ public: { _div->x(input()); _div->y(_const); + _div->name("test"); } output()->from(_div); @@ -893,6 +984,7 @@ public: { _floor_div->x(input()); _floor_div->y(_const); + _floor_div->name("test"); } output()->from(_floor_div); @@ -917,6 +1009,7 @@ public: _rsqrt = g()->nodes()->create<luci::CircleRsqrt>(); { _rsqrt->x(input()); + _rsqrt->name("test"); } output()->from(_rsqrt); @@ -936,6 +1029,7 @@ public: _sqrt = g()->nodes()->create<luci::CircleSqrt>(); { _sqrt->x(input()); + _sqrt->name("test"); } output()->from(_sqrt); @@ -955,6 +1049,7 @@ public: _elu = g()->nodes()->create<luci::CircleElu>(); { _elu->features(input()); + _elu->name("test"); } output()->from(_elu); @@ -977,6 +1072,7 @@ public: { _pow->x(input()); _pow->y(_const); + _pow->name("test"); } output()->from(_pow); @@ -1004,6 +1100,7 @@ public: { _resize_bilinear->input(input()); _resize_bilinear->size(_size); + _resize_bilinear->name("test"); } output()->from(_resize_bilinear); @@ -1027,6 +1124,7 @@ public: { _resize_nearest_neighbor->input(input()); _resize_nearest_neighbor->size(_size); + _resize_nearest_neighbor->name("test"); } output()->from(_resize_nearest_neighbor); @@ -1067,6 +1165,62 @@ private: luci::CircleConst *_unpack_dim = nullptr; }; +class MulTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({32}, {32}); + + _const = create_dummy_const<Type::FLOAT32>(g(), {32}); + _mul = g()->nodes()->create<luci::CircleMul>(); + { + _mul->x(input()); + _mul->y(_const); + _mul->fusedActivationFunction(luci::FusedActFunc::NONE); + _mul->name("test"); + } + output()->from(_mul); + + set_minmax_to_non_const(g(), -1, 1); + } + + loco::Node *x() { return _mul->x(); } + loco::Node *y() { return _mul->y(); } + +private: + luci::CircleMul *_mul = nullptr; + luci::CircleConst *_const = nullptr; +}; + +class AddTestGraph final : public SimpleTestGraph +{ +public: + void init(void) override + { + TestIOGraph::init({32}, {32}); + + _const = create_dummy_const<Type::FLOAT32>(g(), {32}); + _add = g()->nodes()->create<luci::CircleAdd>(); + { + _add->x(input()); + _add->y(_const); + _add->fusedActivationFunction(luci::FusedActFunc::NONE); + _add->name("test"); + } + output()->from(_add); + + set_minmax_to_non_const(g(), -1, 1); + } + + loco::Node *x() { return _add->x(); } + loco::Node *y() { return _add->y(); } + +private: + luci::CircleAdd *_add = nullptr; + luci::CircleConst *_const = nullptr; +}; + } // namespace // Quantize and verify with given configurations @@ -1078,6 +1232,15 @@ private: EXPECT_NO_THROW(quantize_and_verify(g.g(), type, granularity)); \ } while (0) +// Quantize and verify with layer info +#define TEST_WITH_LAYER_INFO(graph, type, granularity) \ + do \ + { \ + graph g; \ + g.init(); \ + EXPECT_NO_THROW(quantize_and_verify_with_layer_info(g.g(), type, granularity)); \ + } while (0) + // Quantize and verify with wrong type #define TEST_WITH_WRONG_TYPE(graph, type, granularity, wrong_dtype) \ do \ @@ -1098,25 +1261,34 @@ private: // Quantize and verify with wrong type // Users can specify the test target -#define TEST_WITH_WRONG_TYPE_TARGET(graph, type, granularity, wrong_dtype, target) \ - do \ - { \ - graph g; \ - g.init(); \ - auto node = loco::must_cast<luci::CircleNode *>(target); \ - EXPECT_ANY_THROW( \ - quantize_and_verify_with_wrong_type(&g, type, granularity, wrong_dtype, node)); \ +#define TEST_WITH_WRONG_TYPE_TARGET(graph, type, granularity, wrong_dtype, target) \ + do \ + { \ + graph g; \ + g.init(); \ + auto node = loco::must_cast<luci::CircleNode *>(target); \ + luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, type, granularity); \ + pass.run(g.g()); \ + auto after_node = loco::must_cast<luci::CircleNode *>(target); \ + after_node->dtype(wrong_dtype); \ + luci::QuantizedModelVerifier verifier(type, granularity); \ + EXPECT_ANY_THROW(verifier.verify(g.g())); \ } while (0) // Quantize and verify with wrong granularity // Users can specify the test target -#define TEST_WITH_WRONG_GRANULARITY_TARGET(graph, type, granularity, target) \ - do \ - { \ - graph g; \ - g.init(); \ - auto node = loco::must_cast<luci::CircleNode *>(target); \ - EXPECT_ANY_THROW(quantize_and_verify_with_wrong_granularity(&g, type, granularity, node)); \ +#define TEST_WITH_WRONG_GRANULARITY_TARGET(graph, type, granularity, target) \ + do \ + { \ + graph g; \ + g.init(); \ + auto node = loco::must_cast<luci::CircleNode *>(target); \ + luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, type, granularity); \ + pass.run(g.g()); \ + auto after_node = loco::must_cast<luci::CircleNode *>(target); \ + insert_scale_zp(after_node, 1.0, 1); \ + luci::QuantizedModelVerifier verifier(type, granularity); \ + EXPECT_ANY_THROW(verifier.verify(g.g())); \ } while (0) // Test a local helper function @@ -1145,6 +1317,10 @@ TEST(QuantizedModelVerifierTest, InstanceNorm) TEST_WITH_GRAPH(InstanceNormTestGraph, Type::U8, Granularity::LayerWise); TEST_WITH_GRAPH(InstanceNormTestGraph, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(InstanceNormTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(InstanceNormTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(InstanceNormTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(InstanceNormTestGraph, Type::S16, Granularity::ChannelWise); SUCCEED(); } @@ -1169,6 +1345,10 @@ TEST(QuantizedModelVerifierTest, LocalResponseNormalization) TEST_WITH_GRAPH(LocalResponseNormalizationTestGraph, Type::U8, Granularity::LayerWise); TEST_WITH_GRAPH(LocalResponseNormalizationTestGraph, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(LocalResponseNormalizationTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(LocalResponseNormalizationTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(LocalResponseNormalizationTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(LocalResponseNormalizationTestGraph, Type::S16, Granularity::ChannelWise); SUCCEED(); } @@ -1199,6 +1379,10 @@ TEST(QuantizedModelVerifierTest, Logistic) TEST_WITH_GRAPH(LogisticTestGraph, Type::U8, Granularity::LayerWise); TEST_WITH_GRAPH(LogisticTestGraph, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(LogisticTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(LogisticTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(LogisticTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(LogisticTestGraph, Type::S16, Granularity::ChannelWise); SUCCEED(); } @@ -1223,6 +1407,10 @@ TEST(QuantizedModelVerifierTest, Softmax) TEST_WITH_GRAPH(SoftmaxTestGraph, Type::U8, Granularity::LayerWise); TEST_WITH_GRAPH(SoftmaxTestGraph, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(SoftmaxTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(SoftmaxTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(SoftmaxTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(SoftmaxTestGraph, Type::S16, Granularity::ChannelWise); SUCCEED(); } @@ -1247,6 +1435,10 @@ TEST(QuantizedModelVerifierTest, SpaceToBatchND) TEST_WITH_GRAPH(SpaceToBatchNDTestGraph, Type::U8, Granularity::LayerWise); TEST_WITH_GRAPH(SpaceToBatchNDTestGraph, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(SpaceToBatchNDTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(SpaceToBatchNDTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(SpaceToBatchNDTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(SpaceToBatchNDTestGraph, Type::S16, Granularity::ChannelWise); SUCCEED(); } @@ -1271,6 +1463,10 @@ TEST(QuantizedModelVerifierTest, SpaceToDepth) TEST_WITH_GRAPH(SpaceToDepthTestGraph, Type::U8, Granularity::LayerWise); TEST_WITH_GRAPH(SpaceToDepthTestGraph, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(SpaceToDepthTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(SpaceToDepthTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(SpaceToDepthTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(SpaceToDepthTestGraph, Type::S16, Granularity::ChannelWise); SUCCEED(); } @@ -1299,6 +1495,14 @@ TEST(QuantizedModelVerifierTest, Slice) TEST_WITH_GRAPH(SliceTestGraph<Type::S64>, Type::U8, Granularity::LayerWise); TEST_WITH_GRAPH(SliceTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(SliceTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(SliceTestGraph<Type::S32>, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(SliceTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(SliceTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(SliceTestGraph<Type::S64>, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(SliceTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(SliceTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise); SUCCEED(); } @@ -1379,6 +1583,10 @@ TEST(QuantizedModelVerifierTest, StridedSlice) TEST_WITH_GRAPH(StridedSliceTestGraph, Type::U8, Granularity::LayerWise); TEST_WITH_GRAPH(StridedSliceTestGraph, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(StridedSliceTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(StridedSliceTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(StridedSliceTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(StridedSliceTestGraph, Type::S16, Granularity::ChannelWise); SUCCEED(); } @@ -1463,6 +1671,10 @@ TEST(QuantizedModelVerifierTest, BatchToSpaceND) TEST_WITH_GRAPH(BatchToSpaceNDTestGraph, Type::U8, Granularity::LayerWise); TEST_WITH_GRAPH(BatchToSpaceNDTestGraph, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(BatchToSpaceNDTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(BatchToSpaceNDTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(BatchToSpaceNDTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(BatchToSpaceNDTestGraph, Type::S16, Granularity::ChannelWise); SUCCEED(); } @@ -1487,6 +1699,10 @@ TEST(QuantizedModelVerifierTest, DepthToSpace) TEST_WITH_GRAPH(DepthToSpaceTestGraph, Type::U8, Granularity::LayerWise); TEST_WITH_GRAPH(DepthToSpaceTestGraph, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(DepthToSpaceTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(DepthToSpaceTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(DepthToSpaceTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(DepthToSpaceTestGraph, Type::S16, Granularity::ChannelWise); SUCCEED(); } @@ -1511,6 +1727,10 @@ TEST(QuantizedModelVerifierTest, Concatenation) TEST_WITH_GRAPH(ConcatenationTestGraph, Type::U8, Granularity::LayerWise); TEST_WITH_GRAPH(ConcatenationTestGraph, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(ConcatenationTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(ConcatenationTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(ConcatenationTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(ConcatenationTestGraph, Type::S16, Granularity::ChannelWise); SUCCEED(); } @@ -1557,6 +1777,10 @@ TEST(QuantizedModelVerifierTest, Reshape) TEST_WITH_GRAPH(ReshapeTestGraph, Type::U8, Granularity::LayerWise); TEST_WITH_GRAPH(ReshapeTestGraph, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(ReshapeTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(ReshapeTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(ReshapeTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(ReshapeTestGraph, Type::S16, Granularity::ChannelWise); SUCCEED(); } @@ -1581,6 +1805,10 @@ TEST(QuantizedModelVerifierTest, Tanh) TEST_WITH_GRAPH(TanhTestGraph, Type::U8, Granularity::LayerWise); TEST_WITH_GRAPH(TanhTestGraph, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(TanhTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(TanhTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(TanhTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(TanhTestGraph, Type::S16, Granularity::ChannelWise); SUCCEED(); } @@ -1606,6 +1834,10 @@ TEST(QuantizedModelVerifierTest, Pack) TEST_WITH_GRAPH(PackTestGraph, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(PackTestGraph, Type::S16, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(PackTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(PackTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(PackTestGraph, Type::S16, Granularity::ChannelWise); + // Test if Pack's qparam is propagated to the input { PackTestGraph g; @@ -1640,6 +1872,10 @@ TEST(QuantizedModelVerifierTest, Pad) TEST_WITH_GRAPH(PadTestGraph, Type::U8, Granularity::LayerWise); TEST_WITH_GRAPH(PadTestGraph, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(PadTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(PadTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(PadTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(PadTestGraph, Type::S16, Granularity::ChannelWise); SUCCEED(); } @@ -1664,6 +1900,10 @@ TEST(QuantizedModelVerifierTest, PadV2) TEST_WITH_GRAPH(PadV2TestGraph, Type::U8, Granularity::LayerWise); TEST_WITH_GRAPH(PadV2TestGraph, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(PadV2TestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(PadV2TestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(PadV2TestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(PadV2TestGraph, Type::S16, Granularity::ChannelWise); SUCCEED(); } @@ -1688,6 +1928,10 @@ TEST(QuantizedModelVerifierTest, MirrorPad) TEST_WITH_GRAPH(MirrorPadTestGraph, Type::U8, Granularity::LayerWise); TEST_WITH_GRAPH(MirrorPadTestGraph, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(MirrorPadTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(MirrorPadTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(MirrorPadTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(MirrorPadTestGraph, Type::S16, Granularity::ChannelWise); SUCCEED(); } @@ -1712,6 +1956,10 @@ TEST(QuantizedModelVerifierTest, Transpose) TEST_WITH_GRAPH(TransposeTestGraph, Type::U8, Granularity::LayerWise); TEST_WITH_GRAPH(TransposeTestGraph, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(TransposeTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(TransposeTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(TransposeTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(TransposeTestGraph, Type::S16, Granularity::ChannelWise); SUCCEED(); } @@ -1736,6 +1984,10 @@ TEST(QuantizedModelVerifierTest, Floor) TEST_WITH_GRAPH(FloorTestGraph, Type::U8, Granularity::LayerWise); TEST_WITH_GRAPH(FloorTestGraph, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(FloorTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(FloorTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(FloorTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(FloorTestGraph, Type::S16, Granularity::ChannelWise); SUCCEED(); } @@ -1869,11 +2121,59 @@ TEST(QuantizedModelVerifierTest, NotEqual_wrong_granularity_NEG) SUCCEED(); } +TEST(QuantizedModelVerifierTest, OneHot) +{ + TEST_WITH_GRAPH(OneHotTestGraph<Type::S32>, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(OneHotTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(OneHotTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise); + + TEST_WITH_GRAPH(OneHotTestGraph<Type::S64>, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(OneHotTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(OneHotTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(OneHotTestGraph<Type::S32>, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(OneHotTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(OneHotTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(OneHotTestGraph<Type::S64>, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(OneHotTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(OneHotTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, OneHot_wrong_input_type_NEG) +{ + TEST_WITH_WRONG_TYPE(OneHotTestGraph<Type::S32>, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(OneHotTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(OneHotTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise, Type::U8); + + TEST_WITH_WRONG_TYPE(OneHotTestGraph<Type::S64>, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(OneHotTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(OneHotTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, OneHot_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY(OneHotTestGraph<Type::S32>, Type::U8, Granularity::LayerWise); + TEST_WITH_WRONG_GRANULARITY(OneHotTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise); + TEST_WITH_WRONG_GRANULARITY(OneHotTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise); + + TEST_WITH_WRONG_GRANULARITY(OneHotTestGraph<Type::S64>, Type::U8, Granularity::LayerWise); + TEST_WITH_WRONG_GRANULARITY(OneHotTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise); + TEST_WITH_WRONG_GRANULARITY(OneHotTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + TEST(QuantizedModelVerifierTest, Div) { TEST_WITH_GRAPH(DivTestGraph, Type::U8, Granularity::LayerWise); TEST_WITH_GRAPH(DivTestGraph, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(DivTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(DivTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(DivTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(DivTestGraph, Type::S16, Granularity::ChannelWise); SUCCEED(); } @@ -1902,6 +2202,10 @@ TEST(QuantizedModelVerifierTest, FloorDiv) TEST_WITH_GRAPH(FloorDivTestGraph, Type::U8, Granularity::LayerWise); TEST_WITH_GRAPH(FloorDivTestGraph, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(FloorDivTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(FloorDivTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(FloorDivTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(FloorDivTestGraph, Type::S16, Granularity::ChannelWise); SUCCEED(); } @@ -1930,6 +2234,10 @@ TEST(QuantizedModelVerifierTest, Rsqrt) TEST_WITH_GRAPH(RsqrtTestGraph, Type::U8, Granularity::LayerWise); TEST_WITH_GRAPH(RsqrtTestGraph, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(RsqrtTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(RsqrtTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(RsqrtTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(RsqrtTestGraph, Type::S16, Granularity::ChannelWise); SUCCEED(); } @@ -1954,6 +2262,10 @@ TEST(QuantizedModelVerifierTest, Sqrt) TEST_WITH_GRAPH(SqrtTestGraph, Type::U8, Granularity::LayerWise); TEST_WITH_GRAPH(SqrtTestGraph, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(SqrtTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(SqrtTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(SqrtTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(SqrtTestGraph, Type::S16, Granularity::ChannelWise); SUCCEED(); } @@ -1978,6 +2290,10 @@ TEST(QuantizedModelVerifierTest, Elu) TEST_WITH_GRAPH(EluTestGraph, Type::U8, Granularity::LayerWise); TEST_WITH_GRAPH(EluTestGraph, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(EluTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(EluTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(EluTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(EluTestGraph, Type::S16, Granularity::ChannelWise); SUCCEED(); } @@ -2002,6 +2318,10 @@ TEST(QuantizedModelVerifierTest, Pow) TEST_WITH_GRAPH(PowTestGraph, Type::U8, Granularity::LayerWise); TEST_WITH_GRAPH(PowTestGraph, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(PowTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(PowTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(PowTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(PowTestGraph, Type::S16, Granularity::ChannelWise); SUCCEED(); } @@ -2030,6 +2350,10 @@ TEST(QuantizedModelVerifierTest, ResizeBilinear) TEST_WITH_GRAPH(ResizeBilinearTestGraph, Type::U8, Granularity::LayerWise); TEST_WITH_GRAPH(ResizeBilinearTestGraph, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(ResizeBilinearTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(ResizeBilinearTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(ResizeBilinearTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(ResizeBilinearTestGraph, Type::S16, Granularity::ChannelWise); SUCCEED(); } @@ -2054,6 +2378,10 @@ TEST(QuantizedModelVerifierTest, ResizeNearestNeighbor) TEST_WITH_GRAPH(ResizeNearestNeighborTestGraph, Type::U8, Granularity::LayerWise); TEST_WITH_GRAPH(ResizeNearestNeighborTestGraph, Type::U8, Granularity::ChannelWise); TEST_WITH_GRAPH(ResizeNearestNeighborTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(ResizeBilinearTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(ResizeBilinearTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(ResizeBilinearTestGraph, Type::S16, Granularity::ChannelWise); SUCCEED(); } @@ -2099,6 +2427,93 @@ TEST(QuantizedModelVerifierTest, Unpack_wrong_granularity_NEG) SUCCEED(); } +TEST(QuantizedModelVerifierTest, Add) +{ + TEST_WITH_GRAPH(AddTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(AddTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(AddTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(AddTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(AddTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(AddTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Add_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(AddTestGraph, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(AddTestGraph, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(AddTestGraph, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Add_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY_TARGET(AddTestGraph, Type::U8, Granularity::LayerWise, g.x()); + TEST_WITH_WRONG_GRANULARITY_TARGET(AddTestGraph, Type::U8, Granularity::ChannelWise, g.x()); + TEST_WITH_WRONG_GRANULARITY_TARGET(AddTestGraph, Type::S16, Granularity::ChannelWise, g.x()); + + TEST_WITH_WRONG_GRANULARITY_TARGET(AddTestGraph, Type::U8, Granularity::LayerWise, g.y()); + TEST_WITH_WRONG_GRANULARITY_TARGET(AddTestGraph, Type::U8, Granularity::ChannelWise, g.y()); + TEST_WITH_WRONG_GRANULARITY_TARGET(AddTestGraph, Type::S16, Granularity::ChannelWise, g.y()); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Mul) +{ + TEST_WITH_GRAPH(MulTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_GRAPH(MulTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_GRAPH(MulTestGraph, Type::S16, Granularity::ChannelWise); + + TEST_WITH_LAYER_INFO(MulTestGraph, Type::U8, Granularity::LayerWise); + TEST_WITH_LAYER_INFO(MulTestGraph, Type::U8, Granularity::ChannelWise); + TEST_WITH_LAYER_INFO(MulTestGraph, Type::S16, Granularity::ChannelWise); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Mul_wrong_type_NEG) +{ + TEST_WITH_WRONG_TYPE(MulTestGraph, Type::U8, Granularity::LayerWise, Type::S16); + TEST_WITH_WRONG_TYPE(MulTestGraph, Type::U8, Granularity::ChannelWise, Type::S16); + TEST_WITH_WRONG_TYPE(MulTestGraph, Type::S16, Granularity::ChannelWise, Type::U8); + SUCCEED(); +} + +TEST(QuantizedModelVerifierTest, Mul_wrong_granularity_NEG) +{ + TEST_WITH_WRONG_GRANULARITY_TARGET(MulTestGraph, Type::U8, Granularity::LayerWise, g.x()); + TEST_WITH_WRONG_GRANULARITY_TARGET(MulTestGraph, Type::U8, Granularity::ChannelWise, g.x()); + TEST_WITH_WRONG_GRANULARITY_TARGET(MulTestGraph, Type::S16, Granularity::ChannelWise, g.x()); + + TEST_WITH_WRONG_GRANULARITY_TARGET(MulTestGraph, Type::U8, Granularity::LayerWise, g.y()); + TEST_WITH_WRONG_GRANULARITY_TARGET(MulTestGraph, Type::U8, Granularity::ChannelWise, g.y()); + TEST_WITH_WRONG_GRANULARITY_TARGET(MulTestGraph, Type::S16, Granularity::ChannelWise, g.y()); + SUCCEED(); +} + +// TODO Add following testcases +// +// CircleConv2D +// +// CircleDepthwiseConv2D +// +// CirclePRelu +// +// CircleTransposeConv +// +// CircleFullyConnected +// +// CircleAveragePool2D +// +// CircleMaxPool2D +// +// CircleMean +// +// CircleRelu +// +// CircleCast +// + #undef TEST_WITH_GRAPH #undef TEST_WITH_WRONG_TYPE #undef TEST_WITH_WRONG_GRANULARITY diff --git a/compiler/luci/pass/src/RemoveRedundantQuantizePass.cpp b/compiler/luci/pass/src/RemoveRedundantQuantizePass.cpp new file mode 100644 index 000000000..8a10ad4a0 --- /dev/null +++ b/compiler/luci/pass/src/RemoveRedundantQuantizePass.cpp @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/RemoveRedundantQuantizePass.h" + +#include <luci/IR/CircleNode.h> + +/** + * Remove redundant quantize operations. For subsequent Quantize Ops, + * only the last Quantize Op is valid, so we can remove the rest of the Quantize Op. + * + * BEFORE + * [CircleNode_1] + * | + * [CircleQuantize, dtype_1, scale_1, zero_point_1] + * | + * [CircleQuantize, dtype_2, scale_2, zero_point_2] + * | + * [CircleNode_2] + * + * AFTER + * [CircleNode_1] + * / \ + * / \ + * / \ + * / \ + * / \ + * [CircleQuantize, dtype_2, scale_2, zero_point_2] [CircleQuantize, dtype_1, scale_1, zero_point_1] + * | + * [CircleNode_2] + * + */ + +namespace +{ + +bool remove_redundant_quantize(luci::CircleQuantize *node) +{ + auto pred_node = loco::must_cast<luci::CircleNode *>(node->input()); + + if (node->quantparam() == nullptr or pred_node->quantparam() == nullptr) + return false; + + if (node->quantparam()->scale.size() != 1 or node->quantparam()->zerop.size() != 1 or + pred_node->quantparam()->scale.size() != 1 or pred_node->quantparam()->zerop.size() != 1) + { + return false; + } + + if (node->dtype() != pred_node->dtype() or + pred_node->quantparam()->scale.at(0) != node->quantparam()->scale.at(0) or + pred_node->quantparam()->zerop.at(0) != node->quantparam()->zerop.at(0)) + { + return false; + } + + replace(node).with(pred_node); + + return true; +} + +bool remove_redundant_subsequent_quantize(luci::CircleQuantize *node) +{ + auto pred_node = dynamic_cast<luci::CircleQuantize *>(node->input()); + if (pred_node == nullptr) + return remove_redundant_quantize(node); + + node->input(pred_node->input()); + return true; +} + +} // namespace + +namespace luci +{ + +bool RemoveRedundantQuantizePass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::postorder_traversal(loco::output_nodes(g))) + { + if (auto quantize_node = dynamic_cast<luci::CircleQuantize *>(node)) + { + if (remove_redundant_subsequent_quantize(quantize_node)) + changed = true; + } + } + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/RemoveRedundantQuantizePass.test.cpp b/compiler/luci/pass/src/RemoveRedundantQuantizePass.test.cpp new file mode 100644 index 000000000..d0166bd20 --- /dev/null +++ b/compiler/luci/pass/src/RemoveRedundantQuantizePass.test.cpp @@ -0,0 +1,166 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/RemoveRedundantQuantizePass.h" + +#include <luci/IR/CircleNodes.h> + +#include <luci/test/TestIOGraph.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class QuantizeGraphlet +{ +public: + QuantizeGraphlet() = default; + +public: + void init(loco::Graph *g) + { + _first_quantize = g->nodes()->create<luci::CircleQuantize>(); + _first_quantize->dtype(loco::DataType::U8); + { + auto quantize_param = std::make_unique<luci::CircleQuantParam>(); + quantize_param->scale = {0.5}; + quantize_param->zerop = {0}; + _first_quantize->quantparam(std::move(quantize_param)); + } + _first_quantize->name("first_quantize"); + + _second_quantize = g->nodes()->create<luci::CircleQuantize>(); + _second_quantize->dtype(loco::DataType::U8); + { + auto quantize_param = std::make_unique<luci::CircleQuantParam>(); + quantize_param->scale = {0.5}; + quantize_param->zerop = {0}; + _second_quantize->quantparam(std::move(quantize_param)); + } + _second_quantize->name("second_quantize"); + } + +protected: + luci::CircleQuantize *_first_quantize = nullptr; + luci::CircleQuantize *_second_quantize = nullptr; +}; + +class RedundantSubsequentQuantizeGraph : public TestIOGraph, public QuantizeGraphlet +{ +public: + RedundantSubsequentQuantizeGraph() = default; + +public: + void init(void) + { + TestIOGraph::init({1}, {1}); + QuantizeGraphlet::init(g()); + + input()->dtype(loco::DataType::U8); + { + auto quantize_param = std::make_unique<luci::CircleQuantParam>(); + quantize_param->scale = {1}; + quantize_param->zerop = {1}; + input()->quantparam(std::move(quantize_param)); + } + + _first_quantize->input(input()); + _second_quantize->input(_first_quantize); + + output()->from(_second_quantize); + output()->dtype(loco::DataType::U8); + } +}; + +class RedundantQuantizeGraph : public TestIOGraph, public QuantizeGraphlet +{ +public: + RedundantQuantizeGraph() = default; + +public: + void init(void) + { + TestIOGraph::init({1}, {1}); + QuantizeGraphlet::init(g()); + + input()->dtype(loco::DataType::U8); + { + auto quantize_param = std::make_unique<luci::CircleQuantParam>(); + quantize_param->scale = {0.5}; + quantize_param->zerop = {0}; + input()->quantparam(std::move(quantize_param)); + } + + _first_quantize->input(input()); + + output()->from(_first_quantize); + output()->dtype(loco::DataType::U8); + } +}; + +} // namespace + +TEST(RemoveRedundantQuantizePass, name) +{ + luci::RemoveRedundantQuantizePass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST(RemoveRedundantQuantizePass, remove_subsequent_quantize) +{ + RedundantSubsequentQuantizeGraph g; + luci::RemoveRedundantQuantizePass pass; + + g.init(); + + EXPECT_TRUE(pass.run(g.g())); + + int count = 0; + for (auto node : loco::active_nodes(loco::output_nodes(g.g()))) + { + if (dynamic_cast<luci::CircleQuantize *>(node)) + { + count++; + } + } + + ASSERT_EQ(1, count); +} + +TEST(RemoveRedundantQuantizePass, remove_quantize) +{ + RedundantQuantizeGraph g; + luci::RemoveRedundantQuantizePass pass; + + g.init(); + + EXPECT_TRUE(pass.run(g.g())); + + int count = 0; + for (auto node : loco::active_nodes(loco::output_nodes(g.g()))) + { + if (dynamic_cast<luci::CircleQuantize *>(node)) + { + count++; + } + } + + ASSERT_EQ(0, count); +} diff --git a/compiler/luci/pass/src/RemoveRedundantTransposePass.cpp b/compiler/luci/pass/src/RemoveRedundantTransposePass.cpp index 71c51ecda..75cf72795 100644 --- a/compiler/luci/pass/src/RemoveRedundantTransposePass.cpp +++ b/compiler/luci/pass/src/RemoveRedundantTransposePass.cpp @@ -71,7 +71,7 @@ bool remove_consecutive_transpose_function(luci::CircleTranspose *target_node) for (uint32_t i = 0; i < pred_perm->size<loco::DataType::S32>(); i++) { new_const_node->at<loco::DataType::S32>(i) = - target_perm->at<loco::DataType::S32>(pred_perm->at<loco::DataType::S32>(i)); + pred_perm->at<loco::DataType::S32>(target_perm->at<loco::DataType::S32>(i)); } new_const_node->name(name + "/Transpose/perm"); diff --git a/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp b/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp index e80623499..bb8e292d4 100644 --- a/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp +++ b/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp @@ -271,6 +271,31 @@ TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type2) ASSERT_EQ(2, perm->at<loco::DataType::S32>(3)); } +TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type3) +{ + auto graph = loco::make_graph(); + create_redundunt_transpose(graph.get(), {0, 3, 2, 1}, {0, 2, 3, 1}); + + luci::RemoveRedundantTransposePass pass; + while (pass.run(graph.get())) + ; + luci::CircleTranspose *transpose_node = nullptr; + for (auto node : loco::active_nodes(loco::output_nodes(graph.get()))) + { + auto trans = dynamic_cast<luci::CircleTranspose *>(node); + if (not trans) + continue; + transpose_node = trans; + break; + } + ASSERT_NE(nullptr, transpose_node); + auto perm = loco::must_cast<luci::CircleConst *>(transpose_node->perm()); + ASSERT_EQ(0, perm->at<loco::DataType::S32>(0)); + ASSERT_EQ(2, perm->at<loco::DataType::S32>(1)); + ASSERT_EQ(1, perm->at<loco::DataType::S32>(2)); + ASSERT_EQ(3, perm->at<loco::DataType::S32>(3)); +} + /** * @brief Test case that first transpose output become input of operations more than one. */ diff --git a/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.cpp b/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.cpp index 3f0c4ee82..fb46f490d 100644 --- a/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.cpp +++ b/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.cpp @@ -58,6 +58,25 @@ bool remove_no_effect_reshape(luci::CircleNode *node) namespace luci { +/** + * BEFORE + * [CircleNode] + * | + * [CircleReshape] + * | + * [CircleNode] + * + * AFTER + * [CircleNode] + * | \ + * | [CircleReshape] + * | + * [CircleNode] + * + * NOTE + * This pass will remove Reshape when input and output has same shape + */ + bool RemoveUnnecessaryReshapePass::run(loco::Graph *g) { bool changed = false; diff --git a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp index a0cc0194f..bca0a9483 100644 --- a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp +++ b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp @@ -26,8 +26,17 @@ namespace luci::CircleConst *create_weights_from_gamma(luci::CircleConst *gamma) { - assert(gamma->rank() == 1); - auto channel_size = gamma->dim(0).value(); + assert(gamma->rank() == 1 or gamma->rank() == 4); + + uint32_t channel_idx = gamma->rank() - 1; + uint32_t channel_size = gamma->dim(channel_idx).value(); + + // Gamma should be broadcastable in the channel direction + for (uint32_t i = 0; i < gamma->rank(); i++) + { + if (i != channel_idx) + assert(gamma->dim(i).value() == 1); // FIX is_batchnorm_mul UNLESS + } auto name = gamma->name(); assert(name.length() > 0); @@ -53,8 +62,17 @@ luci::CircleConst *create_weights_from_gamma(luci::CircleConst *gamma) luci::CircleConst *create_bias_from_beta(luci::CircleConst *beta) { - assert(beta->rank() == 1); - auto channel_size = beta->dim(0).value(); + assert(beta->rank() == 1 or beta->rank() == 4); + + uint32_t channel_idx = beta->rank() - 1; + uint32_t channel_size = beta->dim(channel_idx).value(); + + // Beta should be broadcastable in the channel direction + for (uint32_t i = 0; i < beta->rank(); i++) + { + if (i != channel_idx) + assert(beta->dim(i).value() == 1); // FIX is_batchnorm_add UNLESS + } auto name = beta->name(); assert(name.length() > 0); diff --git a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp index 903d4dcc9..bac033112 100644 --- a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp +++ b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp @@ -141,6 +141,37 @@ TEST(ReplaceMulAddWithDepthwiseConv, simple) } } +TEST(ReplaceMulAddWithDepthwiseConv, simple_rank4) +{ + SimpleGraph g; + + const uint32_t channel_size = 16; + g.gamma->shape({1, 1, 1, channel_size}); + g.beta->shape({1, 1, 1, channel_size}); + + luci::ReplaceMulAddWithDepthwiseConvPass pass; + while (pass.run(&g.g)) + ; + + auto dwconv = dynamic_cast<luci::CircleDepthwiseConv2D *>(g.output->from()); + EXPECT_NE(nullptr, dwconv); + + auto weights = dynamic_cast<luci::CircleConst *>(dwconv->filter()); + auto bias = dynamic_cast<luci::CircleConst *>(dwconv->bias()); + EXPECT_NE(nullptr, weights); + EXPECT_EQ(4, weights->rank()); + EXPECT_EQ(channel_size, weights->dim(3).value()); + EXPECT_NE(nullptr, bias); + EXPECT_EQ(1, bias->rank()); + EXPECT_EQ(channel_size, bias->dim(0).value()); + + for (int i = 0; i < channel_size; i++) + { + EXPECT_FLOAT_EQ(i, weights->at<loco::DataType::FLOAT32>(i)); + EXPECT_FLOAT_EQ(i, bias->at<loco::DataType::FLOAT32>(i)); + } +} + TEST(ReplaceMulAddWithDepthwiseConv, wrong_op_NEG) { SimpleGraph g; @@ -154,3 +185,18 @@ TEST(ReplaceMulAddWithDepthwiseConv, wrong_op_NEG) EXPECT_EQ(false, changed); } + +TEST(ReplaceMulAddWithDepthwiseConv, rank3_NEG) +{ + SimpleGraph g; + + g.input->shape({4, 4, 16}); + g.mul->shape({4, 4, 16}); + g.add->shape({4, 4, 16}); + g.output->shape({4, 4, 16}); + + luci::ReplaceMulAddWithDepthwiseConvPass pass; + auto changed = pass.run(&g.g); + + EXPECT_EQ(false, changed); +} diff --git a/compiler/luci/pass/src/SubstituteSplitVToSplitPass.cpp b/compiler/luci/pass/src/SubstituteSplitVToSplitPass.cpp index 9cba9a9e7..57c386d99 100644 --- a/compiler/luci/pass/src/SubstituteSplitVToSplitPass.cpp +++ b/compiler/luci/pass/src/SubstituteSplitVToSplitPass.cpp @@ -24,15 +24,6 @@ namespace { -void copy_quantparam(luci::CircleNode *dst, const luci::CircleNode *src) -{ - auto q = src->quantparam(); - if (q == nullptr) - dst->quantparam(nullptr); - else - dst->quantparam(std::make_unique<luci::CircleQuantParam>(*q)); -} - // SplitV is substituted to Split if the contents of size_splits are all same // For example, // size_splits = [32, 32] -> substitute @@ -67,7 +58,7 @@ bool resolve_splitv(luci::CircleSplitV *sv) split_node->split_dim(sv->split_dim()); split_node->num_split(sv->num_split()); split_node->name(sv->name()); - copy_quantparam(split_node, sv); + copy_quantparam(sv, split_node); luci::add_origin(split_node, luci::get_origin(sv)); auto succs = loco::succs(sv); @@ -78,7 +69,7 @@ bool resolve_splitv(luci::CircleSplitV *sv) so_node->input(split_node); so_node->index(svo->index()); so_node->name(svo->name()); - copy_quantparam(so_node, svo); + copy_quantparam(svo, so_node); luci::add_origin(so_node, luci::get_origin(svo)); replace(svo).with(so_node); diff --git a/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp b/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp index f48763782..df7266df9 100644 --- a/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp +++ b/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp @@ -76,18 +76,6 @@ std::vector<uint32_t> node_shape(const luci::CircleNode *input) } /** - * @brief copy quantparam of src to dst - */ -void copy_quantparam(luci::CircleNode *dst, const luci::CircleNode *src) -{ - auto q = src->quantparam(); - if (q == nullptr) - dst->quantparam(nullptr); - else - dst->quantparam(std::make_unique<luci::CircleQuantParam>(*q)); -} - -/** * @brief return CircleConst ptr with values of new_shape */ luci::CircleConst *create_shape_const(loco::Graph *graph, const std::vector<uint32_t> &new_shape) @@ -142,7 +130,7 @@ bool substitute_squeeze_to_reshape(luci::CircleSqueeze *squeeze) auto graph = squeeze->graph(); auto reshape = graph->nodes()->create<luci::CircleReshape>(); auto shape_const = create_shape_const(graph, reshape_shape); - copy_quantparam(reshape, squeeze); + copy_quantparam(squeeze, reshape); reshape->name(name + "/Reshape"); luci::add_origin(reshape, luci::get_origin(squeeze)); shape_const->name(name + "/Reshape/shape"); diff --git a/compiler/luci/pass/src/SubstituteStridedSliceToReshapePass.cpp b/compiler/luci/pass/src/SubstituteStridedSliceToReshapePass.cpp index f50f2f54f..9e1c5a4a3 100644 --- a/compiler/luci/pass/src/SubstituteStridedSliceToReshapePass.cpp +++ b/compiler/luci/pass/src/SubstituteStridedSliceToReshapePass.cpp @@ -124,7 +124,7 @@ bool substitute_strided_slice_to_reshape(luci::CircleStridedSlice *ss_node) std::bitset<32> end_mask(ss_node->end_mask()); std::bitset<32> shrink_axis_mask(ss_node->shrink_axis_mask()); - uint input_rank = input_node->rank(); + uint32_t input_rank = input_node->rank(); for (uint32_t i = 0; i < input_rank; i++) { if (!input_node->dim(i).known()) diff --git a/compiler/luci/pass/src/VerifyQuantizedBiasScale.cpp b/compiler/luci/pass/src/VerifyQuantizedBiasScale.cpp new file mode 100644 index 000000000..e65d576cd --- /dev/null +++ b/compiler/luci/pass/src/VerifyQuantizedBiasScale.cpp @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "VerifyQuantizedBiasScale.h" + +#include <cmath> + +// This macro is undef at the end of the file +#define RETURN_FALSE_UNLESS(ARG) \ + if (not(ARG)) \ + { \ + return false; \ + } + +namespace +{ + +bool same(float a, float b) +{ + constexpr float epsilon = 1e-10; + return abs(a - b) < epsilon; +} + +// Check bias scale = input scale * weight scale +// This function checks both LWQ and CWQ +bool check_bias_scale(const loco::Node *input, const loco::Node *weights, const loco::Node *bias) +{ + auto input_node = loco::must_cast<const luci::CircleNode *>(input); + auto input_qparam = input_node->quantparam(); + RETURN_FALSE_UNLESS(input_qparam != nullptr); + + auto weights_node = loco::must_cast<const luci::CircleNode *>(weights); + auto weights_qparam = weights_node->quantparam(); + RETURN_FALSE_UNLESS(weights_qparam != nullptr); + + auto bias_node = loco::must_cast<const luci::CircleNode *>(bias); + auto bias_qparam = bias_node->quantparam(); + RETURN_FALSE_UNLESS(bias_qparam != nullptr); + + RETURN_FALSE_UNLESS(input_qparam->scale.size() == 1); + RETURN_FALSE_UNLESS(weights_qparam->scale.size() == bias_qparam->scale.size()); + + auto input_scale = input_qparam->scale[0]; + for (uint32_t i = 0; i < weights_qparam->scale.size(); i++) + { + auto weights_scale = weights_qparam->scale[i]; + auto bias_scale = bias_qparam->scale[i]; + RETURN_FALSE_UNLESS(same(bias_scale, input_scale * weights_scale)); + } + return true; +} + +} // namespace + +namespace luci +{ + +bool VerifyQuantizedBiasScale::visit(const luci::CircleConv2D *node) +{ + RETURN_FALSE_UNLESS(check_bias_scale(node->input(), node->filter(), node->bias())); + return true; +} + +bool VerifyQuantizedBiasScale::visit(const luci::CircleDepthwiseConv2D *node) +{ + RETURN_FALSE_UNLESS(check_bias_scale(node->input(), node->filter(), node->bias())); + return true; +} + +bool VerifyQuantizedBiasScale::visit(const luci::CircleFullyConnected *node) +{ + luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); + if (bias != nullptr) + { + RETURN_FALSE_UNLESS(check_bias_scale(node->input(), node->weights(), node->bias())); + } + return true; +} + +bool VerifyQuantizedBiasScale::visit(const luci::CircleTransposeConv *node) +{ + luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); + if (bias != nullptr) + { + RETURN_FALSE_UNLESS(check_bias_scale(node->outBackprop(), node->filter(), node->bias())); + } + return true; +} + +} // namespace luci + +#undef RETURN_FALSE_UNLESS diff --git a/compiler/luci/pass/src/VerifyQuantizedBiasScale.h b/compiler/luci/pass/src/VerifyQuantizedBiasScale.h new file mode 100644 index 000000000..b41f78eca --- /dev/null +++ b/compiler/luci/pass/src/VerifyQuantizedBiasScale.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_VERIFY_QUANTIZED_BIAS_SCALE_H__ +#define __LUCI_VERIFY_QUANTIZED_BIAS_SCALE_H__ + +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNodeVisitor.h> + +#include <memory> + +namespace luci +{ + +/** + * @brief Verify the scale of quantized bias node + * @details + * + * Bias of CONV, DCONV, TCONV, FC layers should meet the following condition. + * + * bias scale = input scale * weights scale + */ +class VerifyQuantizedBiasScale : public luci::CircleNodeVisitor<bool> +{ +public: + static std::shared_ptr<VerifyQuantizedBiasScale> create() + { + return std::make_shared<VerifyQuantizedBiasScale>(); + }; + +public: + bool verify(luci::CircleNode *node) { return node->accept(this); } + +private: + // Operators with bias + bool visit(const luci::CircleConv2D *node); + bool visit(const luci::CircleDepthwiseConv2D *node); + bool visit(const luci::CircleFullyConnected *node); + bool visit(const luci::CircleTransposeConv *node); + + bool visit(const luci::CircleNode *) { return true; } +}; + +} // namespace luci + +#endif // __LUCI_VERIFY_QUANTIZED_BIAS_SCALE_H__ diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.cpp b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.cpp new file mode 100644 index 000000000..8697090a7 --- /dev/null +++ b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "VerifyQuantizedNodeGranularity.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/Pass/QuantizationParameters.h> + +#include <memory> + +namespace luci +{ + +std::shared_ptr<VerifyQuantizedNodeGranularity> +VerifyQuantizedNodeGranularity::create(Granularity granularity) +{ + if (granularity == Granularity::ChannelWise) + return std::make_shared<VerifyQuantizedNodeChannelWiseGranularity>(); + else if (granularity == Granularity::LayerWise) + return std::make_shared<VerifyQuantizedNodeLayerWiseGranularity>(); + else + throw std::domain_error("Not supported Granularity type"); +} + +} // namespace luci diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeChannelWiseGranularity.h b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h index bf3ff2e8a..442183c18 100644 --- a/compiler/luci/pass/src/VerifyQuantizedNodeChannelWiseGranularity.h +++ b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h @@ -1,5 +1,6 @@ /* - * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -13,13 +14,15 @@ * limitations under the License. */ -#ifndef __LUCI_VERIFY_QUANTIZED_NODE_CHANNELWISE_GRANULARITY_H__ -#define __LUCI_VERIFY_QUANTIZED_NODE_CHANNELWISE_GRANULARITY_H__ +#ifndef __LUCI_VERIFY_QUANTIZED_NODE_GRANULARITY_H__ +#define __LUCI_VERIFY_QUANTIZED_NODE_GRANULARITY_H__ #include <luci/IR/CircleNodes.h> #include <luci/IR/CircleNodeVisitor.h> #include <luci/Pass/QuantizationParameters.h> +#include <memory> + using Granularity = luci::QuantizationGranularity; // This macro is undef at the end of the file @@ -33,16 +36,19 @@ namespace luci { /** - * @brief Verify the granualrity of channel-wise quantized node + * @brief Verify the granualrity of quantized node * @details * * Targets to verify * - node's output (i.e., node itself) * - node's inputs */ -struct VerifyQuantizedNodeChannelWiseGranularity final : public luci::CircleNodeVisitor<bool> +class VerifyQuantizedNodeGranularity : public luci::CircleNodeVisitor<bool> { -private: +public: + static std::shared_ptr<VerifyQuantizedNodeGranularity> create(Granularity granularity); + +protected: bool is_lwq(const loco::Node *node) { auto circle_node = loco::must_cast<const luci::CircleNode *>(node); @@ -59,48 +65,15 @@ private: return true; } - uint32_t rank(const loco::Node *node) - { - auto circle_node = loco::must_cast<const luci::CircleNode *>(node); - return circle_node->rank(); - } - - bool is_cwq_const(const loco::Node *node, uint32_t channel_dim) - { - auto circle_node = loco::must_cast<const luci::CircleConst *>(node); - - assert(channel_dim < circle_node->rank()); // FIX_CALLER_UNLESS - auto channel_size = circle_node->dim(channel_dim).value(); - - if (circle_node->quantparam() == nullptr) - return false; - - if (circle_node->quantparam()->quantized_dimension != static_cast<int32_t>(channel_dim)) - return false; - - if (circle_node->quantparam()->scale.size() != channel_size) - return false; - - if (circle_node->quantparam()->zerop.size() != channel_size) - return false; - - return true; - } - private: - bool visit(const luci::CircleConv2D *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)) - RETURN_FALSE_UNLESS(is_lwq(node->input())) - RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0)) - luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); - if (bias != nullptr) - RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1)) - return true; - } + virtual bool visit(const luci::CircleConv2D *node) = 0; bool visit(const luci::CircleConcatenation *node) { + // Skip granularity check for concatenation of indices + if (node->dtype() == loco::DataType::S32 or node->dtype() == loco::DataType::S64) + return true; + RETURN_FALSE_UNLESS(is_lwq(node)) for (uint32_t i = 0; i < node->numValues(); i++) { @@ -116,25 +89,9 @@ private: return true; } - bool visit(const luci::CircleDepthwiseConv2D *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)) - RETURN_FALSE_UNLESS(is_lwq(node->input())) - RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 3)) - luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); - if (bias != nullptr) - RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1)) - return true; - } + virtual bool visit(const luci::CircleDepthwiseConv2D *node) = 0; - bool visit(const luci::CircleInstanceNorm *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)) - RETURN_FALSE_UNLESS(is_lwq(node->input())) - RETURN_FALSE_UNLESS(is_cwq_const(node->gamma(), rank(node->gamma()) - 1)) - RETURN_FALSE_UNLESS(is_cwq_const(node->beta(), rank(node->beta()) - 1)) - return true; - } + virtual bool visit(const luci::CircleInstanceNorm *node) = 0; bool visit(const luci::CirclePack *node) { @@ -168,37 +125,11 @@ private: return true; } - bool visit(const luci::CirclePRelu *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)) - RETURN_FALSE_UNLESS(is_lwq(node->input())) - RETURN_FALSE_UNLESS(is_cwq_const(node->alpha(), rank(node->alpha()) - 1)) - return true; - } - - bool visit(const luci::CircleTransposeConv *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)) - RETURN_FALSE_UNLESS(is_lwq(node->outBackprop())) - RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0)) - luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); - if (bias != nullptr) - RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1)) + virtual bool visit(const luci::CirclePRelu *node) = 0; - return true; - } + virtual bool visit(const luci::CircleTransposeConv *node) = 0; - bool visit(const luci::CircleFullyConnected *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)) - RETURN_FALSE_UNLESS(is_lwq(node->input())) - RETURN_FALSE_UNLESS(is_cwq_const(node->weights(), 0)) - luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); - // Bias is optional (it can be CircleOutputExclude) - if (bias != nullptr) - RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1)) - return true; - } + virtual bool visit(const luci::CircleFullyConnected *node) = 0; bool visit(const luci::CircleAdd *node) { @@ -258,6 +189,14 @@ private: return true; } + bool visit(const luci::CircleOneHot *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)); + RETURN_FALSE_UNLESS(is_lwq(node->off_value())); + RETURN_FALSE_UNLESS(is_lwq(node->on_value())); + return true; + } + bool visit(const luci::CircleRelu *node) { RETURN_FALSE_UNLESS(is_lwq(node)); @@ -480,8 +419,186 @@ private: bool visit(const luci::CircleNode *) { return true; } }; +class VerifyQuantizedNodeChannelWiseGranularity final : public VerifyQuantizedNodeGranularity +{ +private: + uint32_t rank(const loco::Node *node) + { + auto circle_node = loco::must_cast<const luci::CircleNode *>(node); + return circle_node->rank(); + } + + bool is_cwq_const(const loco::Node *node, uint32_t channel_dim) + { + auto circle_node = loco::must_cast<const luci::CircleConst *>(node); + + assert(channel_dim < circle_node->rank()); // FIX_CALLER_UNLESS + auto channel_size = circle_node->dim(channel_dim).value(); + + if (circle_node->quantparam() == nullptr) + return false; + + if (circle_node->quantparam()->quantized_dimension != static_cast<int32_t>(channel_dim)) + return false; + + if (circle_node->quantparam()->scale.size() != channel_size) + return false; + + if (circle_node->quantparam()->zerop.size() != channel_size) + return false; + + return true; + } + +private: + bool visit(const luci::CircleConv2D *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->input())) + RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0)) + luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); + if (bias != nullptr) + RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1)) + return true; + } + + bool visit(const luci::CircleDepthwiseConv2D *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->input())) + RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 3)) + luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); + if (bias != nullptr) + RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1)) + return true; + } + + bool visit(const luci::CircleInstanceNorm *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->input())) + RETURN_FALSE_UNLESS(is_cwq_const(node->gamma(), rank(node->gamma()) - 1)) + RETURN_FALSE_UNLESS(is_cwq_const(node->beta(), rank(node->beta()) - 1)) + return true; + } + + bool visit(const luci::CirclePRelu *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->input())) + RETURN_FALSE_UNLESS(is_cwq_const(node->alpha(), rank(node->alpha()) - 1)) + return true; + } + + bool visit(const luci::CircleTransposeConv *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->outBackprop())) + RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0)) + luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); + if (bias != nullptr) + RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1)) + + return true; + } + + bool visit(const luci::CircleFullyConnected *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->input())) + RETURN_FALSE_UNLESS(is_cwq_const(node->weights(), 0)) + luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); + // Bias is optional (it can be CircleOutputExclude) + if (bias != nullptr) + RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1)) + return true; + } +}; + +class VerifyQuantizedNodeLayerWiseGranularity final : public VerifyQuantizedNodeGranularity +{ +private: + bool is_lwq_const(const loco::Node *node) + { + auto circle_node = loco::must_cast<const luci::CircleConst *>(node); + + if (circle_node->quantparam() == nullptr) + return false; + + if (circle_node->quantparam()->scale.size() != 1) + return false; + + if (circle_node->quantparam()->zerop.size() != 1) + return false; + + return true; + } + +private: + bool visit(const luci::CircleConv2D *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->input())) + RETURN_FALSE_UNLESS(is_lwq_const(node->filter())) + luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); + if (bias != nullptr) + RETURN_FALSE_UNLESS(is_lwq_const(node->bias())) + return true; + } + + bool visit(const luci::CircleDepthwiseConv2D *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->input())) + RETURN_FALSE_UNLESS(is_lwq_const(node->filter())) + luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); + if (bias != nullptr) + RETURN_FALSE_UNLESS(is_lwq_const(node->bias())) + return true; + } + + bool visit(const luci::CircleInstanceNorm *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->input())) + RETURN_FALSE_UNLESS(is_lwq_const(node->gamma())) + RETURN_FALSE_UNLESS(is_lwq_const(node->beta())) + return true; + } + + bool visit(const luci::CirclePRelu *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->input())) + RETURN_FALSE_UNLESS(is_lwq_const(node->alpha())) + return true; + } + + bool visit(const luci::CircleTransposeConv *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->outBackprop())) + RETURN_FALSE_UNLESS(is_lwq_const(node->filter())) + luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); + if (bias != nullptr) + RETURN_FALSE_UNLESS(is_lwq_const(node->bias())) + return true; + } + + bool visit(const luci::CircleFullyConnected *node) + { + RETURN_FALSE_UNLESS(is_lwq(node)) + RETURN_FALSE_UNLESS(is_lwq(node->input())) + RETURN_FALSE_UNLESS(is_lwq_const(node->weights())) + luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); + if (bias != nullptr) + RETURN_FALSE_UNLESS(is_lwq_const(node->bias())) + return true; + } +}; + } // namespace luci #undef RETURN_FALSE_UNLESS -#endif // __LUCI_VERIFY_QUANTIZED_NODE_CHANNELWISE_GRANULARITY_H__ +#endif // __LUCI_VERIFY_QUANTIZED_NODE_GRANULARITY_H__ diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeLayerWiseGranularity.h b/compiler/luci/pass/src/VerifyQuantizedNodeLayerWiseGranularity.h deleted file mode 100644 index 9bc8b31df..000000000 --- a/compiler/luci/pass/src/VerifyQuantizedNodeLayerWiseGranularity.h +++ /dev/null @@ -1,473 +0,0 @@ -/* - * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __LUCI_VERIFY_QUANTIZED_NODE_LAYERWISE_GRANULARITY_H__ -#define __LUCI_VERIFY_QUANTIZED_NODE_LAYERWISE_GRANULARITY_H__ - -#include <luci/IR/CircleNodes.h> -#include <luci/IR/CircleNodeVisitor.h> -#include <luci/Pass/QuantizationParameters.h> - -using Granularity = luci::QuantizationGranularity; - -// This macro is undef at the end of the file -#define RETURN_FALSE_UNLESS(ARG) \ - if (not(ARG)) \ - { \ - return false; \ - } - -namespace luci -{ - -/** - * @brief Verify the granualrity of layer-wise quantized node - * @details - * - * Targets to verify - * - node's output (i.e., node itself) - * - node's inputs - */ -struct VerifyQuantizedNodeLayerWiseGranularity final : public luci::CircleNodeVisitor<bool> -{ -private: - bool is_lwq(const loco::Node *node) - { - auto circle_node = loco::must_cast<const luci::CircleNode *>(node); - - if (circle_node->quantparam() == nullptr) - return false; - - if (circle_node->quantparam()->scale.size() != 1) - return false; - - if (circle_node->quantparam()->zerop.size() != 1) - return false; - - return true; - } - - bool is_lwq_const(const loco::Node *node) - { - auto circle_node = loco::must_cast<const luci::CircleConst *>(node); - - if (circle_node->quantparam() == nullptr) - return false; - - if (circle_node->quantparam()->scale.size() != 1) - return false; - - if (circle_node->quantparam()->zerop.size() != 1) - return false; - - return true; - } - -private: - bool visit(const luci::CircleConv2D *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)) - RETURN_FALSE_UNLESS(is_lwq(node->input())) - RETURN_FALSE_UNLESS(is_lwq_const(node->filter())) - luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); - if (bias != nullptr) - RETURN_FALSE_UNLESS(is_lwq_const(node->bias())) - return true; - } - - bool visit(const luci::CircleConcatenation *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)) - for (uint32_t i = 0; i < node->numValues(); i++) - { - RETURN_FALSE_UNLESS(is_lwq(node->values(i))); - } - return true; - } - - bool visit(const luci::CircleDepthToSpace *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)) - RETURN_FALSE_UNLESS(is_lwq(node->input())) - return true; - } - - bool visit(const luci::CircleDepthwiseConv2D *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)) - RETURN_FALSE_UNLESS(is_lwq(node->input())) - RETURN_FALSE_UNLESS(is_lwq_const(node->filter())) - luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); - if (bias != nullptr) - RETURN_FALSE_UNLESS(is_lwq_const(node->bias())) - return true; - } - - bool visit(const luci::CircleInstanceNorm *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)) - RETURN_FALSE_UNLESS(is_lwq(node->input())) - RETURN_FALSE_UNLESS(is_lwq_const(node->gamma())) - RETURN_FALSE_UNLESS(is_lwq_const(node->beta())) - return true; - } - - bool visit(const luci::CirclePack *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)) - for (uint32_t i = 0; i < node->values_count(); i++) - { - RETURN_FALSE_UNLESS(is_lwq(node->values(i))); - } - return true; - } - - bool visit(const luci::CirclePad *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)) - RETURN_FALSE_UNLESS(is_lwq(node->input())) - return true; - } - - bool visit(const luci::CirclePadV2 *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)) - RETURN_FALSE_UNLESS(is_lwq(node->input())) - RETURN_FALSE_UNLESS(is_lwq(node->constant_values())) - return true; - } - - bool visit(const luci::CircleMirrorPad *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)) - RETURN_FALSE_UNLESS(is_lwq(node->input())) - return true; - } - - bool visit(const luci::CirclePRelu *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)) - RETURN_FALSE_UNLESS(is_lwq(node->input())) - RETURN_FALSE_UNLESS(is_lwq_const(node->alpha())) - return true; - } - - bool visit(const luci::CircleTransposeConv *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)) - RETURN_FALSE_UNLESS(is_lwq(node->outBackprop())) - RETURN_FALSE_UNLESS(is_lwq_const(node->filter())) - luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); - if (bias != nullptr) - RETURN_FALSE_UNLESS(is_lwq_const(node->bias())) - return true; - } - - bool visit(const luci::CircleFullyConnected *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)) - RETURN_FALSE_UNLESS(is_lwq(node->input())) - RETURN_FALSE_UNLESS(is_lwq_const(node->weights())) - luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); - if (bias != nullptr) - RETURN_FALSE_UNLESS(is_lwq_const(node->bias())) - return true; - } - - bool visit(const luci::CircleAdd *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)) - RETURN_FALSE_UNLESS(is_lwq(node->x())); - RETURN_FALSE_UNLESS(is_lwq(node->y())); - return true; - } - - bool visit(const luci::CircleAveragePool2D *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)) - RETURN_FALSE_UNLESS(is_lwq(node->value())); - return true; - } - - bool visit(const luci::CircleLogicalOr *) - { - // Logical OR has bool-type inputs and output - // Nothing to be checked - return true; - } - - bool visit(const luci::CircleMaxPool2D *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)) - RETURN_FALSE_UNLESS(is_lwq(node->value())); - return true; - } - - bool visit(const luci::CircleLocalResponseNormalization *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)) - RETURN_FALSE_UNLESS(is_lwq(node->input())); - return true; - } - - bool visit(const luci::CircleMean *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)) - RETURN_FALSE_UNLESS(is_lwq(node->input())); - return true; - } - - bool visit(const luci::CircleMul *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)) - RETURN_FALSE_UNLESS(is_lwq(node->x())); - RETURN_FALSE_UNLESS(is_lwq(node->y())); - return true; - } - - bool visit(const luci::CircleNotEqual *node) - { - RETURN_FALSE_UNLESS(is_lwq(node->x())); - RETURN_FALSE_UNLESS(is_lwq(node->y())); - return true; - } - - bool visit(const luci::CircleRelu *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)) - RETURN_FALSE_UNLESS(is_lwq(node->features())); - return true; - } - - bool visit(const luci::CircleReshape *node) - { - auto input = loco::must_cast<const luci::CircleNode *>(node->tensor()); - bool input_quantized = input->quantparam() != nullptr; - bool node_quantized = node->quantparam() != nullptr; - RETURN_FALSE_UNLESS(input_quantized == node_quantized); - RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node)) - RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input)); - return true; - } - - bool visit(const luci::CircleLogistic *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)); - RETURN_FALSE_UNLESS(is_lwq(node->x())); - return true; - } - - bool visit(const luci::CircleSoftmax *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)); - RETURN_FALSE_UNLESS(is_lwq(node->logits())); - return true; - } - - bool visit(const luci::CircleSpaceToBatchND *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)); - RETURN_FALSE_UNLESS(is_lwq(node->input())); - return true; - } - - bool visit(const luci::CircleSpaceToDepth *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)); - RETURN_FALSE_UNLESS(is_lwq(node->input())); - return true; - } - - bool visit(const luci::CircleSlice *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)); - RETURN_FALSE_UNLESS(is_lwq(node->input())); - return true; - } - - bool visit(const luci::CircleSplit *node) - { - // node's output is the input of CircleSplitOut, thus not quantized - RETURN_FALSE_UNLESS(is_lwq(node->input())); - return true; - } - - bool visit(const luci::CircleSplitOut *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)); - return true; - } - - bool visit(const luci::CircleSplitV *node) - { - // node's output is the input of CircleSplitVOut, thus not quantized - RETURN_FALSE_UNLESS(is_lwq(node->input())); - return true; - } - - bool visit(const luci::CircleSplitVOut *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)); - return true; - } - - bool visit(const luci::CircleStridedSlice *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)); - RETURN_FALSE_UNLESS(is_lwq(node->input())); - return true; - } - - bool visit(const luci::CircleArgMax *node) - { - // node's output is index, thus not quantized - RETURN_FALSE_UNLESS(is_lwq(node->input())); - return true; - } - - bool visit(const luci::CircleBatchToSpaceND *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)); - RETURN_FALSE_UNLESS(is_lwq(node->input())); - return true; - } - - bool visit(const luci::CircleTanh *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)); - RETURN_FALSE_UNLESS(is_lwq(node->x())); - return true; - } - - bool visit(const luci::CircleTranspose *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)); - RETURN_FALSE_UNLESS(is_lwq(node->a())); - return true; - } - - bool visit(const luci::CircleFloor *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)); - RETURN_FALSE_UNLESS(is_lwq(node->x())); - return true; - } - - bool visit(const luci::CircleGreater *node) - { - RETURN_FALSE_UNLESS(is_lwq(node->x())); - RETURN_FALSE_UNLESS(is_lwq(node->y())); - return true; - } - - bool visit(const luci::CircleGreaterEqual *node) - { - RETURN_FALSE_UNLESS(is_lwq(node->x())); - RETURN_FALSE_UNLESS(is_lwq(node->y())); - return true; - } - - bool visit(const luci::CircleDiv *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)); - RETURN_FALSE_UNLESS(is_lwq(node->x())); - RETURN_FALSE_UNLESS(is_lwq(node->y())); - return true; - } - - bool visit(const luci::CircleFloorDiv *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)); - RETURN_FALSE_UNLESS(is_lwq(node->x())); - RETURN_FALSE_UNLESS(is_lwq(node->y())); - return true; - } - - bool visit(const luci::CircleRsqrt *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)); - RETURN_FALSE_UNLESS(is_lwq(node->x())); - return true; - } - - bool visit(const luci::CircleSqrt *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)); - RETURN_FALSE_UNLESS(is_lwq(node->x())); - return true; - } - - bool visit(const luci::CircleElu *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)); - RETURN_FALSE_UNLESS(is_lwq(node->features())); - return true; - } - - bool visit(const luci::CirclePow *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)); - RETURN_FALSE_UNLESS(is_lwq(node->x())); - RETURN_FALSE_UNLESS(is_lwq(node->y())); - return true; - } - - bool visit(const luci::CircleResizeBilinear *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)); - RETURN_FALSE_UNLESS(is_lwq(node->input())); - return true; - } - - bool visit(const luci::CircleResizeNearestNeighbor *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)); - RETURN_FALSE_UNLESS(is_lwq(node->input())); - return true; - } - - bool visit(const luci::CircleUnpack *node) - { - // node's output is the input of CircleUnpackOut, thus not quantized - RETURN_FALSE_UNLESS(is_lwq(node->value())); - return true; - } - - bool visit(const luci::CircleUnpackOut *node) - { - RETURN_FALSE_UNLESS(is_lwq(node)); - return true; - } - - bool visit(const luci::CircleCast *node) - { - auto input = loco::must_cast<const luci::CircleNode *>(node->x()); - bool input_quantized = input->quantparam() != nullptr; - bool node_quantized = node->quantparam() != nullptr; - RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input)); - RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node)); - return true; - } - - // TODO: Implement more Ops - - bool visit(const luci::CircleNode *) { return true; } -}; - -} // namespace luci - -#undef RETURN_FALSE_UNLESS - -#endif // __LUCI_VERIFY_QUANTIZED_NODE_LAYERWISE_GRANULARITY_H__ diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeS16Type.h b/compiler/luci/pass/src/VerifyQuantizedNodeS16Type.h deleted file mode 100644 index eeec7b82b..000000000 --- a/compiler/luci/pass/src/VerifyQuantizedNodeS16Type.h +++ /dev/null @@ -1,516 +0,0 @@ -/* - * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __LUCI_VERIFY_QUANTIZED_NODE_S16_TYPE_H__ -#define __LUCI_VERIFY_QUANTIZED_NODE_S16_TYPE_H__ - -#include <luci/IR/CircleNodes.h> -#include <luci/IR/CircleNodeVisitor.h> - -#include <cmath> - -using Type = loco::DataType; - -// This macro is undef at the end of the file -#define RETURN_FALSE_UNLESS(ARG) \ - if (not(ARG)) \ - { \ - return false; \ - } - -namespace luci -{ - -/** - * @brief Verify the data type of INT16 quantized node - * @details - * - * Targets to verify - * - node's output (i.e., node itself) - * - node's inputs - */ -struct VerifyQuantizedNodeS16Type final : public luci::CircleNodeVisitor<bool> -{ -private: - bool has_type(const loco::Node *node, Type dtype) - { - auto circle_node = loco::must_cast<const luci::CircleNode *>(node); - return circle_node->dtype() == dtype; - } - -private: - bool visit(const luci::CircleConv2D *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->filter(), Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S64)) - return true; - } - - bool visit(const luci::CircleConcatenation *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - for (uint32_t i = 0; i < node->numValues(); i++) - { - RETURN_FALSE_UNLESS(has_type(node->values(i), Type::S16)) - } - return true; - } - - bool visit(const luci::CircleDepthToSpace *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) - return true; - } - - bool visit(const luci::CircleDepthwiseConv2D *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->filter(), Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S64)) - return true; - } - - bool visit(const luci::CircleInstanceNorm *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->gamma(), Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->beta(), Type::S16)) - return true; - } - - bool visit(const luci::CirclePack *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - for (uint32_t i = 0; i < node->values_count(); i++) - { - RETURN_FALSE_UNLESS(has_type(node->values(i), Type::S16)) - } - return true; - } - - bool visit(const luci::CirclePad *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32)) - return true; - } - - bool visit(const luci::CirclePadV2 *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32)) - RETURN_FALSE_UNLESS(has_type(node->constant_values(), Type::S16)) - return true; - } - - bool visit(const luci::CircleMirrorPad *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32)) - return true; - } - - bool visit(const luci::CirclePRelu *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->alpha(), Type::S16)) - return true; - } - - bool visit(const luci::CircleTransposeConv *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->outBackprop(), Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->filter(), Type::S16)) - luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); - if (bias != nullptr) - RETURN_FALSE_UNLESS(has_type(bias, Type::S64)) - return true; - } - - bool visit(const luci::CircleFullyConnected *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->weights(), Type::S16)) - luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); - if (bias != nullptr) - RETURN_FALSE_UNLESS(has_type(bias, Type::S64)) - return true; - } - - bool visit(const luci::CircleAdd *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16)) - return true; - } - - bool visit(const luci::CircleAveragePool2D *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->value(), Type::S16)) - return true; - } - - bool visit(const luci::CircleLogicalOr *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::BOOL)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::BOOL)) - RETURN_FALSE_UNLESS(has_type(node->y(), Type::BOOL)) - return true; - } - - bool visit(const luci::CircleMaxPool2D *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->value(), Type::S16)) - return true; - } - - bool visit(const luci::CircleLocalResponseNormalization *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) - return true; - } - - bool visit(const luci::CircleMean *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->reduction_indices(), Type::S32)) - return true; - } - - bool visit(const luci::CircleMul *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16)) - return true; - } - - bool visit(const luci::CircleNotEqual *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::BOOL)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16)) - return true; - } - - bool visit(const luci::CircleRelu *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->features(), Type::S16)) - return true; - } - - bool visit(const luci::CircleReshape *node) - { - if (node->quantparam()) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->tensor(), Type::S16)) - } - else - { - RETURN_FALSE_UNLESS(has_type(node->tensor(), node->dtype())) - } - luci::CircleConst *shape = dynamic_cast<luci::CircleConst *>(node->shape()); - if (shape != nullptr) - RETURN_FALSE_UNLESS(has_type(shape, Type::S32)) - return true; - } - - bool visit(const luci::CircleLogistic *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16)) - - RETURN_FALSE_UNLESS(node->quantparam()); - RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32768.0f); - RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0); - return true; - } - - bool visit(const luci::CircleSoftmax *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->logits(), Type::S16)) - - RETURN_FALSE_UNLESS(node->quantparam()); - RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32767.0f); - RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0); - return true; - } - - bool visit(const luci::CircleSpaceToBatchND *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) - return true; - } - - bool visit(const luci::CircleSpaceToDepth *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) - return true; - } - - bool visit(const luci::CircleSlice *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->begin(), Type::S32) || has_type(node->begin(), Type::S64)) - RETURN_FALSE_UNLESS(has_type(node->size(), Type::S32) || has_type(node->size(), Type::S64)) - return true; - } - - bool visit(const luci::CircleSplit *node) - { - // node's output is the input of CircleSplitOut, thus not quantized - RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) - return true; - } - - bool visit(const luci::CircleSplitOut *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - - // SplitOut has the same qparam with the input of Split - auto split = loco::must_cast<luci::CircleSplit *>(node->input()); - auto input = loco::must_cast<luci::CircleNode *>(split->input()); - RETURN_FALSE_UNLESS(node->quantparam()); - RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]); - RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]); - return true; - } - - bool visit(const luci::CircleSplitV *node) - { - // node's output is the input of CircleSplitVOut, thus not quantized - RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) - return true; - } - - bool visit(const luci::CircleSplitVOut *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - - // SplitVOut has the same qparam with the input of SplitV - auto splitv = loco::must_cast<luci::CircleSplitV *>(node->input()); - auto input = loco::must_cast<luci::CircleNode *>(splitv->input()); - RETURN_FALSE_UNLESS(node->quantparam()); - RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]); - RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]); - return true; - } - - bool visit(const luci::CircleStridedSlice *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) - - auto input = loco::must_cast<luci::CircleNode *>(node->input()); - RETURN_FALSE_UNLESS(node->quantparam()); - RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]); - RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]); - return true; - } - - bool visit(const luci::CircleArgMax *node) - { - RETURN_FALSE_UNLESS(has_type(node, node->output_type())) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->dimension(), Type::S32) || - has_type(node->dimension(), Type::S64)) - return true; - } - - bool visit(const luci::CircleBatchToSpaceND *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) - return true; - } - - bool visit(const luci::CircleTanh *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16)) - - RETURN_FALSE_UNLESS(node->quantparam()); - RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32768.0f); - RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0); - return true; - } - - bool visit(const luci::CircleTranspose *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->a(), Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->perm(), Type::S32)) - return true; - } - - bool visit(const luci::CircleFloor *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16)) - - // This checks the value of scale is an integer - RETURN_FALSE_UNLESS(node->quantparam()); - RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]); - return true; - } - - bool visit(const luci::CircleGreater *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::BOOL)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16)) - return true; - } - - bool visit(const luci::CircleGreaterEqual *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::BOOL)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16)) - return true; - } - - bool visit(const luci::CircleDiv *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16)) - return true; - } - - bool visit(const luci::CircleFloorDiv *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16)) - - // This checks the value of scale is an integer - RETURN_FALSE_UNLESS(node->quantparam()); - RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]); - return true; - } - - bool visit(const luci::CircleRsqrt *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16)) - return true; - } - - bool visit(const luci::CircleSqrt *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16)) - return true; - } - - bool visit(const luci::CircleElu *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->features(), Type::S16)) - return true; - } - - bool visit(const luci::CirclePow *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16)) - return true; - } - - bool visit(const luci::CircleResizeBilinear *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) - return true; - } - - bool visit(const luci::CircleResizeNearestNeighbor *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16)) - return true; - } - - bool visit(const luci::CircleUnpack *node) - { - // node's output is the input of CircleUnpackOut, thus not quantized - RETURN_FALSE_UNLESS(has_type(node->value(), Type::S16)) - return true; - } - - bool visit(const luci::CircleUnpackOut *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - - // UnpackOut has the same qparam with the input of Unpack - auto Unpack = loco::must_cast<luci::CircleUnpack *>(node->input()); - auto input = loco::must_cast<luci::CircleNode *>(Unpack->value()); - RETURN_FALSE_UNLESS(node->quantparam() && input->quantparam()); - RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]); - RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]); - return true; - } - - bool visit(const luci::CircleCast *node) - { - auto *input = loco::must_cast<luci::CircleNode *>(node->x()); - RETURN_FALSE_UNLESS(has_type(input, node->in_data_type())) - - bool input_quantized = input->quantparam() != nullptr; - if (input_quantized) - RETURN_FALSE_UNLESS(has_type(input, Type::S16)) - - RETURN_FALSE_UNLESS(has_type(node, node->out_data_type())) - - bool node_quantized = node->quantparam() != nullptr; - if (node_quantized) - RETURN_FALSE_UNLESS(has_type(node, Type::S16)) - return true; - } - - // TODO: Implement more Ops - - bool visit(const luci::CircleNode *) { return true; } -}; - -} // namespace luci - -#undef RETURN_FALSE_UNLESS - -#endif // __LUCI_VERIFY_QUNTIZED_NODE_S16_TYPE_H__ diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp b/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp new file mode 100644 index 000000000..4e1c062c0 --- /dev/null +++ b/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp @@ -0,0 +1,554 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "VerifyQuantizedNodeType.h" + +#include <cmath> +#include <memory> + +// This macro is undef at the end of the file +#define RETURN_FALSE_UNLESS(ARG) \ + if (not(ARG)) \ + { \ + return false; \ + } + +namespace luci +{ + +std::shared_ptr<VerifyQuantizedNodeType> VerifyQuantizedNodeType::create(loco::DataType dtype) +{ + if (dtype == loco::DataType::U8) + return std::make_shared<VerifyQuantizedNodeU8Type>(); + else if (dtype == loco::DataType::S16) + return std::make_shared<VerifyQuantizedNodeS16Type>(); + else + throw std::domain_error("Not supported Quantized type"); +} + +} // namespace luci + +namespace luci +{ + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleAdd *node) +{ + return group_has_type(node, Qtype); +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleArgMax *node) +{ + RETURN_FALSE_UNLESS(has_type(node, node->output_type())) + RETURN_FALSE_UNLESS(has_type(node->input(), Qtype)) + RETURN_FALSE_UNLESS(has_type(node->dimension(), loco::DataType::S32) || + has_type(node->dimension(), loco::DataType::S64)) + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleAveragePool2D *node) +{ + return group_has_type(node, Qtype); +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleBatchToSpaceND *node) +{ + RETURN_FALSE_UNLESS(has_type(node, Qtype)) + RETURN_FALSE_UNLESS(has_type(node->input(), Qtype)) + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleCast *node) +{ + auto *input = loco::must_cast<luci::CircleNode *>(node->x()); + bool input_quantized = input->quantparam() != nullptr; + if (input_quantized) + { + RETURN_FALSE_UNLESS(has_type(input, node->in_data_type())) + RETURN_FALSE_UNLESS(has_type(input, Qtype)) + } + + bool node_quantized = node->quantparam() != nullptr; + if (node_quantized) + { + RETURN_FALSE_UNLESS(has_type(node, node->out_data_type())) + RETURN_FALSE_UNLESS(has_type(node, Qtype)) + } + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleConv2D *node) +{ + RETURN_FALSE_UNLESS(has_type(node, Qtype)) + RETURN_FALSE_UNLESS(has_type(node->input(), Qtype)) + RETURN_FALSE_UNLESS(has_type(node->filter(), Qtype)) + RETURN_FALSE_UNLESS(has_type(node->bias(), Btype)) + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleConcatenation *node) +{ + // Allow concatenation of indices + if (group_has_type(node, loco::DataType::S32) or group_has_type(node, loco::DataType::S64)) + return true; + + return group_has_type(node, Qtype); +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleDepthToSpace *node) +{ + return group_has_type(node, Qtype); +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleDepthwiseConv2D *node) +{ + RETURN_FALSE_UNLESS(has_type(node, Qtype)) + RETURN_FALSE_UNLESS(has_type(node->input(), Qtype)) + RETURN_FALSE_UNLESS(has_type(node->filter(), Qtype)) + RETURN_FALSE_UNLESS(has_type(node->bias(), Btype)) + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleDiv *node) +{ + return group_has_type(node, Qtype); +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleElu *node) +{ + return group_has_type(node, Qtype); +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleFloor *node) +{ + RETURN_FALSE_UNLESS(group_has_type(node, Qtype)); + + // This checks the value of scale is an integer + RETURN_FALSE_UNLESS(node->quantparam()); + RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]); + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleFloorDiv *node) +{ + RETURN_FALSE_UNLESS(group_has_type(node, Qtype)); + + // This checks the value of scale is an integer + RETURN_FALSE_UNLESS(node->quantparam()); + RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]); + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleFullyConnected *node) +{ + RETURN_FALSE_UNLESS(has_type(node, Qtype)) + RETURN_FALSE_UNLESS(has_type(node->input(), Qtype)) + RETURN_FALSE_UNLESS(has_type(node->weights(), Qtype)) + luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); + if (bias != nullptr) + RETURN_FALSE_UNLESS(has_type(bias, Btype)) + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleGreater *node) +{ + RETURN_FALSE_UNLESS(has_type(node, loco::DataType::BOOL)) + RETURN_FALSE_UNLESS(has_type(node->x(), Qtype)) + RETURN_FALSE_UNLESS(has_type(node->y(), Qtype)) + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleGreaterEqual *node) +{ + RETURN_FALSE_UNLESS(has_type(node, loco::DataType::BOOL)) + RETURN_FALSE_UNLESS(has_type(node->x(), Qtype)) + RETURN_FALSE_UNLESS(has_type(node->y(), Qtype)) + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleInstanceNorm *node) +{ + return group_has_type(node, Qtype); +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit( + const luci::CircleLocalResponseNormalization *node) +{ + return group_has_type(node, Qtype); +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleLogicalOr *node) +{ + return group_has_type(node, loco::DataType::BOOL); +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleMaxPool2D *node) +{ + return group_has_type(node, Qtype); +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleMean *node) +{ + RETURN_FALSE_UNLESS(has_type(node, Qtype)) + RETURN_FALSE_UNLESS(has_type(node->input(), Qtype)) + RETURN_FALSE_UNLESS(has_type(node->reduction_indices(), loco::DataType::S32)) + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleMirrorPad *node) +{ + RETURN_FALSE_UNLESS(has_type(node, Qtype)) + RETURN_FALSE_UNLESS(has_type(node->input(), Qtype)) + RETURN_FALSE_UNLESS(has_type(node->paddings(), loco::DataType::S32)) + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleMul *node) +{ + return group_has_type(node, Qtype); +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleNotEqual *node) +{ + RETURN_FALSE_UNLESS(has_type(node, loco::DataType::BOOL)) + RETURN_FALSE_UNLESS(has_type(node->x(), Qtype)) + RETURN_FALSE_UNLESS(has_type(node->y(), Qtype)) + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleOneHot *node) +{ + RETURN_FALSE_UNLESS(has_type(node, Qtype)); + RETURN_FALSE_UNLESS(has_type(node->indices(), loco::DataType::S32) || + has_type(node->indices(), loco::DataType::S64)); + RETURN_FALSE_UNLESS(has_type(node->depth(), loco::DataType::S32)); + RETURN_FALSE_UNLESS(has_type(node->on_value(), Qtype)); + RETURN_FALSE_UNLESS(has_type(node->off_value(), Qtype)); + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePack *node) +{ + return group_has_type(node, Qtype); +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePad *node) +{ + RETURN_FALSE_UNLESS(has_type(node, Qtype)) + RETURN_FALSE_UNLESS(has_type(node->input(), Qtype)) + RETURN_FALSE_UNLESS(has_type(node->paddings(), loco::DataType::S32)) + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePadV2 *node) +{ + RETURN_FALSE_UNLESS(has_type(node, Qtype)) + RETURN_FALSE_UNLESS(has_type(node->input(), Qtype)) + RETURN_FALSE_UNLESS(has_type(node->paddings(), loco::DataType::S32)) + RETURN_FALSE_UNLESS(has_type(node->constant_values(), Qtype)) + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePRelu *node) +{ + return group_has_type(node, Qtype); +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePow *node) +{ + return group_has_type(node, Qtype); +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleRelu *node) +{ + return group_has_type(node, Qtype); +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleReshape *node) +{ + if (node->quantparam()) + { + RETURN_FALSE_UNLESS(has_type(node, Qtype)) + RETURN_FALSE_UNLESS(has_type(node->tensor(), Qtype)) + } + else + { + RETURN_FALSE_UNLESS(has_type(node->tensor(), node->dtype())) + } + luci::CircleConst *shape = dynamic_cast<luci::CircleConst *>(node->shape()); + if (shape != nullptr) + RETURN_FALSE_UNLESS(has_type(shape, loco::DataType::S32)) + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleResizeBilinear *node) +{ + RETURN_FALSE_UNLESS(has_type(node, Qtype)) + RETURN_FALSE_UNLESS(has_type(node->input(), Qtype)) + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleResizeNearestNeighbor *node) +{ + RETURN_FALSE_UNLESS(has_type(node, Qtype)) + RETURN_FALSE_UNLESS(has_type(node->input(), Qtype)) + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleRsqrt *node) +{ + return group_has_type(node, Qtype); +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSlice *node) +{ + RETURN_FALSE_UNLESS(has_type(node, Qtype)) + RETURN_FALSE_UNLESS(has_type(node->input(), Qtype)) + RETURN_FALSE_UNLESS(has_type(node->begin(), loco::DataType::S32) || + has_type(node->begin(), loco::DataType::S64)) + RETURN_FALSE_UNLESS(has_type(node->size(), loco::DataType::S32) || + has_type(node->size(), loco::DataType::S64)) + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSpaceToBatchND *node) +{ + RETURN_FALSE_UNLESS(has_type(node, Qtype)) + RETURN_FALSE_UNLESS(has_type(node->input(), Qtype)) + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSpaceToDepth *node) +{ + return group_has_type(node, Qtype); +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSplit *node) +{ + // node's output is the input of CircleSplitOut, thus not quantized + RETURN_FALSE_UNLESS(has_type(node->input(), Qtype)) + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSplitOut *node) +{ + RETURN_FALSE_UNLESS(has_type(node, Qtype)) + + // SplitOut has the same qparam with the input of Split + auto split = loco::must_cast<luci::CircleSplit *>(node->input()); + auto input = loco::must_cast<luci::CircleNode *>(split->input()); + RETURN_FALSE_UNLESS(node->quantparam()); + RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]); + RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]); + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSplitV *node) +{ + // node's output is the input of CircleSplitVOut, thus not quantized + RETURN_FALSE_UNLESS(has_type(node->input(), Qtype)) + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSplitVOut *node) +{ + RETURN_FALSE_UNLESS(has_type(node, Qtype)) + + // SplitVOut has the same qparam with the input of SplitV + auto splitv = loco::must_cast<luci::CircleSplitV *>(node->input()); + auto input = loco::must_cast<luci::CircleNode *>(splitv->input()); + RETURN_FALSE_UNLESS(node->quantparam()); + RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]); + RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]); + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSqrt *node) +{ + return group_has_type(node, Qtype); +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleStridedSlice *node) +{ + RETURN_FALSE_UNLESS(has_type(node, Qtype)) + RETURN_FALSE_UNLESS(has_type(node->input(), Qtype)) + + auto input = loco::must_cast<luci::CircleNode *>(node->input()); + RETURN_FALSE_UNLESS(node->quantparam()); + RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]); + RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]); + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleTranspose *node) +{ + RETURN_FALSE_UNLESS(has_type(node, Qtype)) + RETURN_FALSE_UNLESS(has_type(node->a(), Qtype)) + RETURN_FALSE_UNLESS(has_type(node->perm(), loco::DataType::S32)) + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleTransposeConv *node) +{ + RETURN_FALSE_UNLESS(has_type(node, Qtype)) + RETURN_FALSE_UNLESS(has_type(node->outBackprop(), Qtype)) + RETURN_FALSE_UNLESS(has_type(node->filter(), Qtype)) + luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); + if (bias != nullptr) + RETURN_FALSE_UNLESS(has_type(bias, Btype)) + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleUnpack *node) +{ + // node's output is the input of CircleUnpackOut, thus not quantized + RETURN_FALSE_UNLESS(has_type(node->value(), Qtype)) + return true; +} + +template <loco::DataType Qtype, loco::DataType Btype> +bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleUnpackOut *node) +{ + RETURN_FALSE_UNLESS(has_type(node, Qtype)) + + // UnpackOut has the same qparam with the input of Unpack + auto Unpack = loco::must_cast<luci::CircleUnpack *>(node->input()); + auto input = loco::must_cast<luci::CircleNode *>(Unpack->value()); + RETURN_FALSE_UNLESS(node->quantparam() && input->quantparam()); + RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]); + RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]); + return true; +} + +} // namespace luci + +namespace luci +{ + +bool VerifyQuantizedNodeU8Type::visit(const luci::CircleTanh *node) +{ + RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::U8)); + + RETURN_FALSE_UNLESS(node->quantparam()); + RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 2.0f / 256.0f); + RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 128); + return true; +} + +bool VerifyQuantizedNodeU8Type::visit(const luci::CircleLogistic *node) +{ + RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::U8)); + + RETURN_FALSE_UNLESS(node->quantparam()); + RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 256.0f); + RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0); + return true; +} + +bool VerifyQuantizedNodeU8Type::visit(const luci::CircleSoftmax *node) +{ + RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::U8)); + + RETURN_FALSE_UNLESS(node->quantparam()); + RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 255.0f); + RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0); + return true; +} + +} // namespace luci + +namespace luci +{ + +bool VerifyQuantizedNodeS16Type::visit(const luci::CircleTanh *node) +{ + RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::S16)); + + RETURN_FALSE_UNLESS(node->quantparam()); + RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32768.0f); + RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0); + return true; +} + +bool VerifyQuantizedNodeS16Type::visit(const luci::CircleLogistic *node) +{ + RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::S16)); + + RETURN_FALSE_UNLESS(node->quantparam()); + RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32768.0f); + RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0); + return true; +} + +bool VerifyQuantizedNodeS16Type::visit(const luci::CircleSoftmax *node) +{ + RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::S16)); + + RETURN_FALSE_UNLESS(node->quantparam()); + RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32767.0f); + RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0); + return true; +} + +} // namespace luci + +#undef RETURN_FALSE_UNLESS diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeType.h b/compiler/luci/pass/src/VerifyQuantizedNodeType.h new file mode 100644 index 000000000..ff1acbd6f --- /dev/null +++ b/compiler/luci/pass/src/VerifyQuantizedNodeType.h @@ -0,0 +1,157 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_VERIFY_QUANTIZED_NODE_TYPE_H__ +#define __LUCI_VERIFY_QUANTIZED_NODE_TYPE_H__ + +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNodeVisitor.h> + +namespace luci +{ + +/** + * @brief Verify the data type of quantized node + * @details + * + * Targets to verify + * - node's output (i.e., node itself) + * - node's inputs + */ +class VerifyQuantizedNodeType +{ +public: + static std::shared_ptr<VerifyQuantizedNodeType> create(loco::DataType dtype); + +public: + virtual bool verify(luci::CircleNode *node) = 0; +}; + +/** + * @brief Verify using quantization type of a node and bias + * + * @tparam Qtype Quantization type for a node (e.g. Q8, Q16, ...) + * @tparam Btype Bias quantization type (e.g. For Q8, S32 is used) + */ +template <loco::DataType Qtype, loco::DataType Btype> +class VerifyQuantizedNodeTypeBase : public luci::CircleNodeVisitor<bool>, + public VerifyQuantizedNodeType +{ +public: + bool verify(luci::CircleNode *node) { return node->accept(this); } + +protected: + bool has_type(const loco::Node *node, loco::DataType dtype) + { + auto circle_node = loco::must_cast<const luci::CircleNode *>(node); + return circle_node->dtype() == dtype; + } + + // Check whether a node and all of its inputs have dtype or not + bool group_has_type(const loco::Node *node, loco::DataType dtype) + { + if (!has_type(node, dtype)) + return false; + + for (uint32_t i = 0; i < node->arity(); ++i) + if (!has_type(node->arg(i), dtype)) + return false; + + return true; + } + +private: + bool visit(const luci::CircleAdd *node); + bool visit(const luci::CircleArgMax *node); + bool visit(const luci::CircleAveragePool2D *node); + bool visit(const luci::CircleBatchToSpaceND *node); + bool visit(const luci::CircleCast *node); + bool visit(const luci::CircleConv2D *node); + bool visit(const luci::CircleConcatenation *node); + bool visit(const luci::CircleDepthToSpace *node); + bool visit(const luci::CircleDepthwiseConv2D *node); + bool visit(const luci::CircleDiv *node); + bool visit(const luci::CircleElu *node); + bool visit(const luci::CircleFloor *node); + bool visit(const luci::CircleFloorDiv *node); + bool visit(const luci::CircleFullyConnected *node); + bool visit(const luci::CircleGreater *node); + bool visit(const luci::CircleGreaterEqual *node); + bool visit(const luci::CircleInstanceNorm *node); + bool visit(const luci::CircleLocalResponseNormalization *node); + bool visit(const luci::CircleLogicalOr *node); + bool visit(const luci::CircleMaxPool2D *node); + bool visit(const luci::CircleMean *node); + bool visit(const luci::CircleMirrorPad *node); + bool visit(const luci::CircleMul *node); + bool visit(const luci::CircleNotEqual *node); + bool visit(const luci::CircleOneHot *node); + bool visit(const luci::CirclePack *node); + bool visit(const luci::CirclePad *node); + bool visit(const luci::CirclePadV2 *node); + bool visit(const luci::CirclePRelu *node); + bool visit(const luci::CirclePow *node); + bool visit(const luci::CircleRelu *node); + bool visit(const luci::CircleReshape *node); + bool visit(const luci::CircleResizeBilinear *node); + bool visit(const luci::CircleResizeNearestNeighbor *node); + bool visit(const luci::CircleRsqrt *node); + bool visit(const luci::CircleSlice *node); + bool visit(const luci::CircleSpaceToBatchND *node); + bool visit(const luci::CircleSpaceToDepth *node); + bool visit(const luci::CircleSplit *node); + bool visit(const luci::CircleSplitOut *node); + bool visit(const luci::CircleSplitV *node); + bool visit(const luci::CircleSplitVOut *node); + bool visit(const luci::CircleSqrt *node); + bool visit(const luci::CircleStridedSlice *node); + bool visit(const luci::CircleTranspose *node); + bool visit(const luci::CircleTransposeConv *node); + bool visit(const luci::CircleUnpack *node); + bool visit(const luci::CircleUnpackOut *node); + + // NOTE below nodes has differnent implementation for Qtype/Btype and + // implementations exist in VerifyQuantizedNodeU8Type, VerifyQuantizedNodeS16Type + // bool visit(const luci::CircleLogistic *node); + // bool visit(const luci::CircleSoftmax *node); + // bool visit(const luci::CircleTanh *node); + + // TODO: Implement more Ops + + bool visit(const luci::CircleNode *) { return true; } +}; + +class VerifyQuantizedNodeU8Type + : public VerifyQuantizedNodeTypeBase<loco::DataType::U8, loco::DataType::S32> +{ +private: + bool visit(const luci::CircleLogistic *node); + bool visit(const luci::CircleSoftmax *node); + bool visit(const luci::CircleTanh *node); +}; + +class VerifyQuantizedNodeS16Type + : public VerifyQuantizedNodeTypeBase<loco::DataType::S16, loco::DataType::S64> +{ +private: + bool visit(const luci::CircleLogistic *node); + bool visit(const luci::CircleSoftmax *node); + bool visit(const luci::CircleTanh *node); +}; + +} // namespace luci + +#endif // __LUCI_VERIFY_QUANTIZED_NODE_TYPE_H__ diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h b/compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h deleted file mode 100644 index e7dd1b072..000000000 --- a/compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h +++ /dev/null @@ -1,518 +0,0 @@ -/* - * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __LUCI_VERIFY_QUANTIZED_NODE_U8_TYPE_H__ -#define __LUCI_VERIFY_QUANTIZED_NODE_U8_TYPE_H__ - -#include <luci/IR/CircleNodes.h> -#include <luci/IR/CircleNodeVisitor.h> - -#include <cmath> - -using Type = loco::DataType; - -// This macro is undef at the end of the file -#define RETURN_FALSE_UNLESS(ARG) \ - if (not(ARG)) \ - { \ - return false; \ - } - -namespace luci -{ - -/** - * @brief Verify the data type of UINT8 quantized node - * @details - * - * Targets to verify - * - node's output (i.e., node itself) - * - node's inputs - */ -struct VerifyQuantizedNodeU8Type final : public luci::CircleNodeVisitor<bool> -{ -private: - bool has_type(const loco::Node *node, Type dtype) - { - auto circle_node = loco::must_cast<const luci::CircleNode *>(node); - return circle_node->dtype() == dtype; - } - -private: - bool visit(const luci::CircleConv2D *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->filter(), Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S32)) - return true; - } - - bool visit(const luci::CircleConcatenation *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - for (uint32_t i = 0; i < node->numValues(); i++) - { - RETURN_FALSE_UNLESS(has_type(node->values(i), Type::U8)) - } - return true; - } - - bool visit(const luci::CircleDepthToSpace *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) - return true; - } - - bool visit(const luci::CircleDepthwiseConv2D *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->filter(), Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S32)) - return true; - } - - bool visit(const luci::CircleInstanceNorm *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->gamma(), Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->beta(), Type::U8)) - return true; - } - - bool visit(const luci::CirclePack *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - for (uint32_t i = 0; i < node->values_count(); i++) - { - RETURN_FALSE_UNLESS(has_type(node->values(i), Type::U8)) - } - return true; - } - - bool visit(const luci::CirclePad *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32)) - return true; - } - - bool visit(const luci::CirclePadV2 *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32)) - RETURN_FALSE_UNLESS(has_type(node->constant_values(), Type::U8)) - return true; - } - - bool visit(const luci::CircleMirrorPad *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32)) - return true; - } - - bool visit(const luci::CirclePRelu *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->alpha(), Type::U8)) - return true; - } - - bool visit(const luci::CircleTransposeConv *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->outBackprop(), Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->filter(), Type::U8)) - luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); - if (bias != nullptr) - RETURN_FALSE_UNLESS(has_type(bias, Type::S32)) - return true; - } - - bool visit(const luci::CircleFullyConnected *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->weights(), Type::U8)) - luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias()); - if (bias != nullptr) - RETURN_FALSE_UNLESS(has_type(bias, Type::S32)) - return true; - } - - bool visit(const luci::CircleAdd *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8)) - return true; - } - - bool visit(const luci::CircleAveragePool2D *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->value(), Type::U8)) - return true; - } - - bool visit(const luci::CircleBatchToSpaceND *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) - return true; - } - - bool visit(const luci::CircleLogicalOr *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::BOOL)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::BOOL)) - RETURN_FALSE_UNLESS(has_type(node->y(), Type::BOOL)) - return true; - } - - bool visit(const luci::CircleMaxPool2D *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->value(), Type::U8)) - return true; - } - - bool visit(const luci::CircleLocalResponseNormalization *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) - return true; - } - - bool visit(const luci::CircleMean *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->reduction_indices(), Type::S32)) - return true; - } - - bool visit(const luci::CircleMul *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8)) - return true; - } - - bool visit(const luci::CircleNotEqual *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::BOOL)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8)) - return true; - } - - bool visit(const luci::CircleRelu *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->features(), Type::U8)) - return true; - } - - bool visit(const luci::CircleReshape *node) - { - if (node->quantparam()) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->tensor(), Type::U8)) - } - else - { - RETURN_FALSE_UNLESS(has_type(node->tensor(), node->dtype())) - } - luci::CircleConst *shape = dynamic_cast<luci::CircleConst *>(node->shape()); - if (shape != nullptr) - RETURN_FALSE_UNLESS(has_type(shape, Type::S32)) - return true; - } - - bool visit(const luci::CircleLogistic *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8)) - - RETURN_FALSE_UNLESS(node->quantparam()); - RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 256.0f); - RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0); - return true; - } - - bool visit(const luci::CircleSoftmax *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->logits(), Type::U8)) - - RETURN_FALSE_UNLESS(node->quantparam()); - RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 255.0f); - RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0); - return true; - } - - bool visit(const luci::CircleSpaceToBatchND *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) - return true; - } - - bool visit(const luci::CircleSpaceToDepth *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) - return true; - } - - bool visit(const luci::CircleSlice *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->begin(), Type::S32) || has_type(node->begin(), Type::S64)) - RETURN_FALSE_UNLESS(has_type(node->size(), Type::S32) || has_type(node->size(), Type::S64)) - return true; - } - - bool visit(const luci::CircleSplit *node) - { - // node's output is the input of CircleSplitOut, thus not quantized - RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) - return true; - } - - bool visit(const luci::CircleSplitOut *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - - // SplitOut has the same qparam with the input of Split - auto split = loco::must_cast<luci::CircleSplit *>(node->input()); - auto input = loco::must_cast<luci::CircleNode *>(split->input()); - RETURN_FALSE_UNLESS(node->quantparam()); - RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]); - RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]); - return true; - } - - bool visit(const luci::CircleSplitV *node) - { - // node's output is the input of CircleSplitVOut, thus not quantized - RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) - return true; - } - - bool visit(const luci::CircleSplitVOut *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - - // SplitVOut has the same qparam with the input of SplitV - auto splitv = loco::must_cast<luci::CircleSplitV *>(node->input()); - auto input = loco::must_cast<luci::CircleNode *>(splitv->input()); - RETURN_FALSE_UNLESS(node->quantparam()); - RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]); - RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]); - return true; - } - - bool visit(const luci::CircleStridedSlice *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) - - auto input = loco::must_cast<luci::CircleNode *>(node->input()); - RETURN_FALSE_UNLESS(node->quantparam()); - RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]); - RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]); - return true; - } - - bool visit(const luci::CircleArgMax *node) - { - RETURN_FALSE_UNLESS(has_type(node, node->output_type())) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->dimension(), Type::S32) || - has_type(node->dimension(), Type::S64)) - return true; - } - - bool visit(const luci::CircleTanh *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8)) - - RETURN_FALSE_UNLESS(node->quantparam()); - RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 2.0f / 256.0f); - RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 128); - return true; - } - - bool visit(const luci::CircleTranspose *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->a(), Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->perm(), Type::S32)) - return true; - } - - bool visit(const luci::CircleFloor *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8)) - - // This checks the value of scale is an integer - RETURN_FALSE_UNLESS(node->quantparam()); - RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]); - return true; - } - - bool visit(const luci::CircleGreater *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::BOOL)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8)) - return true; - } - - bool visit(const luci::CircleGreaterEqual *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::BOOL)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8)) - return true; - } - - bool visit(const luci::CircleDiv *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8)) - return true; - } - - bool visit(const luci::CircleFloorDiv *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8)) - - // This checks the value of scale is an integer - RETURN_FALSE_UNLESS(node->quantparam()); - RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]); - return true; - } - - bool visit(const luci::CircleRsqrt *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8)) - return true; - } - - bool visit(const luci::CircleSqrt *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8)) - return true; - } - - bool visit(const luci::CircleElu *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->features(), Type::U8)) - return true; - } - - bool visit(const luci::CirclePow *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8)) - return true; - } - - bool visit(const luci::CircleResizeBilinear *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) - return true; - } - - bool visit(const luci::CircleResizeNearestNeighbor *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8)) - return true; - } - - bool visit(const luci::CircleUnpack *node) - { - // node's output is the input of CircleUnpackOut, thus not quantized - RETURN_FALSE_UNLESS(has_type(node->value(), Type::U8)) - return true; - } - - bool visit(const luci::CircleUnpackOut *node) - { - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - - // UnpackOut has the same qparam with the input of Unpack - auto Unpack = loco::must_cast<luci::CircleUnpack *>(node->input()); - auto input = loco::must_cast<luci::CircleNode *>(Unpack->value()); - RETURN_FALSE_UNLESS(node->quantparam() && input->quantparam()); - RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]); - RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]); - return true; - } - - bool visit(const luci::CircleCast *node) - { - auto *input = loco::must_cast<luci::CircleNode *>(node->x()); - bool input_quantized = input->quantparam() != nullptr; - if (input_quantized) - { - RETURN_FALSE_UNLESS(has_type(input, node->in_data_type())) - RETURN_FALSE_UNLESS(has_type(input, Type::U8)) - } - - bool node_quantized = node->quantparam() != nullptr; - if (node_quantized) - { - RETURN_FALSE_UNLESS(has_type(node, node->out_data_type())) - RETURN_FALSE_UNLESS(has_type(node, Type::U8)) - } - return true; - } - - // TODO: Implement more Ops - - bool visit(const luci::CircleNode *) { return true; } -}; - -} // namespace luci - -#undef RETURN_FALSE_UNLESS - -#endif // __LUCI_VERIFY_QUNTIZED_NODE_U8_TYPE_H__ diff --git a/compiler/luci/pass/src/helpers/LayerInfoMap.cpp b/compiler/luci/pass/src/helpers/LayerInfoMap.cpp new file mode 100644 index 000000000..ac07f9ec9 --- /dev/null +++ b/compiler/luci/pass/src/helpers/LayerInfoMap.cpp @@ -0,0 +1,189 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "LayerInfoMap.h" + +#include <luci/IR/CircleNode.h> + +#include <cassert> + +namespace luci +{ +namespace +{ + +bool is_multiple_output_node(const luci::CircleNode *node) +{ + switch (node->opcode()) + { + // The following nodes have multiple outputs. Output tensors are not produced by themselves but + // by the corresponding *Out nodes. + case luci::CircleOpcode::SPLIT: + case luci::CircleOpcode::SPLIT_V: + case luci::CircleOpcode::TOPK_V2: + case luci::CircleOpcode::UNIQUE: + case luci::CircleOpcode::UNPACK: + return true; + // TODO: Support ops + case luci::CircleOpcode::BIDIRECTIONAL_SEQUENCE_LSTM: + case luci::CircleOpcode::CUSTOM: + case luci::CircleOpcode::IF: + case luci::CircleOpcode::NON_MAX_SUPPRESSION_V4: + case luci::CircleOpcode::NON_MAX_SUPPRESSION_V5: + case luci::CircleOpcode::WHILE: + throw std::runtime_error("Unsupported op now"); + default: + return false; + } +} + +const luci::CircleNode *get_multi_output_node(const luci::CircleNode *node) +{ + if (is_multiple_output_node(node)) + return node; + + switch (node->opcode()) + { + // The following nodes denote outputs of multiple-output nodes. + case luci::CircleOpcode::CIRCLESPLITOUT: + { + const auto split_out = loco::must_cast<const CircleSplitOut *>(node); + return loco::must_cast<luci::CircleNode *>(split_out->input()); + } + case luci::CircleOpcode::CIRCLESPLITVOUT: + { + const auto splitv_out = loco::must_cast<const CircleSplitVOut *>(node); + return loco::must_cast<luci::CircleNode *>(splitv_out->input()); + } + case luci::CircleOpcode::CIRCLETOPKV2OUT: + { + const auto top_kv2_out = loco::must_cast<const CircleTopKV2Out *>(node); + return loco::must_cast<luci::CircleNode *>(top_kv2_out->input()); + } + case luci::CircleOpcode::CIRCLEUNIQUEOUT: + { + const auto unique_out = loco::must_cast<const CircleUniqueOut *>(node); + return loco::must_cast<luci::CircleNode *>(unique_out->input()); + } + case luci::CircleOpcode::CIRCLEUNPACKOUT: + { + const auto unpack_out = loco::must_cast<const CircleUnpackOut *>(node); + return loco::must_cast<luci::CircleNode *>(unpack_out->input()); + } + // TODO: Support these ops + case luci::CircleOpcode::CIRCLEBIDIRECTIONAL_SEQUENCE_LSTM_OUT: + case luci::CircleOpcode::CIRCLECUSTOMOUT: + case luci::CircleOpcode::CIRCLEIFOUT: + case luci::CircleOpcode::CIRCLENONMAXSUPPRESSIONV4OUT: + case luci::CircleOpcode::CIRCLENONMAXSUPPRESSIONV5OUT: + case luci::CircleOpcode::CIRCLEWHILEOUT: + throw std::runtime_error("Unsupported op now"); + default: + return nullptr; + } +} + +bool same_setting(const LayerInfo &left, const LayerInfo &right) +{ + return left.dtype == right.dtype and left.granularity == right.granularity; +} + +void add_multi_output_node(LayerInfoMap &info_by_name, LayerInfo &layer_info, + const luci::CircleNode *node) +{ + assert(is_multiple_output_node(node)); // FIX_CALLER_UNLESS + + const auto succs_nodes = loco::succs(node); + const auto name = node->name(); + + if (info_by_name.find(name) != info_by_name.end()) + { + // Check that all outputs have equal dtype and granularity + for (const auto succs_node : succs_nodes) + { + const auto succs_circle_node = loco::must_cast<luci::CircleNode *>(succs_node); + + const auto it = info_by_name.find(succs_circle_node->name()); + if (it != info_by_name.end() and not same_setting(layer_info, (it->second))) + throw std::runtime_error("Outputs of multiple-output nodes should have equal dtype and " + "granularity. Check the quantization configuration file"); + } + return; + } + + // Add multiple output node to info_by_name + info_by_name[name] = {name, layer_info.dtype, layer_info.granularity}; + + // Add outputs node to info_by_name + for (const auto succs_node : succs_nodes) + { + const auto succs_circle_node = loco::must_cast<luci::CircleNode *>(succs_node); + const auto succs_circle_node_name = succs_circle_node->name(); + info_by_name[succs_circle_node_name] = {succs_circle_node_name, layer_info.dtype, + layer_info.granularity}; + } +} + +} // namespace + +LayerInfoMap layer_info_map(loco::Graph *g, std::vector<LayerInfo> &layers_info) +{ + LayerInfoMap info_by_name; + + for (auto &&info : layers_info) + { + auto name = info.name; + bool found = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto cnode = loco::must_cast<luci::CircleNode *>(node); + if (cnode->opcode() == luci::CircleOpcode::CIRCLEOUTPUT) + continue; + + if (cnode->name() == name) + { + // Check and add multiple-output node and its outputs to info_by_name + if (const auto multi_output = get_multi_output_node(cnode)) + { + add_multi_output_node(info_by_name, info, multi_output); + found = true; + continue; + } + + if (info_by_name.find(name) != info_by_name.end()) + { + throw std::runtime_error("Duplicate layer name " + name + + ". Check layer names in the quantization configuration file."); + } + + info_by_name[name] = info; + found = true; + continue; + } + } + + if (not found) + throw std::runtime_error("No such layer named " + name + + ". Check layer names in the quantization configuration file."); + } + + // TODO Check all names in layers_info exist in the info_by_name + // TODO Check names in info_by_name but not in layers_info are from virtual outputs + + return info_by_name; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/helpers/LayerInfoMap.h b/compiler/luci/pass/src/helpers/LayerInfoMap.h new file mode 100644 index 000000000..bb4724a50 --- /dev/null +++ b/compiler/luci/pass/src/helpers/LayerInfoMap.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_PASS_HELPERS_LAYER_INFO_MAP_H__ +#define __LUCI_PASS_HELPERS_LAYER_INFO_MAP_H__ + +#include <luci/Pass/QuantizationParameters.h> + +#include <unordered_map> + +namespace luci +{ + +using LayerInfoMap = std::unordered_map<std::string, luci::LayerInfo>; + +LayerInfoMap layer_info_map(loco::Graph *g, std::vector<LayerInfo> &layers_info); + +} // namespace luci + +#endif // __LUCI_PASS_HELPERS_LAYER_INFO_MAP_H__ diff --git a/compiler/luci/pass/src/helpers/LayerInfoMap.test.cpp b/compiler/luci/pass/src/helpers/LayerInfoMap.test.cpp new file mode 100644 index 000000000..2ed28eda4 --- /dev/null +++ b/compiler/luci/pass/src/helpers/LayerInfoMap.test.cpp @@ -0,0 +1,201 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "LayerInfoMap.h" + +#include <luci/IR/CircleNode.h> +#include <luci/test/TestIOGraph.h> + +#include <gtest/gtest.h> + +namespace +{ + +class SoftmaxTestGraph : public luci::test::TestIOGraph +{ +public: + void init(void) + { + TestIOGraph::init({32}, {32}); + _softmax = g()->nodes()->create<luci::CircleSoftmax>(); + { + _softmax->logits(input()); + _softmax->beta(0.1); + _softmax->name("test"); + } + output()->from(_softmax); + } + +private: + luci::CircleSoftmax *_softmax = nullptr; +}; + +class SplitAddTestGraph : public luci::test::TestIOGraph +{ +public: + void init(void) + { + TestIOGraph::init({6, 1, 2}, {3, 1, 2}); + _split_dim = g()->nodes()->create<luci::CircleConst>(); + { + _split_dim->rank(1); + _split_dim->dtype(loco::DataType::S32); + _split_dim->size<loco::DataType::S32>(1); + _split_dim->at<loco::DataType::S32>(0); + _split_dim->shape({1}); + _split_dim->name("split_dim"); + } + + _split = g()->nodes()->create<luci::CircleSplit>(); + { + _split->input(input()); + _split->num_split(2); + _split->split_dim(_split_dim); + _split->name("split0"); + } + + _split_out_1 = g()->nodes()->create<luci::CircleSplitOut>(); + { + _split_out_1->input(_split); + _split_out_1->index(0); + _split_out_1->name("split0"); + } + + _split_out_2 = g()->nodes()->create<luci::CircleSplitOut>(); + { + _split_out_2->input(_split); + _split_out_2->index(1); + _split_out_2->name("split1"); + } + + _add = g()->nodes()->create<luci::CircleAdd>(); + { + _add->x(_split_out_1); + _add->y(_split_out_2); + _add->name("add"); + } + output()->from(_add); + } + +private: + luci::CircleSplit *_split = nullptr; + luci::CircleSplitOut *_split_out_1 = nullptr; + luci::CircleSplitOut *_split_out_2 = nullptr; + luci::CircleConst *_split_dim = nullptr; + luci::CircleAdd *_add = nullptr; +}; + +} // namespace + +TEST(LayerInfoMapTest, simple_test) +{ + SoftmaxTestGraph g; + g.init(); + + luci::LayerInfo info; + { + info.name = "test"; + info.dtype = loco::DataType::U8; + info.granularity = luci::QuantizationGranularity::ChannelWise; + } + std::vector<luci::LayerInfo> v; + v.emplace_back(info); + auto map = luci::layer_info_map(g.g(), v); + + EXPECT_EQ("test", map["test"].name); + EXPECT_EQ(loco::DataType::U8, map["test"].dtype); + EXPECT_EQ(luci::QuantizationGranularity::ChannelWise, map["test"].granularity); +} + +TEST(LayerInfoMapTest, multiple_output_node_test) +{ + SplitAddTestGraph g; + g.init(); + + luci::LayerInfo info; + { + info.name = "split0"; + info.dtype = loco::DataType::U8; + info.granularity = luci::QuantizationGranularity::ChannelWise; + } + std::vector<luci::LayerInfo> v; + v.emplace_back(info); + auto map = luci::layer_info_map(g.g(), v); + + EXPECT_EQ(map.size(), 2); + EXPECT_EQ("split0", map["split0"].name); + EXPECT_EQ("split1", map["split1"].name); + + EXPECT_EQ(loco::DataType::U8, map["split0"].dtype); + EXPECT_EQ(luci::QuantizationGranularity::ChannelWise, map["split0"].granularity); +} + +TEST(LayerInfoMapTest, invalid_layer_info_multiple_output_node_NEG) +{ + SplitAddTestGraph g; + g.init(); + + luci::LayerInfo info_0; + { + info_0.name = "split0"; + info_0.dtype = loco::DataType::U8; + info_0.granularity = luci::QuantizationGranularity::ChannelWise; + } + luci::LayerInfo info_1; + { + info_1.name = "split1"; + info_1.dtype = loco::DataType::S16; + info_1.granularity = luci::QuantizationGranularity::ChannelWise; + } + std::vector<luci::LayerInfo> v; + v.emplace_back(info_0); + v.emplace_back(info_1); + + EXPECT_ANY_THROW(luci::layer_info_map(g.g(), v)); +} + +TEST(LayerInfoMapTest, duplicate_name_NEG) +{ + SoftmaxTestGraph g; + g.init(); + g.input()->name("test"); + + luci::LayerInfo info; + { + info.name = "test"; + info.dtype = loco::DataType::U8; + info.granularity = luci::QuantizationGranularity::ChannelWise; + } + std::vector<luci::LayerInfo> v; + v.emplace_back(info); + EXPECT_ANY_THROW(luci::layer_info_map(g.g(), v)); +} + +TEST(LayerInfoMapTest, no_name_NEG) +{ + SoftmaxTestGraph g; + g.init(); + + luci::LayerInfo info; + { + info.name = "noname"; + info.dtype = loco::DataType::U8; + info.granularity = luci::QuantizationGranularity::ChannelWise; + } + std::vector<luci::LayerInfo> v; + v.emplace_back(info); + EXPECT_ANY_THROW(luci::layer_info_map(g.g(), v)); +} diff --git a/compiler/luci/requires.cmake b/compiler/luci/requires.cmake index 3ccc58128..e896188be 100644 --- a/compiler/luci/requires.cmake +++ b/compiler/luci/requires.cmake @@ -4,8 +4,8 @@ require("loco") require("locop") require("logo") require("logo-core") -require("mio-circle") -require("mio-tflite") +require("mio-circle04") +require("mio-tflite280") require("oops") require("hermes") require("hermes-std") diff --git a/compiler/luci/service/CMakeLists.txt b/compiler/luci/service/CMakeLists.txt index 0e6097f96..24bdfc152 100644 --- a/compiler/luci/service/CMakeLists.txt +++ b/compiler/luci/service/CMakeLists.txt @@ -10,7 +10,6 @@ add_library(luci_service ${LUCI_LIBRARY_TYPE} ${SOURCES}) target_include_directories(luci_service PRIVATE src) target_include_directories(luci_service PUBLIC include) target_link_libraries(luci_service PUBLIC luci_lang) -target_link_libraries(luci_service PUBLIC mio_circle) target_link_libraries(luci_service PUBLIC logo_core) target_link_libraries(luci_service PRIVATE luci_log) target_link_libraries(luci_service PRIVATE luci_logex) diff --git a/compiler/luci/service/include/luci/Service/CircleShapeInference.h b/compiler/luci/service/include/luci/Service/CircleShapeInference.h index ead12d074..2c1120941 100644 --- a/compiler/luci/service/include/luci/Service/CircleShapeInference.h +++ b/compiler/luci/service/include/luci/Service/CircleShapeInference.h @@ -17,11 +17,12 @@ #ifndef __LUCI_CIRCLE_SHAPE_INFERENCE_H__ #define __LUCI_CIRCLE_SHAPE_INFERENCE_H__ -#include <loco/IR/Nodes.h> - +#include <luci/Service/CircleShapeInferenceRule.h> #include <luci/IR/CircleNodes.h> #include <luci/IR/CircleNodeVisitor.h> -#include <luci/Service/CircleShapeInferenceRule.h> + +#include <loco/IR/NodeShape.h> +#include <loco/IR/TensorShape.h> namespace luci { diff --git a/compiler/luci/service/include/luci/Service/CircleTypeInference.h b/compiler/luci/service/include/luci/Service/CircleTypeInference.h index d62731380..e0ceabeac 100644 --- a/compiler/luci/service/include/luci/Service/CircleTypeInference.h +++ b/compiler/luci/service/include/luci/Service/CircleTypeInference.h @@ -17,13 +17,11 @@ #ifndef __LUCI_CIRCLE_TYPE_INFERENCE_H__ #define __LUCI_CIRCLE_TYPE_INFERENCE_H__ -#include <loco/IR/Nodes.h> - -#include <mio/circle/schema_generated.h> - +#include <luci/Service/CircleTypeInferenceRule.h> #include <luci/IR/CircleNodes.h> #include <luci/IR/CircleNodeVisitor.h> -#include <luci/Service/CircleTypeInferenceRule.h> + +#include <loco/IR/DataType.h> namespace luci { diff --git a/compiler/luci/service/src/CircleCloneNode.h b/compiler/luci/service/src/CircleCloneNode.h index 3926147f5..99e4561b3 100644 --- a/compiler/luci/service/src/CircleCloneNode.h +++ b/compiler/luci/service/src/CircleCloneNode.h @@ -208,6 +208,7 @@ public: luci::CircleNode *visit(const luci::CircleSquaredDifference *) final; luci::CircleNode *visit(const luci::CircleSqueeze *) final; luci::CircleNode *visit(const luci::CircleStridedSlice *) final; + luci::CircleNode *visit(const luci::CircleSVDF *) final; luci::CircleNode *visit(const luci::CircleSub *) final; luci::CircleNode *visit(const luci::CircleSum *) final; luci::CircleNode *visit(const luci::CircleTanh *) final; @@ -269,6 +270,7 @@ public: luci::CircleNode *visit(const luci::CircleTopKV2Out *) final; luci::CircleNode *visit(const luci::CircleUniqueOut *) final; luci::CircleNode *visit(const luci::CircleUnpackOut *) final; + luci::CircleNode *visit(const luci::CircleVariable *) final; luci::CircleNode *visit(const luci::CircleWhileOut *) final; // Handle in CircleNode diff --git a/compiler/luci/service/src/CircleNodeClone.cpp b/compiler/luci/service/src/CircleNodeClone.cpp index d2033dd0c..220c6096c 100644 --- a/compiler/luci/service/src/CircleNodeClone.cpp +++ b/compiler/luci/service/src/CircleNodeClone.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "luci/IR/CircleQuantParam.h" #include "luci/Service/CircleNodeClone.h" #include "CircleCloneNode.h" @@ -45,18 +46,7 @@ void copy_common_attributes(const luci::CircleNode *src, luci::CircleNode *dst) dst->shape_status(src->shape_status()); // quantparam - const auto *quantparam = src->quantparam(); - if (quantparam != nullptr) - { - auto qparam = std::make_unique<luci::CircleQuantParam>(); - qparam->scale = quantparam->scale; - qparam->zerop = quantparam->zerop; - qparam->min = quantparam->min; - qparam->max = quantparam->max; - qparam->quantized_dimension = quantparam->quantized_dimension; - - dst->quantparam(std::move(qparam)); - } + copy_quantparam(src, dst); // sparsity const auto *sparsity = src->sparsityparam(); diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.cpp index 5d6a31050..9d156f3e2 100644 --- a/compiler/luci/service/src/CircleShapeInferenceRule.cpp +++ b/compiler/luci/service/src/CircleShapeInferenceRule.cpp @@ -1,5 +1,6 @@ /* * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2017 The TensorFlow Authors. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -196,23 +197,18 @@ template <class CIRCLENODE> loco::NodeShape broadcast_xy(const CIRCLENODE *node) return loco::NodeShape{output_shape}; } -template <class CIRCLENODE> loco::NodeShape use_inputs(const CIRCLENODE *node) -{ - auto inputs_shape = luci::shape_get(node->inputs()).template as<loco::TensorShape>(); - return loco::NodeShape{inputs_shape}; -} +#define DECLARE_USE_SINGLE(NAME) \ + template <class CIRCLENODE> loco::NodeShape use_##NAME(const CIRCLENODE *node) \ + { \ + auto inputs_shape = luci::shape_get(node->NAME()).template as<loco::TensorShape>(); \ + return loco::NodeShape{inputs_shape}; \ + } -template <class CIRCLENODE> loco::NodeShape use_x(const CIRCLENODE *node) -{ - auto x_shape = luci::shape_get(node->x()).template as<loco::TensorShape>(); - return loco::NodeShape{x_shape}; -} +DECLARE_USE_SINGLE(inputs); +DECLARE_USE_SINGLE(x); +DECLARE_USE_SINGLE(logits); -template <class CIRCLENODE> loco::NodeShape use_logits(const CIRCLENODE *node) -{ - auto shape = luci::shape_get(node->logits()).template as<loco::TensorShape>(); - return loco::NodeShape{shape}; -} +#undef DECLARE_USE_SINGLE template <class CIRCLENODE> loco::NodeShape use_paddings(const CIRCLENODE *node, const luci::CircleConst *paddings) @@ -721,6 +717,8 @@ loco::NodeShape infer_fully_connected(const luci::CircleFullyConnected *node) auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); auto weights_shape = luci::shape_get(node->weights()).as<loco::TensorShape>(); +// TODO Remove following unused code +#if 0 // Checking shape capability for fully connected layer // Input: a tensor of at least rank 2 [D1, D2, ... Dn] // Weight: [# of units, K] @@ -741,6 +739,40 @@ loco::NodeShape infer_fully_connected(const luci::CircleFullyConnected *node) out_shape.rank(2); out_shape.dim(0) = batch_size; out_shape.dim(1) = weights_shape.dim(0); +#endif + + loco::TensorShape out_shape; + + // NOTE Some recipes in some repositories are using rank 4 input for FullyConnected. + // Until they are all fixed, disable following assert. + // TODO Enable following assert after related fixes are applied + // https://github.com/tensorflow/tensorflow/blob/ea33c1e7a25d8025e8ee405ad8ab7be261798d76/tensorflow/lite/kernels/fully_connected.cc#L194 + // LUCI_ASSERT(input_shape.rank() == 2 || input_shape.rank() == 3, + // "Input rank of FullyConnected should be 2 or 3"); + + // https://github.com/tensorflow/tensorflow/blob/ea33c1e7a25d8025e8ee405ad8ab7be261798d76/tensorflow/lite/kernels/fully_connected.cc#L225 + LUCI_ASSERT(weights_shape.rank() == 2, "Weights of FullyConnected should be 2"); + + // https://github.com/tensorflow/tensorflow/blob/ea33c1e7a25d8025e8ee405ad8ab7be261798d76/tensorflow/lite/kernels/fully_connected.cc#L353-L367 + if (node->keep_num_dims()) + { + out_shape.rank(input_shape.rank()); + for (uint32_t i = 0; i < input_shape.rank(); ++i) + out_shape.dim(i) = input_shape.dim(i); + out_shape.dim(out_shape.rank() - 1) = weights_shape.dim(0); + } + else + { + uint32_t input_size = 1; + for (uint32_t i = 0; i < input_shape.rank(); i++) + { + input_size = input_size * input_shape.dim(i).value(); + } + const uint32_t batch_size = input_size / weights_shape.dim(1).value(); + out_shape.rank(2); + out_shape.dim(0) = batch_size; + out_shape.dim(1) = weights_shape.dim(0); + } return loco::NodeShape{out_shape}; } @@ -1554,6 +1586,30 @@ loco::NodeShape infer_squeeze(const luci::CircleSqueeze *node) return loco::NodeShape{output_shape}; } +loco::NodeShape infer_svdf(const luci::CircleSVDF *node) +{ + const auto ifm_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); + const auto weight_feature_shape = luci::shape_get(node->weight_feature()).as<loco::TensorShape>(); + + assert(ifm_shape.rank() == 2); + assert(weight_feature_shape.rank() == 2); + + assert(ifm_shape.dim(1) == weight_feature_shape.dim(1)); + assert(weight_feature_shape.dim(0).known()); + + const auto rank = node->svdf_rank(); + const auto num_filters = weight_feature_shape.dim(0).value(); + assert(num_filters % rank == 0); + const auto num_units = num_filters / rank; + + loco::TensorShape ofm_shape; + ofm_shape.rank(2); + ofm_shape.dim(0) = ifm_shape.dim(0); + ofm_shape.dim(1) = num_units; + + return loco::NodeShape{ofm_shape}; +} + loco::NodeShape infer_tile(const luci::CircleTile *node) { const loco::DataType S32 = loco::DataType::S32; @@ -2393,6 +2449,8 @@ public: return loco::NodeShape{output_shape}; } + loco::NodeShape visit(const luci::CircleSVDF *node) final { return infer_svdf(node); } + loco::NodeShape visit(const luci::CircleTanh *node) final { return use_x(node); } loco::NodeShape visit(const luci::CircleTile *node) final { return infer_tile(node); } @@ -2486,6 +2544,8 @@ public: loco::NodeShape visit(const luci::CircleUnpackOut *node) final { return infer_unpack_out(node); } + loco::NodeShape visit(const luci::CircleVariable *node) final { return use_own(node); } + loco::NodeShape visit(const luci::CircleWhileOut *node) final { return infer_while_out(node); } }; diff --git a/compiler/luci/service/src/CircleTypeInferenceRule.cpp b/compiler/luci/service/src/CircleTypeInferenceRule.cpp index 5f6d46f2b..438c4a364 100644 --- a/compiler/luci/service/src/CircleTypeInferenceRule.cpp +++ b/compiler/luci/service/src/CircleTypeInferenceRule.cpp @@ -478,6 +478,11 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT loco::DataType visit(const luci::CircleSum *node) final { return luci::dtype_get(node->input()); } + loco::DataType visit(const luci::CircleSVDF *node) final + { + return luci::dtype_get(node->input()); + } + loco::DataType visit(const luci::CircleTanh *node) final { return luci::dtype_get(node->x()); } loco::DataType visit(const luci::CircleTile *node) final @@ -605,6 +610,8 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT return loco::DataType::S32; } + loco::DataType visit(const luci::CircleVariable *node) final { return node->dtype(); } + loco::DataType visit(const luci::CircleUniqueOut *node) final { if (node->index() == 0) diff --git a/compiler/luci/service/src/Nodes/CircleSVDF.cpp b/compiler/luci/service/src/Nodes/CircleSVDF.cpp new file mode 100644 index 000000000..d4c3ce88f --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSVDF.cpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleSVDF *node) +{ + if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED) + return nullptr; + + auto *cloned = _graph->nodes()->create<luci::CircleSVDF>(); + if (cloned != nullptr) + { + cloned->fusedActivationFunction(node->fusedActivationFunction()); + cloned->asymmetric_quantize_inputs(node->asymmetric_quantize_inputs()); + cloned->svdf_rank(node->svdf_rank()); + } + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleSVDF.test.cpp b/compiler/luci/service/src/Nodes/CircleSVDF.test.cpp new file mode 100644 index 000000000..d6edaf1cc --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSVDF.test.cpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_SVDF) +{ + auto g = loco::make_graph(); + auto node_svdf = g->nodes()->create<luci::CircleSVDF>(); + node_svdf->fusedActivationFunction(luci::FusedActFunc::RELU); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_svdf, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_svdf = dynamic_cast<luci::CircleSVDF *>(cloned); + ASSERT_NE(nullptr, cloned_svdf); + ASSERT_EQ(node_svdf->asymmetric_quantize_inputs(), cloned_svdf->asymmetric_quantize_inputs()); + ASSERT_EQ(node_svdf->svdf_rank(), cloned_svdf->svdf_rank()); +} + +TEST(CloneNodeTest, clone_SVDF_NEG) +{ + auto g = loco::make_graph(); + auto node_svdf = g->nodes()->create<luci::CircleSVDF>(); + node_svdf->fusedActivationFunction(luci::FusedActFunc::UNDEFINED); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_svdf, gc.get()); + ASSERT_EQ(nullptr, cloned); +} diff --git a/compiler/luci/service/src/Nodes/CircleVariable.cpp b/compiler/luci/service/src/Nodes/CircleVariable.cpp new file mode 100644 index 000000000..c1430bd3a --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleVariable.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleVariable *) +{ + return _graph->nodes()->create<luci::CircleVariable>(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleVariable.test.cpp b/compiler/luci/service/src/Nodes/CircleVariable.test.cpp new file mode 100644 index 000000000..7d29438be --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleVariable.test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include <gtest/gtest.h> + +TEST(CloneNodeTest, clone_Variable) +{ + auto g = loco::make_graph(); + auto node_dummy = g->nodes()->create<luci::CircleVariable>(); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_dummy, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_variable = dynamic_cast<luci::CircleVariable *>(cloned); + ASSERT_NE(nullptr, cloned_variable); +} diff --git a/compiler/luci/tests/CMakeLists.txt b/compiler/luci/tests/CMakeLists.txt index c03835823..1333efb7d 100644 --- a/compiler/luci/tests/CMakeLists.txt +++ b/compiler/luci/tests/CMakeLists.txt @@ -1,3 +1,14 @@ +set(CIRCLECHEF_FILE_PATH $<TARGET_FILE:circlechef-file>) +set(TFLCHEF_FILE_PATH $<TARGET_FILE:tflchef-file>) +set(TFLITE2CIRCLE_PATH $<TARGET_FILE:tflite2circle>) +if(DEFINED ENV{BUILD_HOST_EXEC}) + # TODO use better way to represent path for host executable + set(CIRCLECHEF_FILE_PATH $ENV{BUILD_HOST_EXEC}/compiler/circlechef/tools/file/circlechef-file) + set(TFLCHEF_FILE_PATH $ENV{BUILD_HOST_EXEC}/compiler/tflchef/tools/file/tflchef-file) + set(TFLITE2CIRCLE_PATH $ENV{BUILD_HOST_EXEC}/compiler/tflite2circle/tflite2circle) + message(STATUS "TFLITE2CIRCLE_PATH = ${TFLITE2CIRCLE_PATH}") +endif(DEFINED ENV{BUILD_HOST_EXEC}) + # TODO use local test.recipe files for small networks file(GLOB RECIPES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*/test.recipe") @@ -17,14 +28,14 @@ foreach(RECIPE IN ITEMS ${RECIPES}) # Generate .tflite add_custom_command(OUTPUT "${RECIPE_OUTPUT_FILE}" - COMMAND tflchef-file "${RECIPE_SOURCE_FILE}" "${RECIPE_OUTPUT_FILE}" - DEPENDS tflchef-file "${RECIPE_SOURCE_FILE}" + COMMAND ${TFLCHEF_FILE_PATH} "${RECIPE_SOURCE_FILE}" "${RECIPE_OUTPUT_FILE}" + DEPENDS ${TFLCHEF_FILE_PATH} "${RECIPE_SOURCE_FILE}" COMMENT "Generating ${RECIPE_OUTPUT_FILE}") # Generate .circle add_custom_command(OUTPUT "${CIRCLE_OUTPUT_FILE}" - COMMAND tflite2circle "${RECIPE_OUTPUT_FILE}" "${CIRCLE_OUTPUT_FILE}" - DEPENDS tflite2circle "${RECIPE_OUTPUT_FILE}" + COMMAND ${TFLITE2CIRCLE_PATH} "${RECIPE_OUTPUT_FILE}" "${CIRCLE_OUTPUT_FILE}" + DEPENDS ${TFLITE2CIRCLE_PATH} "${RECIPE_OUTPUT_FILE}" COMMENT "Generating ${CIRCLE_OUTPUT_FILE}") list(APPEND TESTFILES "${CIRCLE_OUTPUT_FILE}") @@ -52,14 +63,14 @@ foreach(RECIPE IN ITEMS ${RECIPES}) # Generate .tflite add_custom_command(OUTPUT "${RECIPE_OUTPUT_FILE}" - COMMAND tflchef-file "${RECIPE_SOURCE_FILE}" "${RECIPE_OUTPUT_FILE}" - DEPENDS tflchef-file "${RECIPE_SOURCE_FILE}" + COMMAND ${TFLCHEF_FILE_PATH} "${RECIPE_SOURCE_FILE}" "${RECIPE_OUTPUT_FILE}" + DEPENDS ${TFLCHEF_FILE_PATH} "${RECIPE_SOURCE_FILE}" COMMENT "Generating ${RECIPE_OUTPUT_FILE}") # Generate .circle add_custom_command(OUTPUT "${CIRCLE_OUTPUT_FILE}" - COMMAND tflite2circle "${RECIPE_OUTPUT_FILE}" "${CIRCLE_OUTPUT_FILE}" - DEPENDS tflite2circle "${RECIPE_OUTPUT_FILE}" + COMMAND ${TFLITE2CIRCLE_PATH} "${RECIPE_OUTPUT_FILE}" "${CIRCLE_OUTPUT_FILE}" + DEPENDS ${TFLITE2CIRCLE_PATH} "${RECIPE_OUTPUT_FILE}" COMMENT "Generating ${CIRCLE_OUTPUT_FILE}") list(APPEND TESTFILES "${CIRCLE_OUTPUT_FILE}") @@ -87,8 +98,8 @@ foreach(RECIPE IN ITEMS ${RECIPES2}) # Generate .circle add_custom_command(OUTPUT "${CIRCLE_OUTPUT_FILE}" - COMMAND circlechef-file "${RECIPE_SOURCE_FILE}" "${CIRCLE_OUTPUT_FILE}" - DEPENDS circlechef-file "${RECIPE_SOURCE_FILE}" + COMMAND ${CIRCLECHEF_FILE_PATH} "${RECIPE_SOURCE_FILE}" "${CIRCLE_OUTPUT_FILE}" + DEPENDS ${CIRCLECHEF_FILE_PATH} "${RECIPE_SOURCE_FILE}" COMMENT "Generating ${CIRCLE_OUTPUT_FILE}") list(APPEND TESTFILES "${CIRCLE_OUTPUT_FILE}") @@ -111,6 +122,8 @@ include("test.lst") # Read "test.local.lst" if exists include("test.local.lst" OPTIONAL) +# NOTE $<TARGET_FILE:luci_readtester> is used as-is as test itself should +# run in target device for cross build also add_test(NAME luci_unit_readtest COMMAND "${CMAKE_CURRENT_SOURCE_DIR}/readverify.sh" "${CMAKE_CURRENT_BINARY_DIR}" diff --git a/compiler/luci/tests/test.lst b/compiler/luci/tests/test.lst index 28ddcf672..94e723f21 100644 --- a/compiler/luci/tests/test.lst +++ b/compiler/luci/tests/test.lst @@ -180,6 +180,8 @@ addread(Sub_000) addread(Sub_U8_000) addread(Sum_000) addread(Sum_001) +addread(SVDF_000) +addread(SVDF_001) addread(Tanh_000) addread(Tanh_U8_000) addread(Tile_000) @@ -403,6 +405,8 @@ addwrite(Sub_000) addwrite(Sub_U8_000) addwrite(Sum_000) addwrite(Sum_001) +addwrite(SVDF_000) +addwrite(SVDF_001) addwrite(Tanh_000) addwrite(Tanh_U8_000) addwrite(Tile_000) |