summaryrefslogtreecommitdiff
path: root/compiler/luci
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci')
-rw-r--r--compiler/luci/CMakeLists.txt4
-rw-r--r--compiler/luci/export/CMakeLists.txt4
-rw-r--r--compiler/luci/export/src/CircleBuiltinTypesExtractor.h539
-rw-r--r--compiler/luci/export/src/CircleBuiltinTypesMappingRule.h79
-rw-r--r--compiler/luci/export/src/CircleExporterImpl.cpp9
-rw-r--r--compiler/luci/export/src/CircleExporterUtils.cpp58
-rw-r--r--compiler/luci/export/src/CircleExporterUtils.h6
-rw-r--r--compiler/luci/export/src/CircleOperationExporter.cpp1696
-rw-r--r--compiler/luci/export/src/CircleOperationExporter.h2
-rw-r--r--compiler/luci/export/src/CircleOperationExporterRule.cpp277
-rw-r--r--compiler/luci/export/src/CircleOperationExporterRule.h76
-rw-r--r--compiler/luci/export/src/CircleOps.lst154
-rw-r--r--compiler/luci/export/src/CircleTensorExporter.cpp15
-rw-r--r--compiler/luci/export/src/SerializedData.h6
-rw-r--r--compiler/luci/import/CMakeLists.txt7
-rw-r--r--compiler/luci/import/include/luci/Import/CircleReader.h73
-rw-r--r--compiler/luci/import/include/luci/Import/GraphBuilderRegistry.h23
-rw-r--r--compiler/luci/import/include/luci/Import/NodeBuilder.h58
-rw-r--r--compiler/luci/import/include/luci/Import/Nodes.h2
-rw-r--r--compiler/luci/import/include/luci/Import/Nodes/CircleConst.h17
-rw-r--r--compiler/luci/import/include/luci/Import/Nodes/CircleSVDF.h37
-rw-r--r--compiler/luci/import/include/luci/Import/Nodes/CircleVariable.h37
-rw-r--r--compiler/luci/import/src/CircleImportMetadata.cpp43
-rw-r--r--compiler/luci/import/src/CircleReader.cpp186
-rw-r--r--compiler/luci/import/src/GraphBuilder.cpp15
-rw-r--r--compiler/luci/import/src/GraphBuilderMultiOutput.cpp20
-rw-r--r--compiler/luci/import/src/GraphBuilderRegistry.cpp9
-rw-r--r--compiler/luci/import/src/Importer.cpp78
-rw-r--r--compiler/luci/import/src/Importer.test.cpp50
-rw-r--r--compiler/luci/import/src/Nodes/CircleCast.cpp12
-rw-r--r--compiler/luci/import/src/Nodes/CircleConst.cpp34
-rw-r--r--compiler/luci/import/src/Nodes/CircleCustom.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp12
-rw-r--r--compiler/luci/import/src/Nodes/CircleElu.cpp10
-rw-r--r--compiler/luci/import/src/Nodes/CircleEqual.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleExp.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleExpandDims.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleFloorDiv.cpp17
-rw-r--r--compiler/luci/import/src/Nodes/CircleFloorMod.cpp9
-rw-r--r--compiler/luci/import/src/Nodes/CircleFullyConnected.cpp1
-rw-r--r--compiler/luci/import/src/Nodes/CircleGatherNd.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleGreater.cpp10
-rw-r--r--compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleIf.cpp9
-rw-r--r--compiler/luci/import/src/Nodes/CircleLess.cpp13
-rw-r--r--compiler/luci/import/src/Nodes/CircleLessEqual.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleLog.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleLogicalNot.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleLogicalOr.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleLogistic.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp22
-rw-r--r--compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp25
-rw-r--r--compiler/luci/import/src/Nodes/CircleNotEqual.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleOneHot.cpp24
-rw-r--r--compiler/luci/import/src/Nodes/CircleReduceAny.cpp17
-rw-r--r--compiler/luci/import/src/Nodes/CircleReduceProd.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleReshape.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleReverseSequence.cpp15
-rw-r--r--compiler/luci/import/src/Nodes/CircleReverseV2.cpp15
-rw-r--r--compiler/luci/import/src/Nodes/CircleRound.cpp12
-rw-r--r--compiler/luci/import/src/Nodes/CircleRsqrt.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleSVDF.cpp67
-rw-r--r--compiler/luci/import/src/Nodes/CircleScatterNd.cpp9
-rw-r--r--compiler/luci/import/src/Nodes/CircleSegmentSum.cpp15
-rw-r--r--compiler/luci/import/src/Nodes/CircleSelect.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleSelectV2.cpp14
-rw-r--r--compiler/luci/import/src/Nodes/CircleSin.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleSquare.cpp12
-rw-r--r--compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp13
-rw-r--r--compiler/luci/import/src/Nodes/CircleTanh.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleTile.cpp10
-rw-r--r--compiler/luci/import/src/Nodes/CircleTopKV2.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleTransposeConv.cpp14
-rw-r--r--compiler/luci/import/src/Nodes/CircleUnpack.cpp11
-rw-r--r--compiler/luci/import/src/Nodes/CircleVariable.cpp80
-rw-r--r--compiler/luci/import/src/Nodes/CircleWhere.cpp12
-rw-r--r--compiler/luci/import/src/Nodes/CircleWhile.cpp13
-rw-r--r--compiler/luci/import/src/ValidateHelpers.cpp39
-rw-r--r--compiler/luci/lang/include/luci/IR/CircleNodes.h9
-rw-r--r--compiler/luci/lang/include/luci/IR/CircleNodes.lst5
-rw-r--r--compiler/luci/lang/include/luci/IR/CircleQuantParam.h4
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h4
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleSVDF.h70
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleVariable.h39
-rw-r--r--compiler/luci/lang/src/CircleQuantParam.cpp46
-rw-r--r--compiler/luci/lang/src/CircleQuantParam.test.cpp78
-rw-r--r--compiler/luci/lang/src/Nodes/CircleFullyConnected.test.cpp1
-rw-r--r--compiler/luci/lang/src/Nodes/CircleSVDF.test.cpp101
-rw-r--r--compiler/luci/lang/src/Nodes/CircleVariable.test.cpp61
-rw-r--r--compiler/luci/logex/CMakeLists.txt14
-rw-r--r--compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp265
-rw-r--r--compiler/luci/logex/src/CircleNodeSummaryBuilder.h52
-rw-r--r--compiler/luci/logex/src/CircleNodeSummaryBuilder.test.cpp309
-rw-r--r--compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp1128
-rw-r--r--compiler/luci/logex/src/CircleNodeSummaryBuilders.h821
-rw-r--r--compiler/luci/logex/src/FormattedGraph.cpp2194
-rw-r--r--compiler/luci/partition/CMakeLists.txt2
-rw-r--r--compiler/luci/partition/src/ConnectNode.h2
-rw-r--r--compiler/luci/partition/src/Nodes/CircleSVDF.cpp47
-rw-r--r--compiler/luci/partition/src/Nodes/CircleSVDF.test.cpp106
-rw-r--r--compiler/luci/partition/src/Nodes/CircleVariable.cpp27
-rw-r--r--compiler/luci/partition/src/PartitionIRDump.cpp11
-rw-r--r--compiler/luci/partition/src/PartitionMerge.cpp50
-rw-r--r--compiler/luci/partition/src/PartitionPGroups.cpp240
-rw-r--r--compiler/luci/pass/CMakeLists.txt8
-rw-r--r--compiler/luci/pass/include/luci/CircleOptimizer.h20
-rw-r--r--compiler/luci/pass/include/luci/CircleQuantizer.h97
-rw-r--r--compiler/luci/pass/include/luci/Pass/ConvertToFakeQuantizedModelPass.h39
-rw-r--r--compiler/luci/pass/include/luci/Pass/CopyQuantParamPass.h53
-rw-r--r--compiler/luci/pass/include/luci/Pass/FoldGatherPass.h38
-rw-r--r--compiler/luci/pass/include/luci/Pass/PropagateQParamBackwardPass.h42
-rw-r--r--compiler/luci/pass/include/luci/Pass/PropagateQParamForwardPass.h (renamed from compiler/luci/pass/include/luci/Pass/PropagateQuantParamPass.h)19
-rw-r--r--compiler/luci/pass/include/luci/Pass/QuantizationParameters.h11
-rw-r--r--compiler/luci/pass/include/luci/Pass/QuantizeDequantizeWeightsPass.h28
-rw-r--r--compiler/luci/pass/include/luci/Pass/QuantizePreCheckerPass.h39
-rw-r--r--compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h39
-rw-r--r--compiler/luci/pass/include/luci/Pass/RemoveRedundantQuantizePass.h37
-rw-r--r--compiler/luci/pass/src/BatchNormPatternFinder.cpp40
-rw-r--r--compiler/luci/pass/src/BatchNormPatternFinder.test.cpp107
-rw-r--r--compiler/luci/pass/src/CircleOptimizer.cpp224
-rw-r--r--compiler/luci/pass/src/CircleOptimizer.test.cpp168
-rw-r--r--compiler/luci/pass/src/CircleQuantizer.cpp458
-rw-r--r--compiler/luci/pass/src/CircleQuantizer.test.cpp191
-rw-r--r--compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp6
-rw-r--r--compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp36
-rw-r--r--compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp214
-rw-r--r--compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.test.cpp277
-rw-r--r--compiler/luci/pass/src/CopyQuantParamPass.cpp82
-rw-r--r--compiler/luci/pass/src/FoldGatherPass.cpp185
-rw-r--r--compiler/luci/pass/src/FoldGatherPass.test.cpp214
-rw-r--r--compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp36
-rw-r--r--compiler/luci/pass/src/PropagateQParamBackwardPass.cpp482
-rw-r--r--compiler/luci/pass/src/PropagateQParamBackwardPass.test.cpp167
-rw-r--r--compiler/luci/pass/src/PropagateQParamForwardPass.cpp194
-rw-r--r--compiler/luci/pass/src/PropagateQParamForwardPass.test.cpp260
-rw-r--r--compiler/luci/pass/src/PropagateQuantParamPass.cpp107
-rw-r--r--compiler/luci/pass/src/PropagateQuantParamPass.test.cpp125
-rw-r--r--compiler/luci/pass/src/QuantizationUtils.cpp158
-rw-r--r--compiler/luci/pass/src/QuantizationUtils.h36
-rw-r--r--compiler/luci/pass/src/QuantizeActivation.cpp296
-rw-r--r--compiler/luci/pass/src/QuantizeActivation.h165
-rw-r--r--compiler/luci/pass/src/QuantizeBias.cpp300
-rw-r--r--compiler/luci/pass/src/QuantizeBias.h56
-rw-r--r--compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp259
-rw-r--r--compiler/luci/pass/src/QuantizeDequantizeWeightsPass.test.cpp14
-rw-r--r--compiler/luci/pass/src/QuantizePreCheckerPass.cpp119
-rw-r--r--compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp401
-rw-r--r--compiler/luci/pass/src/QuantizeWeights.cpp394
-rw-r--r--compiler/luci/pass/src/QuantizeWeights.h55
-rw-r--r--compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp1773
-rw-r--r--compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp49
-rw-r--r--compiler/luci/pass/src/QuantizedModelVerifier.cpp70
-rw-r--r--compiler/luci/pass/src/QuantizedModelVerifier.h30
-rw-r--r--compiler/luci/pass/src/QuantizedModelVerifier.test.cpp497
-rw-r--r--compiler/luci/pass/src/RemoveRedundantQuantizePass.cpp104
-rw-r--r--compiler/luci/pass/src/RemoveRedundantQuantizePass.test.cpp166
-rw-r--r--compiler/luci/pass/src/RemoveRedundantTransposePass.cpp2
-rw-r--r--compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp25
-rw-r--r--compiler/luci/pass/src/RemoveUnnecessaryReshapePass.cpp19
-rw-r--r--compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp26
-rw-r--r--compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp46
-rw-r--r--compiler/luci/pass/src/SubstituteSplitVToSplitPass.cpp13
-rw-r--r--compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp14
-rw-r--r--compiler/luci/pass/src/SubstituteStridedSliceToReshapePass.cpp2
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedBiasScale.cpp105
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedBiasScale.h59
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeGranularity.cpp38
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h (renamed from compiler/luci/pass/src/VerifyQuantizedNodeChannelWiseGranularity.h)301
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeLayerWiseGranularity.h473
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeS16Type.h516
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeType.cpp554
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeType.h157
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h518
-rw-r--r--compiler/luci/pass/src/helpers/LayerInfoMap.cpp189
-rw-r--r--compiler/luci/pass/src/helpers/LayerInfoMap.h33
-rw-r--r--compiler/luci/pass/src/helpers/LayerInfoMap.test.cpp201
-rw-r--r--compiler/luci/requires.cmake4
-rw-r--r--compiler/luci/service/CMakeLists.txt1
-rw-r--r--compiler/luci/service/include/luci/Service/CircleShapeInference.h7
-rw-r--r--compiler/luci/service/include/luci/Service/CircleTypeInference.h8
-rw-r--r--compiler/luci/service/src/CircleCloneNode.h2
-rw-r--r--compiler/luci/service/src/CircleNodeClone.cpp14
-rw-r--r--compiler/luci/service/src/CircleShapeInferenceRule.cpp90
-rw-r--r--compiler/luci/service/src/CircleTypeInferenceRule.cpp7
-rw-r--r--compiler/luci/service/src/Nodes/CircleSVDF.cpp37
-rw-r--r--compiler/luci/service/src/Nodes/CircleSVDF.test.cpp47
-rw-r--r--compiler/luci/service/src/Nodes/CircleVariable.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleVariable.test.cpp33
-rw-r--r--compiler/luci/tests/CMakeLists.txt33
-rw-r--r--compiler/luci/tests/test.lst4
194 files changed, 14022 insertions, 8555 deletions
diff --git a/compiler/luci/CMakeLists.txt b/compiler/luci/CMakeLists.txt
index b92eefb40..460dc7b23 100644
--- a/compiler/luci/CMakeLists.txt
+++ b/compiler/luci/CMakeLists.txt
@@ -23,4 +23,8 @@ add_subdirectory(import)
add_subdirectory(export)
add_subdirectory(tester)
+if(NOT ENABLE_TEST)
+ return()
+endif(NOT ENABLE_TEST)
+
add_subdirectory(tests)
diff --git a/compiler/luci/export/CMakeLists.txt b/compiler/luci/export/CMakeLists.txt
index a267d0e1f..f46181eb6 100644
--- a/compiler/luci/export/CMakeLists.txt
+++ b/compiler/luci/export/CMakeLists.txt
@@ -12,7 +12,7 @@ target_include_directories(luci_export PUBLIC include)
target_link_libraries(luci_export PRIVATE luci_lang)
target_link_libraries(luci_export PRIVATE luci_service)
target_link_libraries(luci_export PRIVATE luci_pass)
-target_link_libraries(luci_export PRIVATE mio_circle)
+target_link_libraries(luci_export PRIVATE mio_circle04)
target_link_libraries(luci_export PRIVATE luci_env)
target_link_libraries(luci_export PRIVATE luci_log)
target_link_libraries(luci_export PRIVATE luci_logex)
@@ -36,6 +36,6 @@ target_include_directories(luci_export_test PRIVATE src)
target_link_libraries(luci_export_test luci_export)
target_link_libraries(luci_export_test luci_plan)
target_link_libraries(luci_export_test luci_lang)
-target_link_libraries(luci_export_test mio_circle)
+target_link_libraries(luci_export_test mio_circle04)
target_link_libraries(luci_export_test luci_env)
target_link_libraries(luci_export_test oops)
diff --git a/compiler/luci/export/src/CircleBuiltinTypesExtractor.h b/compiler/luci/export/src/CircleBuiltinTypesExtractor.h
new file mode 100644
index 000000000..0ff21a34b
--- /dev/null
+++ b/compiler/luci/export/src/CircleBuiltinTypesExtractor.h
@@ -0,0 +1,539 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CIRCLE_BUILTIN_TYPES_EXTRACTOR_H__
+#define __CIRCLE_BUILTIN_TYPES_EXTRACTOR_H__
+
+#include "CircleExporterUtils.h"
+
+#include <luci/IR/CircleNode.h>
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+#include <flatbuffers/flexbuffers.h>
+
+namespace luci
+{
+
+// NOTE Virtual nodes are not circle builtin operators.
+// Therefore, they are not defined here.
+class BuiltinOptionsExtractor final
+ : public luci::CircleNodeMutableVisitor<flatbuffers::Offset<void>>
+{
+public:
+ BuiltinOptionsExtractor(flatbuffers::FlatBufferBuilder &builder) : _builder{builder}
+ {
+ // DO NOTHING
+ }
+
+public:
+ flatbuffers::Offset<void> visit(luci::CircleAbs *)
+ {
+ return circle::CreateAbsOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleAdd *node)
+ {
+ return circle::CreateAddOptions(_builder, to_circle_actfunc(node->fusedActivationFunction()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleAddN *)
+ {
+ return circle::CreateAddNOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleArgMax *node)
+ {
+ return circle::CreateArgMaxOptions(_builder, luci::to_circle_tensortype(node->output_type()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleArgMin *node)
+ {
+ return circle::CreateArgMinOptions(_builder, luci::to_circle_tensortype(node->output_type()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleAveragePool2D *node)
+ {
+ return circle::CreatePool2DOptions(_builder, getOpPadding(node->padding()), node->stride()->w(),
+ node->stride()->h(), node->filter()->w(),
+ node->filter()->h(),
+ to_circle_actfunc(node->fusedActivationFunction()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleBatchMatMul *node)
+ {
+ return circle::CreateBatchMatMulOptions(_builder, node->adj_x(), node->adj_y()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleBatchToSpaceND *)
+ {
+ return circle::CreateBatchToSpaceNDOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleBidirectionalSequenceLSTM *node)
+ {
+ return circle::CreateBidirectionalSequenceLSTMOptions(
+ _builder, to_circle_actfunc(node->fusedActivationFunction()), node->cell_clip(),
+ node->proj_clip(), node->merge_outputs(), node->time_major(),
+ node->asymmetric_quantize_inputs())
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleCast *node)
+ {
+ if (node->out_data_type() == loco::DataType::Unknown)
+ return _no_option;
+ else
+ return circle::CreateCastOptions(_builder, luci::to_circle_tensortype(node->in_data_type()),
+ luci::to_circle_tensortype(node->out_data_type()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleCeil *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleConcatenation *node)
+ {
+ return circle::CreateConcatenationOptions(_builder, node->axis(),
+ to_circle_actfunc(node->fusedActivationFunction()))
+ .Union();
+ }
+ // CircleConst is not virtual but not builtinOperator
+ // flatbuffers::Offset<void> visit(luci::CircleConst *)
+ flatbuffers::Offset<void> visit(luci::CircleConv2D *node)
+ {
+ return circle::CreateConv2DOptions(_builder, getOpPadding(node->padding()), node->stride()->w(),
+ node->stride()->h(),
+ to_circle_actfunc(node->fusedActivationFunction()),
+ node->dilation()->w(), node->dilation()->h())
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleCos *)
+ {
+ return circle::CreateCosOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleCustom *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleDepthToSpace *node)
+ {
+ return circle::CreateDepthToSpaceOptions(_builder, node->block_size()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleDepthwiseConv2D *node)
+ {
+ return circle::CreateDepthwiseConv2DOptions(
+ _builder, getOpPadding(node->padding()), node->stride()->w(), node->stride()->h(),
+ node->depthMultiplier(), to_circle_actfunc(node->fusedActivationFunction()),
+ node->dilation()->w(), node->dilation()->h())
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleDequantize *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleDiv *node)
+ {
+ return circle::CreateDivOptions(_builder, to_circle_actfunc(node->fusedActivationFunction()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleElu *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleEqual *)
+ {
+ return circle::CreateEqualOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleExp *)
+ {
+ return circle::CreateExpOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleExpandDims *)
+ {
+ return circle::CreateExpandDimsOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleFakeQuant *node)
+ {
+ return circle::CreateFakeQuantOptions(_builder, node->min(), node->max(), node->num_bits(),
+ node->narrow_range())
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleFill *)
+ {
+ return circle::CreateFillOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleFloor *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleFloorDiv *)
+ {
+ return circle::CreateFloorDivOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleFloorMod *)
+ {
+ return circle::CreateFloorModOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleFullyConnected *node)
+ {
+ return circle::CreateFullyConnectedOptions(
+ _builder, to_circle_actfunc(node->fusedActivationFunction()),
+ to_circle_weightsformat(node->weights_format()), node->keep_num_dims())
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleGather *node)
+ {
+ return circle::CreateGatherOptions(_builder, node->axis()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleGatherNd *)
+ {
+ return circle::CreateGatherNdOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleGreater *)
+ {
+ return circle::CreateGreaterOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleGreaterEqual *)
+ {
+ return circle::CreateGreaterEqualOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleIf *node)
+ {
+ return circle::CreateIfOptions(_builder, node->then_branch(), node->else_branch()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleL2Normalize *node)
+ {
+ return circle::CreateL2NormOptions(_builder, to_circle_actfunc(node->fusedActivationFunction()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleL2Pool2D *node)
+ {
+ return circle::CreatePool2DOptions(_builder, getOpPadding(node->padding()), node->stride()->w(),
+ node->stride()->h(), node->filter()->w(),
+ node->filter()->h(),
+ to_circle_actfunc(node->fusedActivationFunction()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleLeakyRelu *node)
+ {
+ return circle::CreateLeakyReluOptions(_builder, node->alpha()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleLess *)
+ {
+ return circle::CreateLessOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleLessEqual *)
+ {
+ return circle::CreateLessEqualOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleLocalResponseNormalization *node)
+ {
+ return circle::CreateLocalResponseNormalizationOptions(_builder, node->radius(), node->bias(),
+ node->alpha(), node->beta())
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleLog *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleLogicalAnd *)
+ {
+ return circle::CreateLogicalAndOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleLogicalNot *)
+ {
+ return circle::CreateLogicalNotOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleLogicalOr *)
+ {
+ return circle::CreateLogicalOrOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleLogistic *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleLogSoftmax *)
+ {
+ return circle::CreateLogSoftmaxOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleMatrixDiag *)
+ {
+ return circle::CreateMatrixDiagOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleMatrixSetDiag *)
+ {
+ return circle::CreateMatrixSetDiagOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleMaximum *)
+ {
+ return circle::CreateMaximumMinimumOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleMaxPool2D *node)
+ {
+ return circle::CreatePool2DOptions(_builder, getOpPadding(node->padding()), node->stride()->w(),
+ node->stride()->h(), node->filter()->w(),
+ node->filter()->h(),
+ to_circle_actfunc(node->fusedActivationFunction()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleMean *node)
+ {
+ return circle::CreateReducerOptions(_builder, node->keep_dims()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleMinimum *)
+ {
+ return circle::CreateMaximumMinimumOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleMirrorPad *node)
+ {
+ return circle::CreateMirrorPadOptions(_builder, to_circle_mirrorpadmode(node->mode())).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleMul *node)
+ {
+ return circle::CreateMulOptions(_builder, to_circle_actfunc(node->fusedActivationFunction()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleNeg *)
+ {
+ return circle::CreateNegOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleNonMaxSuppressionV4 *)
+ {
+ return circle::CreateNonMaxSuppressionV4Options(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleNonMaxSuppressionV5 *)
+ {
+ return circle::CreateNonMaxSuppressionV5Options(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleNotEqual *)
+ {
+ return circle::CreateNotEqualOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleOneHot *node)
+ {
+ return circle::CreateOneHotOptions(_builder, node->axis()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CirclePack *node)
+ {
+ return circle::CreatePackOptions(_builder, node->values_count(), node->axis()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CirclePad *)
+ {
+ return circle::CreatePadOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CirclePadV2 *)
+ {
+ return circle::CreatePadV2Options(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CirclePow *)
+ {
+ return circle::CreatePowOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CirclePRelu *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleQuantize *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleRange *)
+ {
+ return circle::CreateRangeOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleRank *)
+ {
+ return circle::CreateRankOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleReduceAny *node)
+ {
+ return circle::CreateReducerOptions(_builder, node->keep_dims()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleReduceMax *node)
+ {
+ return circle::CreateReducerOptions(_builder, node->keep_dims()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleReduceMin *node)
+ {
+ return circle::CreateReducerOptions(_builder, node->keep_dims()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleReduceProd *node)
+ {
+ return circle::CreateReducerOptions(_builder, node->keep_dims()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleRelu *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleRelu6 *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleReluN1To1 *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleReshape *node)
+ {
+ auto new_shape = _builder.CreateVector<int32_t>(
+ node->newShape()->rank(), [node](size_t i) { return node->newShape()->dim(i); });
+ return circle::CreateReshapeOptions(_builder, new_shape).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleResizeBilinear *node)
+ {
+ return circle::CreateResizeBilinearOptions(_builder, node->align_corners(),
+ node->half_pixel_centers())
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleResizeNearestNeighbor *node)
+ {
+ return circle::CreateResizeNearestNeighborOptions(_builder, node->align_corners()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleReverseSequence *node)
+ {
+ return circle::CreateReverseSequenceOptions(_builder, node->seq_axis(), node->batch_axis())
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleReverseV2 *)
+ {
+ return circle::CreateReverseV2Options(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleRound *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleRsqrt *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleScatterNd *)
+ {
+ return circle::CreateScatterNdOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSegmentSum *)
+ {
+ return circle::CreateSegmentSumOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSelect *)
+ {
+ return circle::CreateSelectOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSelectV2 *)
+ {
+ return circle::CreateSelectV2Options(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleShape *node)
+ {
+ return circle::CreateShapeOptions(_builder, luci::to_circle_tensortype(node->out_type()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSin *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleSlice *)
+ {
+ return circle::CreateSliceOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSoftmax *node)
+ {
+ return circle::CreateSoftmaxOptions(_builder, node->beta()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSpaceToBatchND *)
+ {
+ return circle::CreateSpaceToBatchNDOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSpaceToDepth *node)
+ {
+ return circle::CreateSpaceToDepthOptions(_builder, node->block_size()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSparseToDense *node)
+ {
+ return circle::CreateSparseToDenseOptions(_builder, node->validate_indices()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSplit *node)
+ {
+ return circle::CreateSplitOptions(_builder, node->num_split()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSplitV *node)
+ {
+ return circle::CreateSplitVOptions(_builder, node->num_split()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSqrt *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleSquare *)
+ {
+ return circle::CreateSquareOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSquaredDifference *)
+ {
+ return circle::CreateSquaredDifferenceOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSqueeze *node)
+ {
+ auto squeeze_dims = _builder.CreateVector<int32_t>(node->squeeze_dims());
+ return circle::CreateSqueezeOptions(_builder, squeeze_dims).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleStridedSlice *node)
+ {
+ return circle::CreateStridedSliceOptions(_builder, node->begin_mask(), node->end_mask(),
+ node->ellipsis_mask(), node->new_axis_mask(),
+ node->shrink_axis_mask())
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSub *node)
+ {
+ return circle::CreateSubOptions(_builder, to_circle_actfunc(node->fusedActivationFunction()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSum *node)
+ {
+ return circle::CreateReducerOptions(_builder, node->keep_dims()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSVDF *node)
+ {
+ return circle::CreateSVDFOptions(_builder, node->svdf_rank(),
+ to_circle_actfunc(node->fusedActivationFunction()),
+ node->asymmetric_quantize_inputs())
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleTanh *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleTile *)
+ {
+ return circle::CreateTileOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleTopKV2 *)
+ {
+ return circle::CreateTopKV2Options(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleTranspose *)
+ {
+ return circle::CreateTransposeOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleTransposeConv *node)
+ {
+ return circle::CreateTransposeConvOptions(_builder, getOpPadding(node->padding()),
+ node->stride()->w(), node->stride()->h())
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleUnidirectionalSequenceLSTM *node)
+ {
+ return circle::CreateUnidirectionalSequenceLSTMOptions(
+ _builder, to_circle_actfunc(node->fusedActivationFunction()), node->cell_clip(),
+ node->proj_clip(), node->time_major(), node->asymmetric_quantize_inputs())
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleUnique *node)
+ {
+ return circle::CreateUniqueOptions(_builder, luci::to_circle_tensortype(node->idx_out_type()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleUnpack *node)
+ {
+ return circle::CreateUnpackOptions(_builder, node->num(), node->axis()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleWhere *)
+ {
+ return circle::CreateWhereOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleWhile *node)
+ {
+ return circle::CreateWhileOptions(_builder, node->cond_branch(), node->body_branch()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleZerosLike *)
+ {
+ return circle::CreateZerosLikeOptions(_builder).Union();
+ }
+ // Circle only
+ flatbuffers::Offset<void> visit(luci::CircleBCQFullyConnected *node)
+ {
+ return circle::CreateBCQFullyConnectedOptions(
+ _builder, node->weights_hidden_size(),
+ to_circle_actfunc(node->fusedActivationFunction()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleBCQGather *node)
+ {
+ return circle::CreateBCQGatherOptions(_builder, node->input_hidden_size(), node->axis())
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleInstanceNorm *node)
+ {
+ return circle::CreateInstanceNormOptions(_builder, node->epsilon(),
+ to_circle_actfunc(node->fusedActivationFunction()))
+ .Union();
+ }
+
+protected:
+ flatbuffers::FlatBufferBuilder &_builder;
+
+private:
+ const flatbuffers::Offset<void> _no_option = 0;
+};
+
+} // namespace luci
+
+#endif // __CIRCLE_BUILTIN_TYPES_EXTRACTOR_H__
diff --git a/compiler/luci/export/src/CircleBuiltinTypesMappingRule.h b/compiler/luci/export/src/CircleBuiltinTypesMappingRule.h
new file mode 100644
index 000000000..6f7c0f70e
--- /dev/null
+++ b/compiler/luci/export/src/CircleBuiltinTypesMappingRule.h
@@ -0,0 +1,79 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CIRCLE_EXPORT_BUILTIN_TYPES_MAPPING_RULE_H__
+#define __CIRCLE_EXPORT_BUILTIN_TYPES_MAPPING_RULE_H__
+
+#include <luci/IR/CircleNode.h>
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+namespace luci
+{
+
+class BuiltinOperatorMappingRule final : public CircleNodeVisitor<circle::BuiltinOperator>
+{
+public:
+ BuiltinOperatorMappingRule()
+ {
+ // DO NOTHING
+ }
+
+public:
+ static BuiltinOperatorMappingRule &get()
+ {
+ static BuiltinOperatorMappingRule instance;
+ return instance;
+ }
+
+public:
+#define CIRCLE_NODE(CIRCLE_NODE, OP, OPTION) \
+ circle::BuiltinOperator visit(const CIRCLE_NODE *) final { return circle::OP; }
+// Virtual nodes are not circle builtin operator
+#define CIRCLE_VNODE(CIRCLE_NODE)
+#include "CircleOps.lst"
+#undef CIRCLE_VNODE
+#undef CIRCLE_NODE
+};
+
+class BuiltinOptionsMappingRule final : public CircleNodeVisitor<circle::BuiltinOptions>
+{
+public:
+ BuiltinOptionsMappingRule()
+ {
+ // DO NOTHING
+ }
+
+public:
+ static BuiltinOptionsMappingRule &get()
+ {
+ static BuiltinOptionsMappingRule instance;
+ return instance;
+ }
+
+public:
+#define CIRCLE_NODE(CIRCLE_NODE, OP, OPTION) \
+ circle::BuiltinOptions visit(const CIRCLE_NODE *) final { return circle::OPTION; }
+// Virtual nodes are not circle builtin operator
+#define CIRCLE_VNODE(CIRCLE_NODE)
+#include "CircleOps.lst"
+#undef CIRCLE_VNODE
+#undef CIRCLE_NODE
+};
+
+} // namespace luci
+
+#endif // __CIRCLE_EXPORT_BUILTIN_TYPES_MAPPING_RULE_H__
diff --git a/compiler/luci/export/src/CircleExporterImpl.cpp b/compiler/luci/export/src/CircleExporterImpl.cpp
index 5868c176c..083add9be 100644
--- a/compiler/luci/export/src/CircleExporterImpl.cpp
+++ b/compiler/luci/export/src/CircleExporterImpl.cpp
@@ -79,14 +79,19 @@ encodeOperatorCodes(FlatBufferBuilder &builder, std::unordered_map<luci::OpCode,
for (auto it : opcodes)
{
uint32_t idx = it.second;
+ int8_t dep_code = 127; // BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES
+ if (it.first.opcode < BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES)
+ dep_code = static_cast<int8_t>(it.first.opcode);
if (it.first.opcode != BuiltinOperator_CUSTOM)
{
- operator_codes_vec[idx] = CreateOperatorCode(builder, it.first.opcode, 0, it.first.version);
+ operator_codes_vec[idx] =
+ CreateOperatorCode(builder, dep_code, 0, it.first.version, it.first.opcode);
}
else
{
operator_codes_vec[idx] =
- CreateOperatorCode(builder, it.first.opcode, builder.CreateString(it.first.custom_code));
+ CreateOperatorCode(builder, dep_code, builder.CreateString(it.first.custom_code),
+ it.first.version, it.first.opcode);
}
}
diff --git a/compiler/luci/export/src/CircleExporterUtils.cpp b/compiler/luci/export/src/CircleExporterUtils.cpp
index 3a7ba304f..9473c2c4e 100644
--- a/compiler/luci/export/src/CircleExporterUtils.cpp
+++ b/compiler/luci/export/src/CircleExporterUtils.cpp
@@ -15,6 +15,7 @@
*/
#include "CircleExporterUtils.h"
+#include "CircleBuiltinTypesMappingRule.h"
#include <oops/InternalExn.h>
@@ -163,36 +164,63 @@ circle::SparseIndexVector to_circle_sparse_index_vector_type(luci::SparseIndexVe
}
}
-} // namespace luci
+circle::BuiltinOperator circle_builtin_operator(const luci::CircleNode *node)
+{
+ return node->accept(&BuiltinOperatorMappingRule::get());
+}
-namespace luci
+circle::BuiltinOptions circle_builtin_options(const luci::CircleNode *node)
{
+ if (auto cast = dynamic_cast<const luci::CircleCast *>(node))
+ {
+ return (cast->out_data_type() == loco::DataType::Unknown) ? circle::BuiltinOptions_NONE
+ : circle::BuiltinOptions_CastOptions;
+ }
-uint32_t SerializedModelData::registerBuiltinOpcode(circle::BuiltinOperator builtin_code,
- const int32_t op_version)
+ return node->accept(&BuiltinOptionsMappingRule::get());
+}
+
+std::string circle_custom_code(const luci::CircleNode *node)
{
- assert(op_version > 0);
+ if (auto custom_node = dynamic_cast<const luci::CircleCustom *>(node))
+ {
+ return custom_node->custom_code();
+ }
- auto it = _operator_codes.find(OpCode{builtin_code, "", op_version});
- if (it != _operator_codes.end())
+ return "";
+}
+
+flatbuffers::Offset<flatbuffers::Vector<uint8_t>>
+circle_custom_options(flatbuffers::FlatBufferBuilder &fb, const luci::CircleNode *node)
+{
+ if (auto custom_node = dynamic_cast<const luci::CircleCustom *>(node))
{
- return it->second;
+ std::vector<uint8_t> custom_options_vec{custom_node->custom_options().begin(),
+ custom_node->custom_options().end()};
+ return fb.CreateVector(custom_options_vec);
}
- auto idx = static_cast<uint32_t>(_operator_codes.size());
- _operator_codes.emplace(OpCode{builtin_code, "", op_version}, idx);
- return idx;
+
+ return 0;
}
-uint32_t SerializedModelData::registerCustomOpcode(const std::string &custom_code)
+} // namespace luci
+
+namespace luci
{
- const circle::BuiltinOperator builtin_code = circle::BuiltinOperator_CUSTOM;
- auto it = _operator_codes.find(OpCode{builtin_code, custom_code});
+
+uint32_t SerializedModelData::registerBuiltinOpcode(circle::BuiltinOperator builtin_code,
+ const std::string &custom_code,
+ const int32_t op_version)
+{
+ assert(op_version > 0);
+
+ auto it = _operator_codes.find(OpCode{builtin_code, custom_code, op_version});
if (it != _operator_codes.end())
{
return it->second;
}
auto idx = static_cast<uint32_t>(_operator_codes.size());
- _operator_codes.emplace(OpCode{builtin_code, custom_code}, idx);
+ _operator_codes.emplace(OpCode{builtin_code, custom_code, op_version}, idx);
return idx;
}
diff --git a/compiler/luci/export/src/CircleExporterUtils.h b/compiler/luci/export/src/CircleExporterUtils.h
index 95310b353..4a4c54a69 100644
--- a/compiler/luci/export/src/CircleExporterUtils.h
+++ b/compiler/luci/export/src/CircleExporterUtils.h
@@ -39,6 +39,12 @@ flatbuffers::Offset<void> to_circle_sparse_index_vector(flatbuffers::FlatBufferB
const SparseIndexVector &sparse_idx_vec);
circle::SparseIndexVector to_circle_sparse_index_vector_type(luci::SparseIndexVectorType type);
+circle::BuiltinOperator circle_builtin_operator(const luci::CircleNode *node);
+circle::BuiltinOptions circle_builtin_options(const luci::CircleNode *node);
+std::string circle_custom_code(const luci::CircleNode *node);
+flatbuffers::Offset<flatbuffers::Vector<uint8_t>>
+circle_custom_options(flatbuffers::FlatBufferBuilder &fb, const luci::CircleNode *node);
+
} // namespace luci
namespace luci
diff --git a/compiler/luci/export/src/CircleOperationExporter.cpp b/compiler/luci/export/src/CircleOperationExporter.cpp
index be64a52d4..b300a7fcf 100644
--- a/compiler/luci/export/src/CircleOperationExporter.cpp
+++ b/compiler/luci/export/src/CircleOperationExporter.cpp
@@ -15,1686 +15,30 @@
*/
#include "CircleOperationExporter.h"
-#include "CircleExporterUtils.h"
-#include "Check.h"
+#include "CircleOperationExporterRule.h"
#include <luci/IR/CircleNode.h>
-#include <luci/IR/CircleNodes.h>
-#include <luci/IR/CircleNodeVisitor.h>
#include <luci/Profile/CircleNodeOrigin.h>
#include <luci/Plan/CircleNodeExecutionPlan.h>
-#include <luci/UserSettings.h>
-#include <luci/Log.h>
+#include <loco/IR/Algorithm.h>
-#include <loco/IR/CanonicalNodeVisitor.h>
-#include <oops/InternalExn.h>
-
-#include <flatbuffers/flexbuffers.h>
-
-using namespace flatbuffers;
-using namespace circle;
-
-namespace
-{
-
-using namespace luci;
-
-struct ExportContext
-{
- FlatBufferBuilder &builder;
- SerializedModelData &md;
- SerializedGraphData &gd;
-};
-
-/**
- * @brief Exports CircleMaxPool2D or CircleAveragePool2D
- *
- * @note CirclePool2D should be one of CircleMaxPool2D or CircleAveragePool2D
- */
-template <class CirclePool2D>
-void export_pool_2d(ExportContext &ctx, CirclePool2D *node, circle::BuiltinOperator builtin_op)
-{
- LUCI_ASSERT(builtin_op == circle::BuiltinOperator_MAX_POOL_2D ||
- builtin_op == circle::BuiltinOperator_L2_POOL_2D ||
- builtin_op == circle::BuiltinOperator_AVERAGE_POOL_2D,
- "Should be L2Pool, MaxPool or AvgPool");
- LUCI_ASSERT(node->padding() != luci::Padding::UNDEFINED, "Padding is not set");
-
- uint32_t op_idx = ctx.md.registerBuiltinOpcode(builtin_op, node->op_version());
- std::vector<int32_t> inputs_vec{get_tensor_index(node->value())};
- std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
-
- circle::Padding padding = getOpPadding(node->padding());
-
- auto options = CreatePool2DOptions(ctx.builder, padding, node->stride()->w(), node->stride()->h(),
- node->filter()->w(), node->filter()->h(),
- to_circle_actfunc(node->fusedActivationFunction()));
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_Pool2DOptions, options.Union());
- ctx.gd._operators.push_back(op_offset);
-}
-
-/**
- * @brief export simple nodes
- */
-void export_node(ExportContext &ctx, loco::Node *node, circle::BuiltinOperator bop,
- circle::BuiltinOptions bot, flatbuffers::Offset<void> options_offset)
-{
- uint32_t op_idx =
- ctx.md.registerBuiltinOpcode(bop, loco::must_cast<luci::CircleNode *>(node)->op_version());
- std::vector<int32_t> inputs_vec;
- std::vector<int32_t> outputs_vec{get_tensor_index(node)};
- for (uint32_t i = 0; i < node->arity(); ++i)
- inputs_vec.push_back(get_tensor_index(node->arg(i)));
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs, bot, options_offset);
- ctx.gd._operators.push_back(op_offset);
-}
-
-/**
- * @brief export simple nodes having void options
- */
-void export_node(ExportContext &ctx, loco::Node *node, circle::BuiltinOperator bop)
-{
- uint32_t op_idx =
- ctx.md.registerBuiltinOpcode(bop, loco::must_cast<luci::CircleNode *>(node)->op_version());
- std::vector<int32_t> inputs_vec;
- std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
- for (uint32_t i = 0; i < node->arity(); ++i)
- inputs_vec.push_back(get_tensor_index(node->arg(i)));
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs);
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleAddN *node)
-{
- uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_ADD_N, node->op_version());
- std::vector<int32_t> inputs_vec;
- std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
-
- for (uint32_t i = 0; i < node->arity(); ++i)
- inputs_vec.push_back(get_tensor_index(node->inputs(i)));
-
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto options = CreateAddNOptions(ctx.builder);
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_AddNOptions, options.Union());
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleCast *node)
-{
- uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_CAST, node->op_version());
- std::vector<int32_t> inputs_vec{get_tensor_index(node->x())};
- std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
-
- flatbuffers::Offset<Operator> op_offset;
- if (node->out_data_type() != loco::DataType::Unknown)
- {
- auto options = CreateCastOptions(ctx.builder, to_circle_tensortype(node->in_data_type()),
- to_circle_tensortype(node->out_data_type()));
- op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_CastOptions, options.Union());
- }
- else
- {
- op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs);
- }
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleConcatenation *node)
-{
- uint32_t op_idx =
- ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_CONCATENATION, node->op_version());
- std::vector<int32_t> inputs_vec;
- std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
-
- for (uint32_t i = 0; i < node->numValues(); ++i)
- inputs_vec.push_back(get_tensor_index(node->values(i)));
-
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto options = CreateConcatenationOptions(ctx.builder, node->axis(),
- to_circle_actfunc(node->fusedActivationFunction()));
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_ConcatenationOptions, options.Union());
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleCustom *node)
-{
- auto custom_outputs = loco::succs(node);
- assert(custom_outputs.size() == node->numOutputs());
-
- uint32_t op_idx = ctx.md.registerCustomOpcode(node->custom_code());
- std::vector<int32_t> inputs_vec;
- std::vector<int32_t> outputs_vec;
-
- for (uint32_t index = 0; index < node->numInputs(); index++)
- {
- inputs_vec.push_back(get_tensor_index(node->inputs(index)));
- }
- for (uint32_t index = 0; index < custom_outputs.size(); index++)
- {
- // store in order of index
- bool found = false;
- for (auto out : custom_outputs)
- {
- auto custom_out = loco::must_cast<luci::CircleCustomOut *>(out);
- if (custom_out->index() == static_cast<int32_t>(index))
- {
- outputs_vec.push_back(get_tensor_index(custom_out));
- found = true;
- break;
- }
- }
- if (!found)
- {
- INTERNAL_EXN("Invalid Custom output");
- }
- }
-
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- flatbuffers::Offset<flatbuffers::Vector<uint8_t>> circle_custom_options;
- std::vector<uint8_t> custom_options_vec{node->custom_options().begin(),
- node->custom_options().end()};
- circle_custom_options = ctx.builder.CreateVector(custom_options_vec);
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs, circle::BuiltinOptions_NONE,
- flatbuffers::Offset<void>(), circle_custom_options);
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleIf *node)
-{
- auto if_outs = loco::succs(node);
- assert(if_outs.size() == node->output_count());
-
- uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_IF, node->op_version());
- std::vector<int32_t> inputs_vec;
- std::vector<int32_t> outputs_vec;
-
- inputs_vec.push_back(get_tensor_index(node->cond()));
- for (uint32_t idx = 0; idx < node->input_count(); ++idx)
- inputs_vec.push_back(get_tensor_index(node->input(idx)));
-
- for (uint32_t idx = 0; idx < node->output_count(); ++idx)
- {
- // store in order of index
- bool found = false;
- for (auto out : if_outs)
- {
- auto if_out = loco::must_cast<luci::CircleIfOut *>(out);
- if (if_out->index() == static_cast<int32_t>(idx))
- {
- outputs_vec.push_back(get_tensor_index(if_out));
- found = true;
- break;
- }
- }
- if (!found)
- {
- INTERNAL_EXN("Invalid CircleIf output");
- }
- }
-
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto options = CreateIfOptions(ctx.builder, node->then_branch(), node->else_branch());
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_IfOptions, options.Union());
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleNonMaxSuppressionV4 *node)
-{
- auto nms_outs = loco::succs(node);
- assert(nms_outs.size() == 2);
-
- uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_NON_MAX_SUPPRESSION_V4,
- node->op_version());
- std::vector<int32_t> inputs_vec{
- get_tensor_index(node->boxes()), get_tensor_index(node->scores()),
- get_tensor_index(node->max_output_size()), get_tensor_index(node->iou_threshold()),
- get_tensor_index(node->score_threshold()),
- };
- std::vector<int32_t> outputs_vec;
-
- for (uint32_t idx = 0; idx < nms_outs.size(); ++idx)
- {
- // store in order of index
- bool found = false;
- for (auto out : nms_outs)
- {
- auto nms_out = loco::must_cast<luci::CircleNonMaxSuppressionV4Out *>(out);
- if (nms_out->index() == static_cast<int32_t>(idx))
- {
- outputs_vec.push_back(get_tensor_index(nms_out));
- found = true;
- break;
- }
- }
- if (!found)
- {
- INTERNAL_EXN("Invalid NonMaxSuppressionV4 output");
- }
- }
-
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto options = CreateNonMaxSuppressionV4Options(ctx.builder);
- auto op_offset =
- CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_NonMaxSuppressionV4Options, options.Union());
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleNonMaxSuppressionV5 *node)
-{
- auto nms_outs = loco::succs(node);
- assert(nms_outs.size() == 3);
-
- uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_NON_MAX_SUPPRESSION_V5,
- node->op_version());
- std::vector<int32_t> inputs_vec{
- get_tensor_index(node->boxes()), get_tensor_index(node->scores()),
- get_tensor_index(node->max_output_size()), get_tensor_index(node->iou_threshold()),
- get_tensor_index(node->score_threshold()), get_tensor_index(node->soft_nms_sigma()),
- };
- std::vector<int32_t> outputs_vec;
-
- for (uint32_t idx = 0; idx < nms_outs.size(); ++idx)
- {
- // store in order of index
- bool found = false;
- for (auto out : nms_outs)
- {
- auto nms_out = loco::must_cast<luci::CircleNonMaxSuppressionV5Out *>(out);
- if (nms_out->index() == static_cast<int32_t>(idx))
- {
- outputs_vec.push_back(get_tensor_index(nms_out));
- found = true;
- break;
- }
- }
- if (!found)
- {
- INTERNAL_EXN("Invalid NonMaxSuppressionV5 output");
- }
- }
-
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto options = CreateNonMaxSuppressionV5Options(ctx.builder);
- auto op_offset =
- CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_NonMaxSuppressionV5Options, options.Union());
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleReverseV2 *node)
-{
- uint32_t op_idx =
- ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_REVERSE_V2, node->op_version());
- std::vector<int32_t> inputs_vec{get_tensor_index(node->tensor()), get_tensor_index(node->axis())};
- std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto options = CreateReverseV2Options(ctx.builder);
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_ReverseSequenceOptions, options.Union());
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleSplit *node)
-{
- auto split_outs = loco::succs(node);
- assert(int32_t(split_outs.size()) == node->num_split());
-
- uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_SPLIT, node->op_version());
- // NOTE BuiltinOperator_SPLIT input is placed at second position
- std::vector<int32_t> inputs_vec{get_tensor_index(node->split_dim()),
- get_tensor_index(node->input())};
- std::vector<int32_t> outputs_vec;
-
- for (int32_t index = 0; index < node->num_split(); index++)
- {
- // store in order of index
- bool found = false;
- for (auto out : split_outs)
- {
- auto split_out = loco::must_cast<luci::CircleSplitOut *>(out);
- if (split_out->index() == index)
- {
- outputs_vec.push_back(get_tensor_index(split_out));
- found = true;
- break;
- }
- }
- if (!found)
- {
- INTERNAL_EXN("Invalid Split output");
- }
- }
-
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto options = CreateSplitOptions(ctx.builder, node->num_split());
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_SplitOptions, options.Union());
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleSplitV *node)
-{
- auto split_outs = loco::succs(node);
- assert(int32_t(split_outs.size()) == node->num_split());
-
- uint32_t op_idx =
- ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_SPLIT_V, node->op_version());
- std::vector<int32_t> inputs_vec{get_tensor_index(node->input()),
- get_tensor_index(node->size_splits()),
- get_tensor_index(node->split_dim())};
- std::vector<int32_t> outputs_vec;
-
- for (int32_t index = 0; index < node->num_split(); index++)
- {
- // store in order of index
- bool found = false;
- for (auto out : split_outs)
- {
- auto split_out = loco::must_cast<luci::CircleSplitVOut *>(out);
- if (split_out->index() == index)
- {
- outputs_vec.push_back(get_tensor_index(split_out));
- found = true;
- break;
- }
- }
- if (!found)
- {
- INTERNAL_EXN("Invalid SplitV output");
- }
- }
-
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto options = CreateSplitVOptions(ctx.builder, node->num_split());
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_SplitVOptions, options.Union());
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleTopKV2 *node)
-{
- auto topkv2_outs = loco::succs(node);
- int outs_count = int32_t(topkv2_outs.size());
- assert(outs_count == 2);
-
- uint32_t op_idx =
- ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_TOPK_V2, node->op_version());
- std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), get_tensor_index(node->k())};
- std::vector<int32_t> outputs_vec;
-
- for (int32_t index = 0; index < outs_count; index++)
- {
- // store in order of index
- bool found = false;
- for (auto out : topkv2_outs)
- {
- auto topkv2_out = loco::must_cast<luci::CircleTopKV2Out *>(out);
- if (topkv2_out->index() == index)
- {
- outputs_vec.push_back(get_tensor_index(topkv2_out));
- found = true;
- break;
- }
- }
- if (!found)
- {
- INTERNAL_EXN("Invalid TopKV2 output");
- }
- }
-
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto options = CreateTopKV2Options(ctx.builder);
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_TopKV2Options, options.Union());
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleUnique *node)
-{
- auto unique_outs = loco::succs(node);
- assert(int32_t(unique_outs.size()) == 2);
- uint32_t op_idx =
- ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_UNIQUE, node->op_version());
-
- std::vector<int32_t> inputs_vec{get_tensor_index(node->input())};
- std::vector<int32_t> outputs_vec;
-
- for (int32_t index = 0; index < 2; index++)
- {
- // store in order of index
- bool found = false;
- for (auto out : unique_outs)
- {
- auto unique_out = loco::must_cast<luci::CircleUniqueOut *>(out);
- if (unique_out->index() == index)
- {
- outputs_vec.push_back(get_tensor_index(unique_out));
- found = true;
- break;
- }
- }
- if (!found)
- {
- INTERNAL_EXN("Invalid Unique output");
- }
- }
-
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto options = CreateUniqueOptions(ctx.builder, to_circle_tensortype(node->idx_out_type()));
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_UniqueOptions, options.Union());
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleUnpack *node)
-{
- LOGGER(l);
- auto settings = luci::UserSettings::settings();
-
- auto unpack_outs = loco::succs(node);
- // NOTE real models may not use all of the outputs
- if (static_cast<int32_t>(unpack_outs.size()) != node->num())
- {
- if (settings->get(luci::UserSettings::Key::DisableValidation))
- {
- WARN(l) << "Warning: export Unpack(" << node->name() << ") 'num' not same as outputs";
- }
- else
- assert(false);
- }
-
- uint32_t op_idx =
- ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_UNPACK, node->op_version());
- std::vector<int32_t> inputs_vec{get_tensor_index(node->value())};
- std::vector<int32_t> outputs_vec;
-
- for (int32_t index = 0; index < node->num(); index++)
- {
- // store in order of index
- bool found = false;
- for (auto out : unpack_outs)
- {
- auto unpack_out = loco::must_cast<luci::CircleUnpackOut *>(out);
- if (unpack_out->index() == index)
- {
- outputs_vec.push_back(get_tensor_index(unpack_out));
- found = true;
- break;
- }
- }
- // NOTE real models may not use all of the outputs
- if (!found)
- {
- if (settings->get(luci::UserSettings::Key::DisableValidation))
- {
- WARN(l) << "Warning: export Unpack(" << node->name() << ") output " << index << " not used";
- }
- else
- assert(false);
- }
- }
-
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto options = CreateUnpackOptions(ctx.builder, node->num(), node->axis());
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_UnpackOptions, options.Union());
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleWhile *node)
-{
- auto while_outs = loco::succs(node);
- assert(while_outs.size() == node->output_count());
-
- uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_WHILE, node->op_version());
- std::vector<int32_t> inputs_vec;
- std::vector<int32_t> outputs_vec;
-
- for (uint32_t idx = 0; idx < node->input_count(); ++idx)
- inputs_vec.push_back(get_tensor_index(node->input(idx)));
-
- for (uint32_t idx = 0; idx < node->output_count(); ++idx)
- {
- // store in order of index
- bool found = false;
- for (auto out : while_outs)
- {
- auto while_out = loco::must_cast<luci::CircleWhileOut *>(out);
- if (while_out->index() == static_cast<int32_t>(idx))
- {
- outputs_vec.push_back(get_tensor_index(while_out));
- found = true;
- break;
- }
- }
- if (!found)
- {
- INTERNAL_EXN("Invalid CircleWhile output");
- }
- }
-
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto options = CreateWhileOptions(ctx.builder, node->cond_branch(), node->body_branch());
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_WhileOptions, options.Union());
- ctx.gd._operators.push_back(op_offset);
-}
-
-class ExportHelper
-{
-public:
- ExportHelper(ExportContext &ctx) : _ctx{ctx}
- {
- // DO NOTHING
- }
-
-protected:
- /**
- * @brief export simple nodes
- */
- void export_simple(loco::Node *node, circle::BuiltinOperator bop, circle::BuiltinOptions bot,
- flatbuffers::Offset<void> options_offset)
- {
- export_node(_ctx, node, bop, bot, options_offset);
- }
-
- /**
- * @brief export simple nodes having void options
- */
- void export_simple(loco::Node *node, circle::BuiltinOperator bop)
- {
- export_node(_ctx, node, bop);
- }
-
-protected:
- ExportContext &_ctx;
-};
-
-enum class OE
-{
- ABC,
- DEF,
- GHIJ,
- KLMN,
- OPQR,
- STUV,
- WXYZ,
- CIRC, // circle only
- VIRT, // virtual
-};
-
-class OperationExporter final : public ExportHelper
-{
-public:
- OperationExporter(ExportContext &ctx) : ExportHelper(ctx)
- {
- // DO NOTHING
- }
-
-public:
- void export_node(luci::CircleNode *);
-};
-
-template <OE oe> class OpExporterLet;
-
-template <>
-class OpExporterLet<OE::ABC> final : public luci::CircleNodeMutableVisitor<void>,
- public ExportHelper
-{
-public:
- OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
- {
- // DO NOTHING
- }
-
-public:
- // NOTE visit for luci::CircleNode is added NOT to throw NYI
- void visit(luci::CircleNode *) final {}
-
-public:
- void visit(luci::CircleAbs *) final;
- void visit(luci::CircleAdd *) final;
- void visit(luci::CircleAddN *) final;
- void visit(luci::CircleArgMax *) final;
- void visit(luci::CircleArgMin *) final;
- void visit(luci::CircleAveragePool2D *) final;
- void visit(luci::CircleBatchMatMul *) final;
- void visit(luci::CircleBatchToSpaceND *) final;
- void visit(luci::CircleBidirectionalSequenceLSTM *) final;
- void visit(luci::CircleCast *) final;
- void visit(luci::CircleCeil *) final;
- void visit(luci::CircleConcatenation *) final;
- void visit(luci::CircleConst *) final{/* skip, everything is done in exportOpDefinedTensors */};
- void visit(luci::CircleConv2D *) final;
- void visit(luci::CircleCos *) final;
- void visit(luci::CircleCustom *) final;
-};
-
-template <>
-class OpExporterLet<OE::DEF> final : public luci::CircleNodeMutableVisitor<void>,
- public ExportHelper
-{
-public:
- OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
- {
- // DO NOTHING
- }
-
-public:
- void visit(luci::CircleNode *) final {}
-
-public:
- void visit(luci::CircleDepthToSpace *) final;
- void visit(luci::CircleDepthwiseConv2D *) final;
- void visit(luci::CircleDequantize *) final;
- void visit(luci::CircleDiv *) final;
- void visit(luci::CircleElu *) final;
- void visit(luci::CircleEqual *) final;
- void visit(luci::CircleExp *) final;
- void visit(luci::CircleExpandDims *) final;
- void visit(luci::CircleFakeQuant *) final;
- void visit(luci::CircleFill *) final;
- void visit(luci::CircleFloor *) final;
- void visit(luci::CircleFloorDiv *) final;
- void visit(luci::CircleFloorMod *) final;
- void visit(luci::CircleFullyConnected *) final;
-};
-
-template <>
-class OpExporterLet<OE::GHIJ> final : public luci::CircleNodeMutableVisitor<void>,
- public ExportHelper
-{
-public:
- OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
- {
- // DO NOTHING
- }
-
-public:
- void visit(luci::CircleNode *) final {}
-
-public:
- void visit(luci::CircleGather *) final;
- void visit(luci::CircleGatherNd *) final;
- void visit(luci::CircleGreater *) final;
- void visit(luci::CircleGreaterEqual *) final;
- void visit(luci::CircleIf *) final;
-};
-
-template <>
-class OpExporterLet<OE::KLMN> final : public luci::CircleNodeMutableVisitor<void>,
- public ExportHelper
-{
-public:
- OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
- {
- // DO NOTHING
- }
-
-public:
- void visit(luci::CircleNode *) final {}
-
-public:
- void visit(luci::CircleL2Normalize *) final;
- void visit(luci::CircleL2Pool2D *) final;
- void visit(luci::CircleLeakyRelu *) final;
- void visit(luci::CircleLess *) final;
- void visit(luci::CircleLessEqual *) final;
- void visit(luci::CircleLocalResponseNormalization *) final;
- void visit(luci::CircleLog *) final;
- void visit(luci::CircleLogicalAnd *) final;
- void visit(luci::CircleLogicalNot *) final;
- void visit(luci::CircleLogicalOr *) final;
- void visit(luci::CircleLogistic *) final;
- void visit(luci::CircleLogSoftmax *) final;
- void visit(luci::CircleMatrixDiag *) final;
- void visit(luci::CircleMatrixSetDiag *) final;
- void visit(luci::CircleMaximum *) final;
- void visit(luci::CircleMaxPool2D *) final;
- void visit(luci::CircleMean *) final;
- void visit(luci::CircleMinimum *) final;
- void visit(luci::CircleMirrorPad *) final;
- void visit(luci::CircleMul *) final;
- void visit(luci::CircleNeg *) final;
- void visit(luci::CircleNonMaxSuppressionV4 *) final;
- void visit(luci::CircleNonMaxSuppressionV5 *) final;
- void visit(luci::CircleNotEqual *) final;
-};
-
-template <>
-class OpExporterLet<OE::OPQR> final : public luci::CircleNodeMutableVisitor<void>,
- public ExportHelper
-{
-public:
- OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
- {
- // DO NOTHING
- }
-
-public:
- void visit(luci::CircleNode *) final {}
-
-public:
- void visit(luci::CircleOneHot *) final;
- void visit(luci::CirclePack *) final;
- void visit(luci::CirclePad *) final;
- void visit(luci::CirclePadV2 *) final;
- void visit(luci::CirclePow *) final;
- void visit(luci::CirclePRelu *) final;
- void visit(luci::CircleQuantize *) final;
- void visit(luci::CircleRange *) final;
- void visit(luci::CircleRank *) final;
- void visit(luci::CircleReduceAny *) final;
- void visit(luci::CircleReduceMax *) final;
- void visit(luci::CircleReduceMin *) final;
- void visit(luci::CircleReduceProd *) final;
- void visit(luci::CircleRelu *) final;
- void visit(luci::CircleRelu6 *) final;
- void visit(luci::CircleReluN1To1 *) final;
- void visit(luci::CircleReshape *) final;
- void visit(luci::CircleResizeBilinear *) final;
- void visit(luci::CircleResizeNearestNeighbor *) final;
- void visit(luci::CircleReverseSequence *) final;
- void visit(luci::CircleReverseV2 *) final;
- void visit(luci::CircleRound *) final;
- void visit(luci::CircleRsqrt *) final;
-};
-
-template <>
-class OpExporterLet<OE::STUV> final : public luci::CircleNodeMutableVisitor<void>,
- public ExportHelper
-{
-public:
- OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
- {
- // DO NOTHING
- }
-
-public:
- void visit(luci::CircleNode *) final {}
-
-public:
- void visit(luci::CircleScatterNd *) final;
- void visit(luci::CircleSegmentSum *) final;
- void visit(luci::CircleSelect *) final;
- void visit(luci::CircleSelectV2 *) final;
- void visit(luci::CircleShape *) final;
- void visit(luci::CircleSin *) final;
- void visit(luci::CircleSlice *) final;
- void visit(luci::CircleSoftmax *) final;
- void visit(luci::CircleSpaceToBatchND *) final;
- void visit(luci::CircleSpaceToDepth *) final;
- void visit(luci::CircleSparseToDense *) final;
- void visit(luci::CircleSplit *) final;
- void visit(luci::CircleSplitV *) final;
- void visit(luci::CircleSqrt *) final;
- void visit(luci::CircleSquare *) final;
- void visit(luci::CircleSquaredDifference *) final;
- void visit(luci::CircleSqueeze *) final;
- void visit(luci::CircleStridedSlice *) final;
- void visit(luci::CircleSub *) final;
- void visit(luci::CircleSum *) final;
- void visit(luci::CircleTanh *) final;
- void visit(luci::CircleTile *) final;
- void visit(luci::CircleTopKV2 *) final;
- void visit(luci::CircleTranspose *) final;
- void visit(luci::CircleTransposeConv *) final;
- void visit(luci::CircleUnidirectionalSequenceLSTM *) final;
- void visit(luci::CircleUnique *) final;
- void visit(luci::CircleUnpack *) final;
-};
-
-template <>
-class OpExporterLet<OE::WXYZ> final : public luci::CircleNodeMutableVisitor<void>,
- public ExportHelper
-{
-public:
- OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
- {
- // DO NOTHING
- }
-
-public:
- void visit(luci::CircleNode *) final {}
-
-public:
- void visit(luci::CircleWhere *) final;
- void visit(luci::CircleWhile *) final;
- void visit(luci::CircleZerosLike *) final;
-};
-
-template <>
-class OpExporterLet<OE::CIRC> final : public luci::CircleNodeMutableVisitor<void>,
- public ExportHelper
-{
-public:
- OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
- {
- // DO NOTHING
- }
-
-public:
- void visit(luci::CircleNode *) final {}
-
-public:
- // Circle only
- void visit(luci::CircleBCQFullyConnected *) final;
- void visit(luci::CircleBCQGather *) final;
- void visit(luci::CircleInstanceNorm *) final;
-};
-
-template <>
-class OpExporterLet<OE::VIRT> final : public luci::CircleNodeMutableVisitor<void>,
- public ExportHelper
-{
-public:
- OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
- {
- // DO NOTHING
- }
-
-public:
- void visit(luci::CircleNode *) final {}
-
-public:
- // Virtual
- void visit(luci::CircleInput *) final {}
- void visit(luci::CircleOutput *) final {}
- void visit(luci::CircleOutputDummy *) final {}
- void visit(luci::CircleOutputExclude *) final {}
- // Virtual for multiple-outputs
- void visit(luci::CircleBidirectionalSequenceLSTMOut *) final {}
- void visit(luci::CircleCustomOut *) final {}
- void visit(luci::CircleIfOut *) final {}
- void visit(luci::CircleNonMaxSuppressionV4Out *) final {}
- void visit(luci::CircleNonMaxSuppressionV5Out *) final {}
- void visit(luci::CircleSplitOut *) final {}
- void visit(luci::CircleSplitVOut *) final {}
- void visit(luci::CircleTopKV2Out *) final {}
- void visit(luci::CircleUniqueOut *) final {}
- void visit(luci::CircleUnpackOut *) final {}
- void visit(luci::CircleWhileOut *) final {}
-};
-
-void OperationExporter::export_node(luci::CircleNode *node)
-{
- // TODO revise return type to bool and return if handled
-#define VISIT_OE(GRP) \
- do \
- { \
- OpExporterLet<OE::GRP> oe(_ctx); \
- node->accept(&oe); \
- } while (false)
-
- VISIT_OE(ABC);
- VISIT_OE(DEF);
- VISIT_OE(GHIJ);
- VISIT_OE(KLMN);
- VISIT_OE(OPQR);
- VISIT_OE(STUV);
- VISIT_OE(WXYZ);
- VISIT_OE(CIRC);
- VISIT_OE(VIRT);
-
-#undef VISIT_OE
-}
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleAbs *node)
-{
- export_simple(node, circle::BuiltinOperator_ABS, circle::BuiltinOptions_AbsOptions,
- CreateAbsOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleAdd *node)
-{
- export_simple(
- node, circle::BuiltinOperator_ADD, circle::BuiltinOptions_AddOptions,
- CreateAddOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union());
-}
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleAddN *node) { export_node(_ctx, node); }
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleArgMax *node)
-{
- export_simple(
- node, circle::BuiltinOperator_ARG_MAX, circle::BuiltinOptions_ArgMaxOptions,
- CreateArgMaxOptions(_ctx.builder, to_circle_tensortype(node->output_type())).Union());
-}
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleArgMin *node)
-{
- export_simple(
- node, circle::BuiltinOperator_ARG_MIN, circle::BuiltinOptions_ArgMinOptions,
- CreateArgMinOptions(_ctx.builder, to_circle_tensortype(node->output_type())).Union());
-}
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleAveragePool2D *node)
-{
- export_pool_2d<luci::CircleAveragePool2D>(_ctx, node, circle::BuiltinOperator_AVERAGE_POOL_2D);
-}
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleBatchMatMul *node)
-{
- export_simple(node, circle::BuiltinOperator_BATCH_MATMUL,
- circle::BuiltinOptions_BatchMatMulOptions,
- CreateBatchMatMulOptions(_ctx.builder, node->adj_x(), node->adj_y()).Union());
-}
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleBidirectionalSequenceLSTM *node)
-{
- auto bidi_lstm_outs = loco::succs(node);
- assert((bidi_lstm_outs.size() == 1) || (bidi_lstm_outs.size() == 2));
- uint32_t op_idx = _ctx.md.registerBuiltinOpcode(
- circle::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, node->op_version());
-
- std::vector<int32_t> inputs_vec{get_tensor_index(node->input())};
- std::vector<int32_t> outputs_vec;
-
- for (int32_t index = 0; index < 2; index++)
- {
- // store in order of index
- bool found = false;
- for (auto out : bidi_lstm_outs)
- {
- auto bidi_lstm_out = loco::must_cast<luci::CircleBidirectionalSequenceLSTMOut *>(out);
- if (bidi_lstm_out->index() == index)
- {
- outputs_vec.push_back(get_tensor_index(bidi_lstm_out));
- found = true;
- break;
- }
- }
- if (!found)
- {
- INTERNAL_EXN("Invalid BidirectionalSequenceLSTM output");
- }
- }
-
- auto inputs = _ctx.builder.CreateVector(inputs_vec);
- auto outputs = _ctx.builder.CreateVector(outputs_vec);
- auto options = CreateBidirectionalSequenceLSTMOptions(
- _ctx.builder, to_circle_actfunc(node->fusedActivationFunction()), node->cell_clip(),
- node->proj_clip(), node->merge_outputs(), node->time_major(),
- node->asymmetric_quantize_inputs());
- auto op_offset =
- CreateOperator(_ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_BidirectionalSequenceLSTMOptions, options.Union());
- _ctx.gd._operators.push_back(op_offset);
-}
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleCast *node) { export_node(_ctx, node); }
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleCeil *node)
-{
- export_simple(node, circle::BuiltinOperator_CEIL);
-}
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleConcatenation *node) { export_node(_ctx, node); }
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleBatchToSpaceND *node)
-{
- export_simple(node, circle::BuiltinOperator_BATCH_TO_SPACE_ND,
- circle::BuiltinOptions_BatchToSpaceNDOptions,
- CreateBatchToSpaceNDOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleConv2D *node)
-{
- export_simple(node, circle::BuiltinOperator_CONV_2D, circle::BuiltinOptions_Conv2DOptions,
- CreateConv2DOptions(_ctx.builder, getOpPadding(node->padding()),
- node->stride()->w(), node->stride()->h(),
- to_circle_actfunc(node->fusedActivationFunction()),
- node->dilation()->w(), node->dilation()->h())
- .Union());
-}
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleCos *node)
-{
- export_simple(node, circle::BuiltinOperator_COS, circle::BuiltinOptions_CosOptions,
- CreateCosOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleCustom *node) { export_node(_ctx, node); }
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleDepthToSpace *node)
-{
- export_simple(node, circle::BuiltinOperator_DEPTH_TO_SPACE,
- circle::BuiltinOptions_DepthToSpaceOptions,
- CreateDepthToSpaceOptions(_ctx.builder, node->block_size()).Union());
-}
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleDepthwiseConv2D *node)
-{
- export_simple(
- node, circle::BuiltinOperator_DEPTHWISE_CONV_2D, circle::BuiltinOptions_DepthwiseConv2DOptions,
- CreateDepthwiseConv2DOptions(_ctx.builder, getOpPadding(node->padding()), node->stride()->w(),
- node->stride()->h(), node->depthMultiplier(),
- to_circle_actfunc(node->fusedActivationFunction()),
- node->dilation()->w(), node->dilation()->h())
- .Union());
-}
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleDequantize *node)
-{
- export_simple(node, circle::BuiltinOperator_DEQUANTIZE);
-}
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleDiv *node)
-{
- export_simple(
- node, circle::BuiltinOperator_DIV, circle::BuiltinOptions_DivOptions,
- CreateDivOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union());
-}
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleElu *node)
-{
- export_simple(node, circle::BuiltinOperator_ELU);
-}
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleEqual *node)
-{
- export_simple(node, circle::BuiltinOperator_EQUAL, circle::BuiltinOptions_EqualOptions,
- CreateEqualOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleExp *node)
-{
- export_simple(node, circle::BuiltinOperator_EXP, circle::BuiltinOptions_ExpOptions,
- CreateExpOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleExpandDims *node)
-{
- export_simple(node, circle::BuiltinOperator_EXPAND_DIMS, circle::BuiltinOptions_ExpandDimsOptions,
- CreateExpandDimsOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleFakeQuant *node)
-{
- export_simple(node, circle::BuiltinOperator_FAKE_QUANT, circle::BuiltinOptions_FakeQuantOptions,
- CreateFakeQuantOptions(_ctx.builder, node->min(), node->max(), node->num_bits(),
- node->narrow_range())
- .Union());
-}
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleFill *node)
-{
- export_simple(node, circle::BuiltinOperator_FILL, circle::BuiltinOptions_FillOptions,
- CreateFillOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleFloor *node)
-{
- export_simple(node, circle::BuiltinOperator_FLOOR);
-}
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleFloorDiv *node)
-{
- export_simple(node, circle::BuiltinOperator_FLOOR_DIV, circle::BuiltinOptions_FloorDivOptions,
- CreateFloorDivOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleFloorMod *node)
-{
- export_simple(node, circle::BuiltinOperator_FLOOR_MOD, circle::BuiltinOptions_FloorModOptions,
- CreateFloorModOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleFullyConnected *node)
-{
- export_simple(
- node, circle::BuiltinOperator_FULLY_CONNECTED, circle::BuiltinOptions_FullyConnectedOptions,
- CreateFullyConnectedOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction()),
- to_circle_weightsformat(node->weights_format()))
- .Union());
-}
-
-void OpExporterLet<OE::GHIJ>::visit(luci::CircleGather *node)
-{
- export_simple(node, circle::BuiltinOperator_GATHER, circle::BuiltinOptions_GatherOptions,
- CreateGatherOptions(_ctx.builder, node->axis()).Union());
-}
-
-void OpExporterLet<OE::GHIJ>::visit(luci::CircleGatherNd *node)
-{
- export_simple(node, circle::BuiltinOperator_GATHER_ND, circle::BuiltinOptions_GatherNdOptions,
- CreateGatherNdOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::GHIJ>::visit(luci::CircleGreater *node)
-{
- export_simple(node, circle::BuiltinOperator_GREATER, circle::BuiltinOptions_GreaterOptions,
- CreateGreaterOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::GHIJ>::visit(luci::CircleGreaterEqual *node)
-{
- export_simple(node, circle::BuiltinOperator_GREATER_EQUAL,
- circle::BuiltinOptions_GreaterEqualOptions,
- CreateGreaterEqualOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::GHIJ>::visit(luci::CircleIf *node) { export_node(_ctx, node); }
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleL2Normalize *node)
-{
- export_simple(
- node, circle::BuiltinOperator_L2_NORMALIZATION, circle::BuiltinOptions_L2NormOptions,
- CreateL2NormOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleL2Pool2D *node)
-{
- export_pool_2d<luci::CircleL2Pool2D>(_ctx, node, circle::BuiltinOperator_L2_POOL_2D);
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleLeakyRelu *node)
-{
- export_simple(node, circle::BuiltinOperator_LEAKY_RELU, circle::BuiltinOptions_LeakyReluOptions,
- CreateLeakyReluOptions(_ctx.builder, node->alpha()).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleLess *node)
-{
- export_simple(node, circle::BuiltinOperator_LESS, circle::BuiltinOptions_LessOptions,
- CreateLessOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleLessEqual *node)
-{
- export_simple(node, circle::BuiltinOperator_LESS_EQUAL, circle::BuiltinOptions_LessEqualOptions,
- CreateLessEqualOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleLocalResponseNormalization *node)
-{
- export_simple(node, circle::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
- circle::BuiltinOptions_LocalResponseNormalizationOptions,
- CreateLocalResponseNormalizationOptions(_ctx.builder, node->radius(), node->bias(),
- node->alpha(), node->beta())
- .Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleLog *node)
-{
- export_simple(node, circle::BuiltinOperator_LOG);
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleLogicalAnd *node)
-{
- export_simple(node, circle::BuiltinOperator_LOGICAL_AND, circle::BuiltinOptions_LogicalAndOptions,
- CreateLogicalAndOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleLogicalNot *node)
-{
- export_simple(node, circle::BuiltinOperator_LOGICAL_NOT, circle::BuiltinOptions_LogicalNotOptions,
- CreateLogicalNotOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleLogicalOr *node)
-{
- export_simple(node, circle::BuiltinOperator_LOGICAL_OR, circle::BuiltinOptions_LogicalOrOptions,
- CreateLogicalOrOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleLogistic *node)
-{
- export_simple(node, circle::BuiltinOperator_LOGISTIC);
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleLogSoftmax *node)
-{
- export_simple(node, circle::BuiltinOperator_LOG_SOFTMAX, circle::BuiltinOptions_LogSoftmaxOptions,
- CreateLogSoftmaxOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleMatrixDiag *node)
-{
- export_simple(node, circle::BuiltinOperator_MATRIX_DIAG, circle::BuiltinOptions_MatrixDiagOptions,
- CreateMatrixDiagOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleMatrixSetDiag *node)
-{
- export_simple(node, circle::BuiltinOperator_MATRIX_SET_DIAG,
- circle::BuiltinOptions_MatrixSetDiagOptions,
- CreateMatrixSetDiagOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleMaximum *node)
-{
- export_simple(node, circle::BuiltinOperator_MAXIMUM, circle::BuiltinOptions_MaximumMinimumOptions,
- CreateMaximumMinimumOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleMaxPool2D *node)
-{
- export_pool_2d<luci::CircleMaxPool2D>(_ctx, node, circle::BuiltinOperator_MAX_POOL_2D);
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleMean *node)
-{
- export_simple(node, circle::BuiltinOperator_MEAN, circle::BuiltinOptions_ReducerOptions,
- CreateReducerOptions(_ctx.builder, node->keep_dims()).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleMinimum *node)
-{
- export_simple(node, circle::BuiltinOperator_MINIMUM, circle::BuiltinOptions_MaximumMinimumOptions,
- CreateMaximumMinimumOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleMirrorPad *node)
-{
- export_simple(
- node, circle::BuiltinOperator_MIRROR_PAD, circle::BuiltinOptions_MirrorPadOptions,
- CreateMirrorPadOptions(_ctx.builder, to_circle_mirrorpadmode(node->mode())).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleMul *node)
-{
- export_simple(
- node, circle::BuiltinOperator_MUL, circle::BuiltinOptions_MulOptions,
- CreateMulOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleNeg *node)
-{
- export_simple(node, circle::BuiltinOperator_NEG, circle::BuiltinOptions_NegOptions,
- CreateNegOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleNonMaxSuppressionV4 *node)
-{
- export_node(_ctx, node);
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleNonMaxSuppressionV5 *node)
-{
- export_node(_ctx, node);
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleNotEqual *node)
-{
- export_simple(node, circle::BuiltinOperator_NOT_EQUAL, circle::BuiltinOptions_NotEqualOptions,
- CreateNotEqualOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleOneHot *node)
-{
- export_simple(node, circle::BuiltinOperator_ONE_HOT, circle::BuiltinOptions_OneHotOptions,
- CreateOneHotOptions(_ctx.builder, node->axis()).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CirclePack *node)
-{
- export_simple(node, circle::BuiltinOperator_PACK, circle::BuiltinOptions_PackOptions,
- CreatePackOptions(_ctx.builder, node->values_count(), node->axis()).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CirclePad *node)
-{
- export_simple(node, circle::BuiltinOperator_PAD, circle::BuiltinOptions_PadOptions,
- CreatePadOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CirclePadV2 *node)
-{
- export_simple(node, circle::BuiltinOperator_PADV2, circle::BuiltinOptions_PadV2Options,
- CreatePadV2Options(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CirclePow *node)
-{
- export_simple(node, circle::BuiltinOperator_POW, circle::BuiltinOptions_PowOptions,
- CreatePowOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CirclePRelu *node)
-{
- export_simple(node, circle::BuiltinOperator_PRELU);
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleQuantize *node)
-{
- export_simple(node, circle::BuiltinOperator_QUANTIZE);
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleRange *node)
-{
- export_simple(node, circle::BuiltinOperator_RANGE, circle::BuiltinOptions_RangeOptions,
- CreateRangeOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleRank *node)
-{
- export_simple(node, circle::BuiltinOperator_RANK, circle::BuiltinOptions_RankOptions,
- CreateRankOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleReduceAny *node)
-{
- export_simple(node, circle::BuiltinOperator_REDUCE_ANY, circle::BuiltinOptions_ReducerOptions,
- CreateReducerOptions(_ctx.builder, node->keep_dims()).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleReduceMax *node)
-{
- export_simple(node, circle::BuiltinOperator_REDUCE_MAX, circle::BuiltinOptions_ReducerOptions,
- CreateReducerOptions(_ctx.builder, node->keep_dims()).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleReduceMin *node)
-{
- export_simple(node, circle::BuiltinOperator_REDUCE_MIN, circle::BuiltinOptions_ReducerOptions,
- CreateReducerOptions(_ctx.builder, node->keep_dims()).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleReduceProd *node)
-{
- export_simple(node, circle::BuiltinOperator_REDUCE_PROD, circle::BuiltinOptions_ReducerOptions,
- CreateReducerOptions(_ctx.builder, node->keep_dims()).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleRelu *node)
-{
- export_simple(node, circle::BuiltinOperator_RELU);
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleRelu6 *node)
-{
- export_simple(node, circle::BuiltinOperator_RELU6);
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleReluN1To1 *node)
-{
- export_simple(node, circle::BuiltinOperator_RELU_N1_TO_1);
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleReshape *node)
-{
- auto new_shape = _ctx.builder.CreateVector<int32_t>(
- node->newShape()->rank(), [node](size_t i) { return node->newShape()->dim(i); });
-
- export_simple(node, circle::BuiltinOperator_RESHAPE, circle::BuiltinOptions_ReshapeOptions,
- CreateReshapeOptions(_ctx.builder, new_shape).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleResizeBilinear *node)
-{
- export_simple(
- node, circle::BuiltinOperator_RESIZE_BILINEAR, circle::BuiltinOptions_ResizeBilinearOptions,
- CreateResizeBilinearOptions(_ctx.builder, node->align_corners(), node->half_pixel_centers())
- .Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleResizeNearestNeighbor *node)
-{
- export_simple(node, circle::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
- circle::BuiltinOptions_ResizeNearestNeighborOptions,
- CreateResizeNearestNeighborOptions(_ctx.builder, node->align_corners()).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleReverseSequence *node)
-{
- export_simple(
- node, circle::BuiltinOperator_REVERSE_SEQUENCE, circle::BuiltinOptions_ReverseSequenceOptions,
- CreateReverseSequenceOptions(_ctx.builder, node->seq_axis(), node->batch_axis()).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleReverseV2 *node) { export_node(_ctx, node); }
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleRound *node)
-{
- export_simple(node, circle::BuiltinOperator_ROUND);
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleRsqrt *node)
-{
- export_simple(node, circle::BuiltinOperator_RSQRT);
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleScatterNd *node)
-{
- export_simple(node, circle::BuiltinOperator_SCATTER_ND, circle::BuiltinOptions_ScatterNdOptions,
- CreateScatterNdOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSegmentSum *node)
-{
- export_simple(node, circle::BuiltinOperator_SEGMENT_SUM, circle::BuiltinOptions_SegmentSumOptions,
- CreateSegmentSumOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSelect *node)
-{
- export_simple(node, circle::BuiltinOperator_SELECT, circle::BuiltinOptions_SelectOptions,
- CreateSelectOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSelectV2 *node)
-{
- export_simple(node, circle::BuiltinOperator_SELECT_V2, circle::BuiltinOptions_SelectV2Options,
- CreateSelectV2Options(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleShape *node)
-{
- export_simple(node, circle::BuiltinOperator_SHAPE, circle::BuiltinOptions_ShapeOptions,
- CreateShapeOptions(_ctx.builder, to_circle_tensortype(node->out_type())).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSin *node)
-{
- export_simple(node, circle::BuiltinOperator_SIN);
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSlice *node)
-{
- export_simple(node, circle::BuiltinOperator_SLICE, circle::BuiltinOptions_SliceOptions,
- CreateSliceOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSoftmax *node)
-{
- export_simple(node, circle::BuiltinOperator_SOFTMAX, circle::BuiltinOptions_SoftmaxOptions,
- CreateSoftmaxOptions(_ctx.builder, node->beta()).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSpaceToBatchND *node)
-{
- export_simple(node, circle::BuiltinOperator_SPACE_TO_BATCH_ND,
- circle::BuiltinOptions_SpaceToBatchNDOptions,
- CreateSpaceToBatchNDOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSpaceToDepth *node)
-{
- export_simple(node, circle::BuiltinOperator_SPACE_TO_DEPTH,
- circle::BuiltinOptions_SpaceToDepthOptions,
- CreateSpaceToDepthOptions(_ctx.builder, node->block_size()).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSparseToDense *node)
-{
- export_simple(node, circle::BuiltinOperator_SPARSE_TO_DENSE,
- circle::BuiltinOptions_SparseToDenseOptions,
- CreateSparseToDenseOptions(_ctx.builder, node->validate_indices()).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSplit *node) { export_node(_ctx, node); }
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSplitV *node) { export_node(_ctx, node); }
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSqrt *node)
-{
- export_simple(node, circle::BuiltinOperator_SQRT);
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSquare *node)
-{
- export_simple(node, circle::BuiltinOperator_SQUARE, circle::BuiltinOptions_SquareOptions,
- CreateSquareOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSquaredDifference *node)
-{
- export_simple(node, circle::BuiltinOperator_SQUARED_DIFFERENCE,
- circle::BuiltinOptions_SquaredDifferenceOptions,
- CreateSquaredDifferenceOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSqueeze *node)
-{
- auto squeeze_dims = _ctx.builder.CreateVector<int32_t>(node->squeeze_dims());
- export_simple(node, circle::BuiltinOperator_SQUEEZE, circle::BuiltinOptions_SqueezeOptions,
- CreateSqueezeOptions(_ctx.builder, squeeze_dims).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleStridedSlice *node)
-{
- export_simple(node, circle::BuiltinOperator_STRIDED_SLICE,
- circle::BuiltinOptions_StridedSliceOptions,
- CreateStridedSliceOptions(_ctx.builder, node->begin_mask(), node->end_mask(),
- node->ellipsis_mask(), node->new_axis_mask(),
- node->shrink_axis_mask())
- .Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSub *node)
-{
- export_simple(
- node, circle::BuiltinOperator_SUB, circle::BuiltinOptions_SubOptions,
- CreateSubOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSum *node)
-{
- export_simple(node, circle::BuiltinOperator_SUM, circle::BuiltinOptions_ReducerOptions,
- CreateReducerOptions(_ctx.builder, node->keep_dims()).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleTanh *node)
-{
- export_simple(node, circle::BuiltinOperator_TANH);
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleTile *node)
-{
- export_simple(node, circle::BuiltinOperator_TILE, circle::BuiltinOptions_TileOptions,
- CreateTileOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleTopKV2 *node) { export_node(_ctx, node); }
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleTranspose *node)
-{
- export_simple(node, circle::BuiltinOperator_TRANSPOSE, circle::BuiltinOptions_TransposeOptions,
- CreateTransposeOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleTransposeConv *node)
-{
- export_simple(node, circle::BuiltinOperator_TRANSPOSE_CONV,
- circle::BuiltinOptions_TransposeConvOptions,
- CreateTransposeConvOptions(_ctx.builder, getOpPadding(node->padding()),
- node->stride()->w(), node->stride()->h())
- .Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleUnidirectionalSequenceLSTM *node)
-{
- export_simple(node, circle::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
- circle::BuiltinOptions_UnidirectionalSequenceLSTMOptions,
- CreateUnidirectionalSequenceLSTMOptions(
- _ctx.builder, to_circle_actfunc(node->fusedActivationFunction()),
- node->cell_clip(), node->proj_clip(), node->time_major(),
- node->asymmetric_quantize_inputs())
- .Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleUnique *node) { export_node(_ctx, node); }
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleUnpack *node) { export_node(_ctx, node); }
-
-void OpExporterLet<OE::WXYZ>::visit(luci::CircleWhere *node)
-{
- export_simple(node, circle::BuiltinOperator_WHERE, circle::BuiltinOptions_WhereOptions,
- CreateWhereOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::WXYZ>::visit(luci::CircleWhile *node) { export_node(_ctx, node); }
-
-void OpExporterLet<OE::WXYZ>::visit(luci::CircleZerosLike *node)
-{
- export_simple(node, circle::BuiltinOperator_ZEROS_LIKE, circle::BuiltinOptions_ZerosLikeOptions,
- CreateZerosLikeOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::CIRC>::visit(luci::CircleBCQFullyConnected *node)
-{
- export_simple(node, circle::BuiltinOperator_BCQ_FULLY_CONNECTED,
- circle::BuiltinOptions_BCQFullyConnectedOptions,
- CreateBCQFullyConnectedOptions(_ctx.builder, node->weights_hidden_size(),
- to_circle_actfunc(node->fusedActivationFunction()))
- .Union());
-}
-
-void OpExporterLet<OE::CIRC>::visit(luci::CircleBCQGather *node)
-{
- export_simple(
- node, circle::BuiltinOperator_BCQ_GATHER, circle::BuiltinOptions_BCQGatherOptions,
- CreateBCQGatherOptions(_ctx.builder, node->input_hidden_size(), node->axis()).Union());
-}
-
-void OpExporterLet<OE::CIRC>::visit(luci::CircleInstanceNorm *node)
+namespace luci
{
- export_simple(node, circle::BuiltinOperator_INSTANCE_NORM,
- circle::BuiltinOptions_InstanceNormOptions,
- CreateInstanceNormOptions(_ctx.builder, node->epsilon(),
- to_circle_actfunc(node->fusedActivationFunction()))
- .Union());
-}
-void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder, SerializedModelData &md,
- SerializedGraphData &gd, uint32_t node_position)
+void exportNodes(loco::Graph *g, flatbuffers::FlatBufferBuilder &builder, SerializedModelData &md,
+ SerializedGraphData &gd)
{
- if (auto circle_node = dynamic_cast<luci::CircleNode *>(node))
+ uint32_t node_position = 0;
+ for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
{
ExportContext ctx{builder, md, gd};
- OperationExporter exporter{ctx};
+ OperationExporterRule exporter_rule{ctx};
+
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ circle_node->accept(&exporter_rule);
const auto ops_size = gd._operators.size();
- exporter.export_node(circle_node);
if (has_origin(circle_node) && ops_size != gd._operators.size())
{
const auto node_id = gd._operators.size() - 1;
@@ -1716,25 +60,7 @@ void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder, Seria
}
md._metadata.add_execution_plan_table(node_position, execution_plan_vector);
}
- }
- else
- {
- INTERNAL_EXN("Node with unsupported dialect found");
- }
-}
-} // namespace
-
-namespace luci
-{
-
-void exportNodes(loco::Graph *g, FlatBufferBuilder &builder, SerializedModelData &md,
- SerializedGraphData &gd)
-{
- uint32_t node_position = 0;
- for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
- {
- exportNode(node, builder, md, gd, node_position);
node_position++;
}
}
diff --git a/compiler/luci/export/src/CircleOperationExporter.h b/compiler/luci/export/src/CircleOperationExporter.h
index de6abfc54..f2b3cfd6b 100644
--- a/compiler/luci/export/src/CircleOperationExporter.h
+++ b/compiler/luci/export/src/CircleOperationExporter.h
@@ -17,7 +17,7 @@
#ifndef __CIRCLE_OPERATION_EXPORTER_H__
#define __CIRCLE_OPERATION_EXPORTER_H__
-#include "CircleExporterUtils.h"
+#include "SerializedData.h"
#include <loco/IR/Graph.h>
diff --git a/compiler/luci/export/src/CircleOperationExporterRule.cpp b/compiler/luci/export/src/CircleOperationExporterRule.cpp
new file mode 100644
index 000000000..8dc59fa9c
--- /dev/null
+++ b/compiler/luci/export/src/CircleOperationExporterRule.cpp
@@ -0,0 +1,277 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleOperationExporterRule.h"
+#include "CircleBuiltinTypesExtractor.h"
+#include "Check.h"
+
+#include <loco/IR/Graph.h>
+#include <luci/IR/CircleNode.h>
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+#include <oops/InternalExn.h>
+
+#include <vector>
+
+namespace
+{
+class OutputVectorExtractor final : public luci::CircleNodeMutableVisitor<std::vector<int32_t>>
+{
+public:
+ OutputVectorExtractor()
+ {
+ // DO NOTHING
+ }
+
+public:
+ std::vector<int32_t> visit(luci::CircleNode *node) final
+ {
+ std::vector<int32_t> outputs_vec{luci::get_tensor_index(node)};
+ return outputs_vec;
+ }
+
+ std::vector<int32_t> visit(luci::CircleBidirectionalSequenceLSTM *node) final
+ {
+ auto bidi_lstm_outs = loco::succs(node);
+ assert((bidi_lstm_outs.size() == 1) || (bidi_lstm_outs.size() == 2));
+
+ std::vector<int32_t> outputs_vec(bidi_lstm_outs.size());
+
+ for (auto out : bidi_lstm_outs)
+ {
+ auto bidi_lstm_out = loco::must_cast<luci::CircleBidirectionalSequenceLSTMOut *>(out);
+ if (bidi_lstm_out->index() >= int32_t(bidi_lstm_outs.size()))
+ INTERNAL_EXN("Invalid BidirectionalSequenceLSTM output");
+ outputs_vec[bidi_lstm_out->index()] = luci::get_tensor_index(bidi_lstm_out);
+ }
+
+ return outputs_vec;
+ }
+
+ std::vector<int32_t> visit(luci::CircleCustom *node) final
+ {
+ auto custom_outputs = loco::succs(node);
+ assert(custom_outputs.size() == node->numOutputs());
+
+ std::vector<int32_t> outputs_vec(node->numOutputs());
+
+ for (auto out : custom_outputs)
+ {
+ auto custom_out = loco::must_cast<luci::CircleCustomOut *>(out);
+ if (custom_out->index() >= int32_t(node->numOutputs()))
+ INTERNAL_EXN("Invalid Custom output");
+ outputs_vec[custom_out->index()] = luci::get_tensor_index(custom_out);
+ }
+
+ return outputs_vec;
+ }
+
+ std::vector<int32_t> visit(luci::CircleIf *node) final
+ {
+ auto if_outs = loco::succs(node);
+ assert(if_outs.size() == node->output_count());
+
+ std::vector<int32_t> outputs_vec(node->output_count());
+
+ for (auto out : if_outs)
+ {
+ auto if_out = loco::must_cast<luci::CircleIfOut *>(out);
+ if (if_out->index() >= int32_t(node->output_count()))
+ INTERNAL_EXN("Invalid If output");
+ outputs_vec[if_out->index()] = luci::get_tensor_index(if_out);
+ }
+
+ return outputs_vec;
+ }
+
+ std::vector<int32_t> visit(luci::CircleNonMaxSuppressionV4 *node) final
+ {
+ auto nms_outs = loco::succs(node);
+ assert(nms_outs.size() == 2);
+
+ std::vector<int32_t> outputs_vec(2);
+
+ for (auto out : nms_outs)
+ {
+ auto nms_out = loco::must_cast<luci::CircleNonMaxSuppressionV4Out *>(out);
+ if (nms_out->index() >= 2)
+ INTERNAL_EXN("Invalid NonMaxSuppressionV4 output");
+ outputs_vec[nms_out->index()] = luci::get_tensor_index(nms_out);
+ }
+
+ return outputs_vec;
+ }
+
+ std::vector<int32_t> visit(luci::CircleNonMaxSuppressionV5 *node) final
+ {
+ auto nms_outs = loco::succs(node);
+ assert(nms_outs.size() == 3);
+
+ std::vector<int32_t> outputs_vec(3);
+
+ for (auto out : nms_outs)
+ {
+ auto nms_out = loco::must_cast<luci::CircleNonMaxSuppressionV5Out *>(out);
+ if (nms_out->index() >= 3)
+ INTERNAL_EXN("Invalid NonMaxSuppressionV5 output");
+ outputs_vec[nms_out->index()] = luci::get_tensor_index(nms_out);
+ }
+
+ return outputs_vec;
+ }
+
+ std::vector<int32_t> visit(luci::CircleSplit *node) final
+ {
+ auto split_outs = loco::succs(node);
+ assert(int32_t(split_outs.size()) == node->num_split());
+
+ std::vector<int32_t> outputs_vec(node->num_split());
+
+ for (auto out : split_outs)
+ {
+ auto split_out = loco::must_cast<luci::CircleSplitOut *>(out);
+ if (split_out->index() >= node->num_split())
+ INTERNAL_EXN("Invalid Split output");
+ outputs_vec[split_out->index()] = luci::get_tensor_index(split_out);
+ }
+
+ return outputs_vec;
+ }
+
+ std::vector<int32_t> visit(luci::CircleSplitV *node) final
+ {
+ auto split_outs = loco::succs(node);
+ assert(int32_t(split_outs.size()) == node->num_split());
+
+ std::vector<int32_t> outputs_vec(node->num_split());
+
+ for (auto out : split_outs)
+ {
+ auto split_out = loco::must_cast<luci::CircleSplitVOut *>(out);
+ if (split_out->index() >= node->num_split())
+ INTERNAL_EXN("Invalid SplitV output");
+ outputs_vec[split_out->index()] = luci::get_tensor_index(split_out);
+ }
+
+ return outputs_vec;
+ }
+
+ std::vector<int32_t> visit(luci::CircleTopKV2 *node) final
+ {
+ auto topkv2_outs = loco::succs(node);
+ assert(topkv2_outs.size() == 2);
+
+ std::vector<int32_t> outputs_vec(2);
+
+ for (auto out : topkv2_outs)
+ {
+ auto topkv2_out = loco::must_cast<luci::CircleTopKV2Out *>(out);
+ if (topkv2_out->index() >= 2)
+ INTERNAL_EXN("Invalid TopKV2 output");
+ outputs_vec[topkv2_out->index()] = luci::get_tensor_index(topkv2_out);
+ }
+
+ return outputs_vec;
+ }
+
+ std::vector<int32_t> visit(luci::CircleUnique *node) final
+ {
+ auto unique_outs = loco::succs(node);
+ assert(unique_outs.size() == 2);
+
+ std::vector<int32_t> outputs_vec(2);
+
+ for (auto out : unique_outs)
+ {
+ auto unique_out = loco::must_cast<luci::CircleUniqueOut *>(out);
+ if (unique_out->index() >= 2)
+ INTERNAL_EXN("Invalid Unique output");
+ outputs_vec[unique_out->index()] = luci::get_tensor_index(unique_out);
+ }
+
+ return outputs_vec;
+ }
+
+ std::vector<int32_t> visit(luci::CircleUnpack *node) final
+ {
+ auto unpack_outs = loco::succs(node);
+ assert(int32_t(unpack_outs.size()) == node->num());
+
+ std::vector<int32_t> outputs_vec(node->num());
+
+ for (auto out : unpack_outs)
+ {
+ auto unpack_out = loco::must_cast<luci::CircleUnpackOut *>(out);
+ if (unpack_out->index() >= node->num())
+ INTERNAL_EXN("Invalid Unpack output");
+ outputs_vec[unpack_out->index()] = luci::get_tensor_index(unpack_out);
+ }
+
+ return outputs_vec;
+ }
+
+ std::vector<int32_t> visit(luci::CircleWhile *node) final
+ {
+ auto while_outs = loco::succs(node);
+ assert(while_outs.size() == node->output_count());
+
+ std::vector<int32_t> outputs_vec(node->output_count());
+
+ for (auto out : while_outs)
+ {
+ auto while_out = loco::must_cast<luci::CircleWhileOut *>(out);
+ if (while_out->index() >= int32_t(node->output_count()))
+ INTERNAL_EXN("Invalid While output");
+ outputs_vec[while_out->index()] = luci::get_tensor_index(while_out);
+ }
+
+ return outputs_vec;
+ }
+};
+
+} // namespace
+
+namespace luci
+{
+
+void OperationExporterRule::visit(luci::CircleNode *node)
+{
+ auto op_idx = _ctx.md.registerBuiltinOpcode(circle_builtin_operator(node),
+ circle_custom_code(node), node->op_version());
+
+ std::vector<int32_t> inputs_vec;
+ for (uint32_t i = 0; i < node->arity(); ++i)
+ inputs_vec.push_back(luci::get_tensor_index(node->arg(i)));
+ auto inputs = _ctx.builder.CreateVector(inputs_vec);
+
+ OutputVectorExtractor outputs_vec_extractor;
+ auto outputs_vec = node->accept(&outputs_vec_extractor);
+ auto outputs = _ctx.builder.CreateVector(outputs_vec);
+
+ auto builtin_options = circle_builtin_options(node);
+
+ luci::BuiltinOptionsExtractor builtin_options_extractor(_ctx.builder);
+ auto options_offset = node->accept(&builtin_options_extractor);
+
+ // If node is not CircleCustom, null offset(0) is returned
+ auto custom_options = circle_custom_options(_ctx.builder, node);
+
+ auto op_offset = circle::CreateOperator(_ctx.builder, op_idx, inputs, outputs, builtin_options,
+ options_offset, custom_options);
+ _ctx.gd._operators.push_back(op_offset);
+}
+
+} // namespace luci
diff --git a/compiler/luci/export/src/CircleOperationExporterRule.h b/compiler/luci/export/src/CircleOperationExporterRule.h
new file mode 100644
index 000000000..23e7546cf
--- /dev/null
+++ b/compiler/luci/export/src/CircleOperationExporterRule.h
@@ -0,0 +1,76 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CIRCLE_OPERATION_EXPORTER_RULE_H__
+#define __CIRCLE_OPERATION_EXPORTER_RULE_H__
+
+#include "CircleOperationExporter.h"
+
+#include <luci/IR/CircleNode.h>
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+namespace luci
+{
+
+struct ExportContext
+{
+ flatbuffers::FlatBufferBuilder &builder;
+ luci::SerializedModelData &md;
+ luci::SerializedGraphData &gd;
+};
+
+class OperationExporterRule final : public luci::CircleNodeMutableVisitor<void>
+{
+public:
+ OperationExporterRule(ExportContext &ctx) : _ctx{ctx}
+ {
+ // DO NOTHING
+ }
+
+public:
+ // Default export rule
+ void visit(luci::CircleNode *node) final;
+
+ // Non-virtual
+ void visit(luci::CircleConst *) final{/* skip, everything is done in exportOpDefinedTensors */};
+
+ // Virtual
+ void visit(luci::CircleInput *) final {}
+ void visit(luci::CircleOutput *) final {}
+ void visit(luci::CircleOutputDummy *) final {}
+ void visit(luci::CircleOutputExclude *) final {}
+ // Virtual for multiple-outputs
+ void visit(luci::CircleBidirectionalSequenceLSTMOut *) final {}
+ void visit(luci::CircleCustomOut *) final {}
+ void visit(luci::CircleIfOut *) final {}
+ void visit(luci::CircleNonMaxSuppressionV4Out *) final {}
+ void visit(luci::CircleNonMaxSuppressionV5Out *) final {}
+ void visit(luci::CircleSplitOut *) final {}
+ void visit(luci::CircleSplitVOut *) final {}
+ void visit(luci::CircleTopKV2Out *) final {}
+ void visit(luci::CircleUniqueOut *) final {}
+ void visit(luci::CircleUnpackOut *) final {}
+ void visit(luci::CircleVariable *) final {}
+ void visit(luci::CircleWhileOut *) final {}
+
+protected:
+ ExportContext &_ctx;
+};
+
+} // namespace luci
+
+#endif // __CIRCLE_OPERATION_EXPORTER_RULE_H__
diff --git a/compiler/luci/export/src/CircleOps.lst b/compiler/luci/export/src/CircleOps.lst
new file mode 100644
index 000000000..1b6909303
--- /dev/null
+++ b/compiler/luci/export/src/CircleOps.lst
@@ -0,0 +1,154 @@
+#ifndef CIRCLE_NODE
+#error "Define CIRCLE_NODE"
+#endif // CIRCLE_NODE
+
+#ifndef CIRCLE_VNODE
+#error "Define CIRCLE_VNODE"
+#endif // CIRCLE_VNODE
+
+//
+// PLEASE SORT NODE DECLS IN ALPHABETICAL ORDER
+//
+// NOTE : CIRCLE_VNODE does not have any additional parameters
+// because they are not circle builtin operators
+// Please add parameters when they are needed.
+//
+// CIRCLE_NODE(CircleNode, circle::BuiltinOperator, circle::BuiltinOptions)
+// CIRCLE_VNODE(CircleNode)
+//
+
+CIRCLE_NODE(CircleAbs, BuiltinOperator_ABS, BuiltinOptions_AbsOptions)
+CIRCLE_NODE(CircleAdd, BuiltinOperator_ADD, BuiltinOptions_AddOptions)
+CIRCLE_NODE(CircleAddN, BuiltinOperator_ADD_N, BuiltinOptions_AddNOptions)
+CIRCLE_NODE(CircleArgMax, BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions)
+CIRCLE_NODE(CircleArgMin, BuiltinOperator_ARG_MIN, BuiltinOptions_ArgMinOptions)
+CIRCLE_NODE(CircleAveragePool2D, BuiltinOperator_AVERAGE_POOL_2D , BuiltinOptions_Pool2DOptions)
+CIRCLE_NODE(CircleBatchToSpaceND, BuiltinOperator_BATCH_TO_SPACE_ND, BuiltinOptions_BatchToSpaceNDOptions)
+CIRCLE_NODE(CircleBatchMatMul, BuiltinOperator_BATCH_MATMUL, BuiltinOptions_BatchMatMulOptions)
+CIRCLE_NODE(CircleBidirectionalSequenceLSTM, BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, BuiltinOptions_BidirectionalSequenceLSTMOptions)
+CIRCLE_NODE(CircleCast, BuiltinOperator_CAST, BuiltinOptions_CastOptions)
+CIRCLE_NODE(CircleCeil, BuiltinOperator_CEIL, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleConcatenation, BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions)
+CIRCLE_NODE(CircleConv2D, BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions)
+CIRCLE_NODE(CircleCos, BuiltinOperator_COS, BuiltinOptions_CosOptions)
+CIRCLE_NODE(CircleCustom, BuiltinOperator_CUSTOM, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleDepthToSpace, BuiltinOperator_DEPTH_TO_SPACE, BuiltinOptions_DepthToSpaceOptions)
+CIRCLE_NODE(CircleDepthwiseConv2D, BuiltinOperator_DEPTHWISE_CONV_2D, BuiltinOptions_DepthwiseConv2DOptions)
+CIRCLE_NODE(CircleDequantize, BuiltinOperator_DEQUANTIZE, BuiltinOptions_DequantizeOptions)
+CIRCLE_NODE(CircleDiv, BuiltinOperator_DIV, BuiltinOptions_DivOptions)
+CIRCLE_NODE(CircleElu, BuiltinOperator_ELU, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleEqual, BuiltinOperator_EQUAL, BuiltinOptions_EqualOptions)
+CIRCLE_NODE(CircleExp, BuiltinOperator_EXP, BuiltinOptions_ExpOptions)
+CIRCLE_NODE(CircleExpandDims, BuiltinOperator_EXPAND_DIMS, BuiltinOptions_ExpandDimsOptions)
+CIRCLE_NODE(CircleFakeQuant, BuiltinOperator_FAKE_QUANT, BuiltinOptions_FakeQuantOptions)
+CIRCLE_NODE(CircleFill, BuiltinOperator_FILL, BuiltinOptions_FillOptions)
+CIRCLE_NODE(CircleFloor, BuiltinOperator_FLOOR, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleFloorDiv, BuiltinOperator_FLOOR_DIV, BuiltinOptions_FloorDivOptions)
+CIRCLE_NODE(CircleFloorMod, BuiltinOperator_FLOOR_MOD, BuiltinOptions_FloorModOptions)
+CIRCLE_NODE(CircleFullyConnected, BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions)
+CIRCLE_NODE(CircleGather, BuiltinOperator_GATHER, BuiltinOptions_GatherOptions)
+CIRCLE_NODE(CircleGatherNd, BuiltinOperator_GATHER_ND, BuiltinOptions_GatherNdOptions)
+CIRCLE_NODE(CircleGreater, BuiltinOperator_GREATER, BuiltinOptions_GreaterOptions)
+CIRCLE_NODE(CircleGreaterEqual, BuiltinOperator_GREATER_EQUAL, BuiltinOptions_GreaterEqualOptions)
+CIRCLE_NODE(CircleIf, BuiltinOperator_IF, BuiltinOptions_IfOptions)
+CIRCLE_NODE(CircleL2Normalize, BuiltinOperator_L2_NORMALIZATION, BuiltinOptions_L2NormOptions)
+CIRCLE_NODE(CircleL2Pool2D, BuiltinOperator_L2_POOL_2D, BuiltinOptions_Pool2DOptions)
+CIRCLE_NODE(CircleLeakyRelu, BuiltinOperator_LEAKY_RELU, BuiltinOptions_LeakyReluOptions)
+CIRCLE_NODE(CircleLess, BuiltinOperator_LESS, BuiltinOptions_LessOptions)
+CIRCLE_NODE(CircleLessEqual, BuiltinOperator_LESS_EQUAL, BuiltinOptions_LessEqualOptions)
+CIRCLE_NODE(CircleLocalResponseNormalization, BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, BuiltinOptions_LocalResponseNormalizationOptions)
+CIRCLE_NODE(CircleLog, BuiltinOperator_LOG, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleLogicalAnd, BuiltinOperator_LOGICAL_AND, BuiltinOptions_LogicalAndOptions)
+CIRCLE_NODE(CircleLogicalNot, BuiltinOperator_LOGICAL_NOT, BuiltinOptions_LogicalNotOptions)
+CIRCLE_NODE(CircleLogicalOr, BuiltinOperator_LOGICAL_OR, BuiltinOptions_LogicalOrOptions)
+CIRCLE_NODE(CircleLogistic, BuiltinOperator_LOGISTIC, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleLogSoftmax, BuiltinOperator_LOG_SOFTMAX, BuiltinOptions_LogSoftmaxOptions)
+CIRCLE_NODE(CircleMatrixDiag, BuiltinOperator_MATRIX_DIAG, BuiltinOptions_MatrixDiagOptions)
+CIRCLE_NODE(CircleMaxPool2D, BuiltinOperator_MAX_POOL_2D, BuiltinOptions_Pool2DOptions)
+CIRCLE_NODE(CircleMatrixSetDiag, BuiltinOperator_MATRIX_SET_DIAG, BuiltinOptions_MatrixSetDiagOptions)
+CIRCLE_NODE(CircleMaximum, BuiltinOperator_MAXIMUM, BuiltinOptions_MaximumMinimumOptions)
+CIRCLE_NODE(CircleMean, BuiltinOperator_MEAN, BuiltinOptions_ReducerOptions)
+CIRCLE_NODE(CircleMinimum, BuiltinOperator_MINIMUM, BuiltinOptions_MaximumMinimumOptions)
+CIRCLE_NODE(CircleMirrorPad, BuiltinOperator_MIRROR_PAD, BuiltinOptions_MirrorPadOptions)
+CIRCLE_NODE(CircleMul, BuiltinOperator_MUL, BuiltinOptions_MulOptions)
+CIRCLE_NODE(CircleNeg, BuiltinOperator_NEG, BuiltinOptions_NegOptions)
+CIRCLE_NODE(CircleNonMaxSuppressionV4, BuiltinOperator_NON_MAX_SUPPRESSION_V4, BuiltinOptions_NonMaxSuppressionV4Options)
+CIRCLE_NODE(CircleNonMaxSuppressionV5, BuiltinOperator_NON_MAX_SUPPRESSION_V5, BuiltinOptions_NonMaxSuppressionV5Options)
+CIRCLE_NODE(CircleNotEqual, BuiltinOperator_NOT_EQUAL, BuiltinOptions_NotEqualOptions)
+CIRCLE_NODE(CircleOneHot, BuiltinOperator_ONE_HOT, BuiltinOptions_OneHotOptions)
+CIRCLE_NODE(CirclePack, BuiltinOperator_PACK, BuiltinOptions_PackOptions)
+CIRCLE_NODE(CirclePad, BuiltinOperator_PAD, BuiltinOptions_PadOptions)
+CIRCLE_NODE(CirclePadV2, BuiltinOperator_PADV2, BuiltinOptions_PadV2Options)
+CIRCLE_NODE(CirclePow, BuiltinOperator_POW, BuiltinOptions_PowOptions)
+CIRCLE_NODE(CirclePRelu, BuiltinOperator_PRELU, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleQuantize, BuiltinOperator_QUANTIZE, BuiltinOptions_QuantizeOptions)
+CIRCLE_NODE(CircleRange, BuiltinOperator_RANGE, BuiltinOptions_RangeOptions)
+CIRCLE_NODE(CircleRank, BuiltinOperator_RANK, BuiltinOptions_RankOptions)
+CIRCLE_NODE(CircleReduceAny, BuiltinOperator_REDUCE_ANY, BuiltinOptions_ReducerOptions)
+CIRCLE_NODE(CircleReduceMax, BuiltinOperator_REDUCE_MAX, BuiltinOptions_ReducerOptions)
+CIRCLE_NODE(CircleReduceMin, BuiltinOperator_REDUCE_MIN, BuiltinOptions_ReducerOptions)
+CIRCLE_NODE(CircleReduceProd, BuiltinOperator_REDUCE_PROD, BuiltinOptions_ReducerOptions)
+CIRCLE_NODE(CircleRelu, BuiltinOperator_RELU, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleRelu6, BuiltinOperator_RELU6, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleReluN1To1, BuiltinOperator_RELU_N1_TO_1, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleReshape, BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions)
+CIRCLE_NODE(CircleResizeBilinear, BuiltinOperator_RESIZE_BILINEAR, BuiltinOptions_ResizeBilinearOptions)
+CIRCLE_NODE(CircleResizeNearestNeighbor, BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, BuiltinOptions_ResizeNearestNeighborOptions)
+CIRCLE_NODE(CircleReverseSequence, BuiltinOperator_REVERSE_SEQUENCE, BuiltinOptions_ReverseSequenceOptions)
+CIRCLE_NODE(CircleReverseV2, BuiltinOperator_REVERSE_V2, BuiltinOptions_ReverseV2Options)
+CIRCLE_NODE(CircleRound, BuiltinOperator_ROUND, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleRsqrt, BuiltinOperator_RSQRT, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleScatterNd, BuiltinOperator_SCATTER_ND, BuiltinOptions_ScatterNdOptions)
+CIRCLE_NODE(CircleSegmentSum, BuiltinOperator_SEGMENT_SUM, BuiltinOptions_SegmentSumOptions)
+CIRCLE_NODE(CircleSelect, BuiltinOperator_SELECT, BuiltinOptions_SelectOptions)
+CIRCLE_NODE(CircleSelectV2, BuiltinOperator_SELECT_V2, BuiltinOptions_SelectV2Options)
+CIRCLE_NODE(CircleShape, BuiltinOperator_SHAPE, BuiltinOptions_ShapeOptions)
+CIRCLE_NODE(CircleSin, BuiltinOperator_SIN, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleSlice, BuiltinOperator_SLICE, BuiltinOptions_SliceOptions)
+CIRCLE_NODE(CircleSoftmax, BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions)
+CIRCLE_NODE(CircleSpaceToBatchND, BuiltinOperator_SPACE_TO_BATCH_ND, BuiltinOptions_SpaceToBatchNDOptions)
+CIRCLE_NODE(CircleSpaceToDepth, BuiltinOperator_SPACE_TO_DEPTH, BuiltinOptions_SpaceToDepthOptions)
+CIRCLE_NODE(CircleSparseToDense, BuiltinOperator_SPARSE_TO_DENSE, BuiltinOptions_SparseToDenseOptions)
+CIRCLE_NODE(CircleSplit, BuiltinOperator_SPLIT, BuiltinOptions_SplitOptions)
+CIRCLE_NODE(CircleSplitV, BuiltinOperator_SPLIT_V, BuiltinOptions_SplitVOptions)
+CIRCLE_NODE(CircleSqrt, BuiltinOperator_SQRT, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleSquare, BuiltinOperator_SQUARE, BuiltinOptions_SquareOptions)
+CIRCLE_NODE(CircleSquaredDifference, BuiltinOperator_SQUARED_DIFFERENCE, BuiltinOptions_SquaredDifferenceOptions)
+CIRCLE_NODE(CircleSqueeze, BuiltinOperator_SQUEEZE, BuiltinOptions_SqueezeOptions)
+CIRCLE_NODE(CircleStridedSlice, BuiltinOperator_STRIDED_SLICE, BuiltinOptions_StridedSliceOptions)
+CIRCLE_NODE(CircleSub, BuiltinOperator_SUB, BuiltinOptions_SubOptions)
+CIRCLE_NODE(CircleSum, BuiltinOperator_SUM, BuiltinOptions_ReducerOptions)
+CIRCLE_NODE(CircleSVDF, BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions)
+CIRCLE_NODE(CircleTanh, BuiltinOperator_TANH, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleTile, BuiltinOperator_TILE, BuiltinOptions_TileOptions)
+CIRCLE_NODE(CircleTopKV2, BuiltinOperator_TOPK_V2, BuiltinOptions_TopKV2Options)
+CIRCLE_NODE(CircleTranspose, BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions)
+CIRCLE_NODE(CircleTransposeConv, BuiltinOperator_TRANSPOSE_CONV, BuiltinOptions_TransposeConvOptions)
+CIRCLE_NODE(CircleUnidirectionalSequenceLSTM, BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, BuiltinOptions_UnidirectionalSequenceLSTMOptions)
+CIRCLE_NODE(CircleUnique, BuiltinOperator_UNIQUE, BuiltinOptions_UniqueOptions)
+CIRCLE_NODE(CircleUnpack, BuiltinOperator_UNPACK, BuiltinOptions_UnpackOptions)
+CIRCLE_NODE(CircleWhere, BuiltinOperator_WHERE, BuiltinOptions_WhereOptions)
+CIRCLE_NODE(CircleWhile, BuiltinOperator_WHILE, BuiltinOptions_WhileOptions)
+CIRCLE_NODE(CircleZerosLike, BuiltinOperator_ZEROS_LIKE, BuiltinOptions_ZerosLikeOptions)
+// Circle Only
+CIRCLE_NODE(CircleBCQFullyConnected, BuiltinOperator_BCQ_FULLY_CONNECTED, BuiltinOptions_BCQFullyConnectedOptions)
+CIRCLE_NODE(CircleBCQGather, BuiltinOperator_BCQ_GATHER, BuiltinOptions_BCQGatherOptions)
+CIRCLE_NODE(CircleInstanceNorm, BuiltinOperator_INSTANCE_NORM, BuiltinOptions_InstanceNormOptions)
+// Virtual node(s)
+CIRCLE_VNODE(CircleBidirectionalSequenceLSTMOut)
+CIRCLE_VNODE(CircleConst)
+CIRCLE_VNODE(CircleInput)
+CIRCLE_VNODE(CircleOutput)
+CIRCLE_VNODE(CircleOutputDummy)
+CIRCLE_VNODE(CircleOutputExclude)
+CIRCLE_VNODE(CircleCustomOut)
+CIRCLE_VNODE(CircleIfOut)
+CIRCLE_VNODE(CircleNonMaxSuppressionV4Out)
+CIRCLE_VNODE(CircleNonMaxSuppressionV5Out)
+CIRCLE_VNODE(CircleSplitOut)
+CIRCLE_VNODE(CircleSplitVOut)
+CIRCLE_VNODE(CircleTopKV2Out)
+CIRCLE_VNODE(CircleUniqueOut)
+CIRCLE_VNODE(CircleUnpackOut)
+CIRCLE_VNODE(CircleVariable)
+CIRCLE_VNODE(CircleWhileOut)
diff --git a/compiler/luci/export/src/CircleTensorExporter.cpp b/compiler/luci/export/src/CircleTensorExporter.cpp
index 615402aa8..b3bb850cc 100644
--- a/compiler/luci/export/src/CircleTensorExporter.cpp
+++ b/compiler/luci/export/src/CircleTensorExporter.cpp
@@ -67,6 +67,9 @@ public:
luci::SparsityParam *sparsityparam(void) const { return _sparsityparam; }
void sparsityparam(luci::SparsityParam *sp) { _sparsityparam = sp; }
+ bool is_variable(void) const { return _is_variable; }
+ void is_variable(bool v) { _is_variable = v; }
+
private:
std::string _name;
@@ -77,6 +80,8 @@ private:
luci::CircleConst *_content = nullptr;
luci::CircleQuantParam *_quantparam = nullptr;
luci::SparsityParam *_sparsityparam = nullptr;
+
+ bool _is_variable = false;
};
class CircleTensorContext
@@ -145,6 +150,8 @@ void allocateCircleTensorInfo(CircleNode *node, CircleTensorContext &ctx)
tensor_info.quantparam(node->quantparam());
tensor_info.sparsityparam(node->sparsityparam());
+ tensor_info.is_variable(dynamic_cast<luci::CircleVariable *>(node) != nullptr);
+
set_tensor_index(node, tensor_index);
ctx.emplace_back(tensor_info);
@@ -592,9 +599,11 @@ void exportOpDefinedTensor(const CircleTensorInfo &info, FlatBufferBuilder &buil
auto buffer_id = get_buffer_id(builder, md, info.content());
auto name_offset = builder.CreateString(info.name());
- auto tensor_offset =
- CreateTensor(builder, shape_offset, info.dtype(), buffer_id, name_offset, quantparam,
- /*is_variable*/ false, sparsityparam, shape_signature_offset);
+
+ auto is_variable = info.is_variable();
+
+ auto tensor_offset = CreateTensor(builder, shape_offset, info.dtype(), buffer_id, name_offset,
+ quantparam, is_variable, sparsityparam, shape_signature_offset);
gd._tensors.push_back(tensor_offset);
}
diff --git a/compiler/luci/export/src/SerializedData.h b/compiler/luci/export/src/SerializedData.h
index a945eecf7..136a8ac49 100644
--- a/compiler/luci/export/src/SerializedData.h
+++ b/compiler/luci/export/src/SerializedData.h
@@ -23,7 +23,7 @@
#include <luci/IR/ExecutionPlanTable.h>
#include <vector>
-
+#include <string>
#include <unordered_map>
#include <map>
@@ -131,8 +131,8 @@ struct SerializedModelData final
* @param builtin_code
* @return idx of opcode in table of opcodes (see schema)
*/
- uint32_t registerBuiltinOpcode(circle::BuiltinOperator builtin_code, const int32_t op_version);
- uint32_t registerCustomOpcode(const std::string &custom_op);
+ uint32_t registerBuiltinOpcode(circle::BuiltinOperator builtin_code,
+ const std::string &custom_code, const int32_t op_version);
};
// Prerequisites for circle::Model object creation
diff --git a/compiler/luci/import/CMakeLists.txt b/compiler/luci/import/CMakeLists.txt
index 6630cab9f..1b2db23ae 100644
--- a/compiler/luci/import/CMakeLists.txt
+++ b/compiler/luci/import/CMakeLists.txt
@@ -12,13 +12,14 @@ target_include_directories(luci_import PUBLIC include)
target_link_libraries(luci_import PUBLIC luci_lang)
target_link_libraries(luci_import PUBLIC luci_profile)
target_link_libraries(luci_import PUBLIC luci_plan)
-target_link_libraries(luci_import PUBLIC mio_circle)
+target_link_libraries(luci_import PUBLIC mio_circle04)
target_link_libraries(luci_import PRIVATE luci_env)
target_link_libraries(luci_import PRIVATE luci_log)
target_link_libraries(luci_import PRIVATE luci_logex)
target_link_libraries(luci_import PRIVATE nncc_common)
target_link_libraries(luci_import PRIVATE locop)
target_link_libraries(luci_import PRIVATE oops)
+target_link_libraries(luci_import PRIVATE mio_circle04_helper)
install(TARGETS luci_import DESTINATION lib)
install(DIRECTORY include/ DESTINATION include
FILES_MATCHING PATTERN "*.h")
@@ -32,7 +33,3 @@ nnas_find_package(GTest REQUIRED)
GTest_AddTest(luci_import_test ${TESTS})
target_include_directories(luci_import_test PRIVATE src)
target_link_libraries(luci_import_test luci_import)
-target_link_libraries(luci_import_test oops)
-target_link_libraries(luci_import_test luci_plan)
-target_link_libraries(luci_import_test luci_lang)
-target_link_libraries(luci_import_test mio_circle)
diff --git a/compiler/luci/import/include/luci/Import/CircleReader.h b/compiler/luci/import/include/luci/Import/CircleReader.h
index fb38ba90b..a0519f661 100644
--- a/compiler/luci/import/include/luci/Import/CircleReader.h
+++ b/compiler/luci/import/include/luci/Import/CircleReader.h
@@ -35,19 +35,7 @@
namespace luci
{
-bool is_valid(const circle::OperatorCodeT &opcode);
-bool is_valid(const circle::OperatorCode *opcode);
-
-bool is_custom(const circle::OperatorCodeT &opcode);
-bool is_custom(const circle::OperatorCode *opcode);
-
-std::string opcode_name(const circle::OperatorCodeT &opcode);
-std::string opcode_name(const circle::OperatorCode *opcode);
-
-const char *tensor_name(const circle::TensorT &tensor);
const char *tensor_name(const circle::Tensor *tensor);
-
-const circle::QuantizationParametersT *tensor_quantization(const circle::TensorT &tensor);
const circle::QuantizationParameters *tensor_quantization(const circle::Tensor *tensor);
loco::DataType luci_datatype(circle::TensorType type);
@@ -57,14 +45,13 @@ MirrorPadMode luci_mirrorpad_mode(const circle::MirrorPadMode mode);
luci::CircleFullyConnected::WeightsFormat
luci_weights_format(const circle::FullyConnectedOptionsWeightsFormat weights_format);
std::unique_ptr<CircleQuantParam>
-luci_quantparam(const circle::QuantizationParametersT *quantization);
-std::unique_ptr<CircleQuantParam>
luci_quantparam(const circle::QuantizationParameters *quantization);
/// @brief Copy common tensor attributes such as name, type, etc. to node.
-void copy_tensor_attributes(const circle::TensorT &tensor, CircleNode *node);
void copy_tensor_attributes(const circle::Tensor *tensor, CircleNode *node);
+std::string fb_string2std_string(const flatbuffers::String *fb_str);
+
/**
* @brief Wrapper to use flatbuffers::Vector pointer as std::vector entity
*/
@@ -101,13 +88,6 @@ template <typename T> VectorWrapper<T> wrap(const flatbuffers::Vector<T> *vec)
*/
class CircleReader
{
-private: // unpack API
- using CircleBuffers_t = std::vector<std::unique_ptr<circle::BufferT>>;
- using CircleTensors_t = std::vector<std::unique_ptr<circle::TensorT>>;
- using CircleOperators_t = std::vector<std::unique_ptr<circle::OperatorT>>;
- using CircleOperatorCodes_t = std::vector<std::unique_ptr<circle::OperatorCodeT>>;
- using CircleMetadata_t = std::vector<std::unique_ptr<circle::MetadataT>>;
-
private: // direct API
using CircleBuffers = VectorWrapper<flatbuffers::Offset<circle::Buffer>>;
using CircleTensors = VectorWrapper<flatbuffers::Offset<circle::Tensor>>;
@@ -115,40 +95,21 @@ private: // direct API
using CircleOperatorCodes = VectorWrapper<flatbuffers::Offset<circle::OperatorCode>>;
using CircleMetadataSet = VectorWrapper<flatbuffers::Offset<circle::Metadata>>;
- using CircleSubGraphsPtr_t = flatbuffers::Vector<flatbuffers::Offset<circle::SubGraph>>;
- using CircleTensorsPtr_t = flatbuffers::Vector<flatbuffers::Offset<circle::Tensor>>;
-
public:
CircleReader() = default;
-public: // unpack API
- const CircleOperatorCodes_t &opcodes() const { return _model->operator_codes; }
- const CircleBuffers_t &buffers() const { return _model->buffers; }
- const CircleTensors_t &tensors() const { return _current_subgraph->tensors; }
- const CircleOperators_t &operators() const { return _current_subgraph->operators; }
- const std::vector<int32_t> &inputs() const { return _current_subgraph->inputs; }
- const std::vector<int32_t> &outputs() const { return _current_subgraph->outputs; }
- const std::string &name() const { return _current_subgraph->name; }
- const circle::DataFormat &data_format() const { return _current_subgraph->data_format; }
- const CircleMetadata_t &metadata() const { return _model->metadata; }
-
- const CircleTensorsPtr_t *tensors_ptr() const { return _tensors_ptr; }
-
- uint32_t num_subgraph() const { return _model->subgraphs.size(); }
-
- circle::BuiltinOperator builtin_code(const circle::OperatorT &op) const;
- std::string opcode_name(const circle::OperatorT &op) const;
-
public: // direct API
- CircleOperatorCodes native_opcodes() const { return wrap(_native_model->operator_codes()); }
- CircleBuffers native_buffers() const { return wrap(_native_model->buffers()); }
- CircleTensors native_tensors() const { return wrap(_native_subgraph->tensors()); }
- CircleOperators native_operators() const { return wrap(_native_subgraph->operators()); }
- VectorWrapper<int32_t> native_inputs() const { return wrap(_native_subgraph->inputs()); }
- VectorWrapper<int32_t> native_outputs() const { return wrap(_native_subgraph->outputs()); }
- std::string native_name() const { return _native_subgraph->name()->str(); }
- circle::DataFormat native_data_format() const { return _native_subgraph->data_format(); }
- CircleMetadataSet native_metadata() const { return wrap(_native_model->metadata()); }
+ CircleOperatorCodes opcodes() const { return wrap(_model->operator_codes()); }
+ CircleBuffers buffers() const { return wrap(_model->buffers()); }
+ CircleTensors tensors() const { return wrap(_current_subgraph->tensors()); }
+ CircleOperators operators() const { return wrap(_current_subgraph->operators()); }
+ VectorWrapper<int32_t> inputs() const { return wrap(_current_subgraph->inputs()); }
+ VectorWrapper<int32_t> outputs() const { return wrap(_current_subgraph->outputs()); }
+ std::string name() const { return fb_string2std_string(_current_subgraph->name()); }
+ circle::DataFormat data_format() const { return _current_subgraph->data_format(); }
+ CircleMetadataSet metadata() const { return wrap(_model->metadata()); }
+
+ uint32_t num_subgraph() const { return wrap(_model->subgraphs()).size(); }
circle::BuiltinOperator builtin_code(const circle::Operator *op) const;
std::string opcode_name(const circle::Operator *op) const;
@@ -158,12 +119,8 @@ public:
bool select_subgraph(uint32_t subgraph);
private:
- std::unique_ptr<const circle::ModelT> _model;
- const circle::SubGraphT *_current_subgraph{nullptr};
-
- const circle::Model *_native_model{nullptr};
- const CircleTensorsPtr_t *_tensors_ptr{nullptr};
- const circle::SubGraph *_native_subgraph{nullptr};
+ const circle::Model *_model{nullptr};
+ const circle::SubGraph *_current_subgraph{nullptr};
};
} // namespace luci
diff --git a/compiler/luci/import/include/luci/Import/GraphBuilderRegistry.h b/compiler/luci/import/include/luci/Import/GraphBuilderRegistry.h
index b8dc22fdd..93e34a56b 100644
--- a/compiler/luci/import/include/luci/Import/GraphBuilderRegistry.h
+++ b/compiler/luci/import/include/luci/Import/GraphBuilderRegistry.h
@@ -18,6 +18,7 @@
#define __LUCI_IMPORT_GRAPH_BUILDER_REGISTRY_H__
#include "GraphBuilderBase.h"
+#include "NodeBuilder.h"
#include <map>
@@ -32,6 +33,11 @@ struct GraphBuilderSource
* @brief Returns registered GraphBuilder pointer for operator (nullptr if not present)
*/
virtual const GraphBuilderBase *lookup(const circle::BuiltinOperator &op) const = 0;
+
+ /**
+ * @brief Returns registered NodeBuilderBase pointer for type (nullptr if not present)
+ */
+ virtual const NodeBuilderBase *lookup(const NodeBuilderType type) const = 0;
};
/**
@@ -61,6 +67,17 @@ public:
return _builder_map.at(op).get();
}
+ /**
+ * @brief Returns registered NodeBuilderBase pointer for type or nullptr if not registered
+ */
+ const NodeBuilderBase *lookup(const NodeBuilderType type) const final
+ {
+ if (_node_builders.find(type) == _node_builders.end())
+ return (_parent == nullptr) ? nullptr : _parent->lookup(type);
+
+ return _node_builders.at(type).get();
+ }
+
static GraphBuilderRegistry &get()
{
static GraphBuilderRegistry me;
@@ -73,11 +90,17 @@ public:
_builder_map[op] = std::move(builder);
}
+ void add(std::unique_ptr<NodeBuilderBase> &&builder)
+ {
+ _node_builders[builder->builder_type()] = std::move(builder);
+ }
+
private:
const GraphBuilderSource *_parent = nullptr;
private:
std::map<const circle::BuiltinOperator, std::unique_ptr<GraphBuilderBase>> _builder_map;
+ std::map<const NodeBuilderType, std::unique_ptr<NodeBuilderBase>> _node_builders;
};
} // namespace luci
diff --git a/compiler/luci/import/include/luci/Import/NodeBuilder.h b/compiler/luci/import/include/luci/Import/NodeBuilder.h
new file mode 100644
index 000000000..440b491b0
--- /dev/null
+++ b/compiler/luci/import/include/luci/Import/NodeBuilder.h
@@ -0,0 +1,58 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_IMPORT_NODE_BUILDER_H__
+#define __LUCI_IMPORT_NODE_BUILDER_H__
+
+#include "GraphBuilderContext.h"
+#include "GraphBuilderBase.h"
+
+#include <mio/circle/schema_generated.h>
+
+namespace luci
+{
+
+/**
+ * @brief Tensor types which requires separated node
+ */
+enum class NodeBuilderType
+{
+ BUFFER,
+ // TODO Extend this struct here if needed to add new type of NodeBuilderBase
+};
+
+/**
+ * @brief Creates nodes from given Tensor and context
+ */
+class NodeBuilderBase
+{
+public:
+ virtual CircleNode *build(TensorIndex tensor_idx, GraphBuilderContext *context) const = 0;
+ virtual NodeBuilderType builder_type() const = 0;
+};
+
+/**
+ * @brief Placeholder for builders of tensors with different types
+ */
+template <NodeBuilderType Type> class TypedNodeBuilder : public NodeBuilderBase
+{
+public:
+ NodeBuilderType builder_type() const final { return Type; }
+};
+
+} // namespace luci
+
+#endif // __LUCI_IMPORT_NODE_BUILDER_H__
diff --git a/compiler/luci/import/include/luci/Import/Nodes.h b/compiler/luci/import/include/luci/Import/Nodes.h
index f7d22e7aa..7a5045ede 100644
--- a/compiler/luci/import/include/luci/Import/Nodes.h
+++ b/compiler/luci/import/include/luci/Import/Nodes.h
@@ -122,6 +122,7 @@
#include "Nodes/CircleStridedSlice.h"
#include "Nodes/CircleSub.h"
#include "Nodes/CircleSum.h"
+#include "Nodes/CircleSVDF.h"
#include "Nodes/CircleTanh.h"
#include "Nodes/CircleTile.h"
#include "Nodes/CircleTopKV2.h"
@@ -130,6 +131,7 @@
#include "Nodes/CircleUnidirectionalSequenceLSTM.h"
#include "Nodes/CircleUnique.h"
#include "Nodes/CircleUnpack.h"
+#include "Nodes/CircleVariable.h"
#include "Nodes/CircleWhere.h"
#include "Nodes/CircleWhile.h"
#include "Nodes/CircleZerosLike.h"
diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleConst.h b/compiler/luci/import/include/luci/Import/Nodes/CircleConst.h
index 7d4f10a59..9e50ddbde 100644
--- a/compiler/luci/import/include/luci/Import/Nodes/CircleConst.h
+++ b/compiler/luci/import/include/luci/Import/Nodes/CircleConst.h
@@ -17,20 +17,21 @@
#ifndef __LUCI_IMPORT_OP_CIRCLE_CONST_H__
#define __LUCI_IMPORT_OP_CIRCLE_CONST_H__
-#include "luci/Import/GraphBuilderContext.h"
+#include "luci/Import/NodeBuilder.h"
#include <luci/IR/Nodes/CircleConst.h>
-/*
- * @note Circle does not have Const operator.
- * Methods here provide helper that creates CircleConst from
- * Tensor and Buffer in circle flatbuffer file.
- */
-
namespace luci
{
-CircleConst *create_circleconst(GraphBuilderContext *context, int32_t tensor_index);
+/**
+ * @brief Builder creates CircleConst node from Tensor with buffer.
+ */
+class CircleConstNodeBuilder : public TypedNodeBuilder<NodeBuilderType::BUFFER>
+{
+public:
+ CircleNode *build(TensorIndex tensor_index, GraphBuilderContext *ctx) const final;
+};
} // namespace luci
diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleSVDF.h b/compiler/luci/import/include/luci/Import/Nodes/CircleSVDF.h
new file mode 100644
index 000000000..a91f66019
--- /dev/null
+++ b/compiler/luci/import/include/luci/Import/Nodes/CircleSVDF.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_IMPORT_OP_CIRCLE_SVDF_H__
+#define __LUCI_IMPORT_OP_CIRCLE_SVDF_H__
+
+#include "luci/Import/GraphBuilder.h"
+
+namespace luci
+{
+
+class CircleSVDFBuilder : public GraphBuilder
+{
+public:
+ bool validate(const ValidateArgs &args) const final;
+
+private:
+ CircleNode *build_node(const circle::OperatorT &op, const std::vector<CircleNode *> &inputs,
+ loco::Graph *graph) const final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_IMPORT_OP_CIRCLE_SVDF_H__
diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleVariable.h b/compiler/luci/import/include/luci/Import/Nodes/CircleVariable.h
new file mode 100644
index 000000000..4d8961fa5
--- /dev/null
+++ b/compiler/luci/import/include/luci/Import/Nodes/CircleVariable.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_IMPORT_OP_CIRCLE_VARIABLE_H__
+#define __LUCI_IMPORT_OP_CIRCLE_VARIABLE_H__
+
+#include "luci/Import/GraphBuilderContext.h"
+
+#include <luci/IR/Nodes/CircleVariable.h>
+
+/*
+ * @note Circle does not have node for variable tensor
+ * Methods here provide helper that creates CircleVariable from
+ * Tensor having is_variable true value.
+ */
+
+namespace luci
+{
+
+CircleVariable *create_circlevariable(GraphBuilderContext *context, int32_t tensor_index);
+
+} // namespace luci
+
+#endif // __LUCI_IMPORT_OP_CIRCLE_VARIABLE_H__
diff --git a/compiler/luci/import/src/CircleImportMetadata.cpp b/compiler/luci/import/src/CircleImportMetadata.cpp
index 42dcebdaa..9c1fe7356 100644
--- a/compiler/luci/import/src/CircleImportMetadata.cpp
+++ b/compiler/luci/import/src/CircleImportMetadata.cpp
@@ -21,8 +21,10 @@
namespace
{
-uint32_t read_u32(const std::vector<uint8_t> &buffer, uint32_t idx)
+template <typename VECTORTYPE> uint32_t read_u32(const VECTORTYPE &buffer, uint32_t idx)
{
+ static_assert(std::is_same<typename VECTORTYPE::value_type, uint8_t>::value, "Types mismatch!");
+
uint32_t val = 0;
val += (buffer.at(idx + 0) << 0 * 8);
val += (buffer.at(idx + 1) << 1 * 8);
@@ -37,9 +39,11 @@ namespace
{
// 'source_table' is decoded to std::map<uint32_t, std::string> format.
-const std::map<uint32_t, std::string>
-decoded_source_table(const std::vector<uint8_t> &source_table_data)
+template <typename VECTORTYPE>
+const std::map<uint32_t, std::string> decoded_source_table(const VECTORTYPE &source_table_data)
{
+ static_assert(std::is_same<typename VECTORTYPE::value_type, uint8_t>::value, "Types mismatch!");
+
std::map<uint32_t, std::string> source_id_name_map;
uint32_t idx = 0;
@@ -86,9 +90,11 @@ decoded_source_table(const std::vector<uint8_t> &source_table_data)
}
// 'op_table' is decoded to std::map<uint32_t, std::set<uint32_t>> format.
-const std::map<uint32_t, std::set<uint32_t>>
-decoded_op_table(const std::vector<uint8_t> &op_table_data)
+template <typename VECTORTYPE>
+const std::map<uint32_t, std::set<uint32_t>> decoded_op_table(const VECTORTYPE &op_table_data)
{
+ static_assert(std::is_same<typename VECTORTYPE::value_type, uint8_t>::value, "Types mismatch!");
+
std::map<uint32_t, std::set<uint32_t>> node_source_ids_map;
uint32_t idx = 0;
@@ -135,9 +141,11 @@ decoded_op_table(const std::vector<uint8_t> &op_table_data)
}
// 'execution_plan_table' is decoded to std::map<uint32_t, std::vector<uint32_t>> format.
-const luci::ExecutionPlanTable
-decoded_execution_plan(const std::vector<uint8_t> &execution_plan_data)
+template <typename VECTORTYPE>
+const luci::ExecutionPlanTable decoded_execution_plan(const VECTORTYPE &execution_plan_data)
{
+ static_assert(std::is_same<typename VECTORTYPE::value_type, uint8_t>::value, "Types mismatch!");
+
luci::ExecutionPlanTable execution_plan_table;
uint32_t idx = 0;
@@ -156,6 +164,10 @@ decoded_execution_plan(const std::vector<uint8_t> &execution_plan_data)
idx += sizeof(uint32_t);
uint32_t size = read_u32(execution_plan_data, idx);
+
+ if (size == 0)
+ throw std::runtime_error("Op table decode error : empty execution plan entry");
+
idx += sizeof(uint32_t);
if (idx + sizeof(uint32_t) * size > execution_plan_data.size())
@@ -190,19 +202,22 @@ namespace luci
CircleImportMetadata::CircleImportMetadata(const luci::CircleReader &reader)
{
- const auto &metadata = reader.metadata();
+ const auto metadata = reader.metadata();
for (uint32_t i = 0; i < metadata.size(); ++i)
{
- const circle::MetadataT &meta = *metadata[i];
+ const auto *meta = metadata[i];
+ assert(meta != nullptr);
- assert(meta.buffer < reader.buffers().size());
- const std::vector<uint8_t> &buffer = reader.buffers()[meta.buffer]->data;
+ assert(meta->buffer() < reader.buffers().size());
+ assert(reader.buffers()[meta->buffer()] != nullptr);
+ const auto buffer = luci::wrap(reader.buffers()[meta->buffer()]->data());
- if (meta.name.compare("ONE_op_table") == 0)
+ assert(meta->name() != nullptr);
+ if (meta->name()->str().compare("ONE_op_table") == 0)
_op_table = decoded_op_table(buffer);
- else if (meta.name.compare("ONE_source_table") == 0)
+ else if (meta->name()->str().compare("ONE_source_table") == 0)
_source_table = decoded_source_table(buffer);
- else if (meta.name.compare("ONE_execution_plan_table") == 0)
+ else if (meta->name()->str().compare("ONE_execution_plan_table") == 0)
_execution_plan_table = decoded_execution_plan(buffer);
}
}
diff --git a/compiler/luci/import/src/CircleReader.cpp b/compiler/luci/import/src/CircleReader.cpp
index 14917ba06..a42c3f913 100644
--- a/compiler/luci/import/src/CircleReader.cpp
+++ b/compiler/luci/import/src/CircleReader.cpp
@@ -16,6 +16,9 @@
#include "luci/Import/CircleReader.h"
+#include <mio_circle/Helper.h>
+
+#include <algorithm>
#include <memory>
#include <sstream>
#include <string>
@@ -23,103 +26,14 @@
namespace luci
{
-bool is_valid(const circle::OperatorCodeT &opcode)
-{
- circle::BuiltinOperator code = opcode.builtin_code;
- return (circle::BuiltinOperator_MIN <= code && code <= circle::BuiltinOperator_MAX);
-}
-
-bool is_valid(const circle::OperatorCode *opcode)
-{
- assert(opcode != nullptr);
- circle::BuiltinOperator code = opcode->builtin_code();
- return (circle::BuiltinOperator_MIN <= code && code <= circle::BuiltinOperator_MAX);
-}
-
-bool is_custom(const circle::OperatorCodeT &opcode)
-{
- circle::BuiltinOperator code = opcode.builtin_code;
- return (code == circle::BuiltinOperator_CUSTOM);
-}
-
-bool is_custom(const circle::OperatorCode *opcode)
-{
- assert(opcode != nullptr);
- circle::BuiltinOperator code = opcode->builtin_code();
- return (code == circle::BuiltinOperator_CUSTOM);
-}
-
-std::string opcode_name(const circle::OperatorCodeT &opcode)
-{
- if (!is_valid(opcode))
- {
- std::ostringstream oss;
- oss << "(invalid)";
- return oss.str();
- }
-
- if (is_custom(opcode))
- {
- if (opcode.custom_code.empty())
- return "(invalid custom)";
-
- return opcode.custom_code;
- }
-
- circle::BuiltinOperator code = opcode.builtin_code;
- return circle::EnumNameBuiltinOperator(code);
-}
-
-std::string opcode_name(const circle::OperatorCode *opcode)
-{
- assert(opcode != nullptr);
-
- if (!is_valid(opcode))
- {
- std::ostringstream oss;
- oss << "(invalid)";
- return oss.str();
- }
-
- if (is_custom(opcode))
- {
- auto custom_code = opcode->custom_code()->str();
- if (custom_code.empty())
- return "(invalid custom)";
-
- return custom_code;
- }
-
- circle::BuiltinOperator code = opcode->builtin_code();
- return circle::EnumNameBuiltinOperator(code);
-}
-
-const char *tensor_name(const circle::TensorT &tensor)
-{
- static const char *kEmptyTensorName = "(noname)";
-
- if (!tensor.name.empty())
- return tensor.name.c_str();
-
- return kEmptyTensorName;
-}
-
const char *tensor_name(const circle::Tensor *tensor)
{
assert(tensor != nullptr);
- static const char *kEmptyTensorName = "(noname)";
- const auto tensor_name = tensor->name()->c_str();
-
- if (!std::string(tensor_name).empty())
- return tensor_name;
+ if (tensor->name() == nullptr || std::string(tensor->name()->c_str()).empty())
+ return "(noname)";
- return kEmptyTensorName;
-}
-
-const circle::QuantizationParametersT *tensor_quantization(const circle::TensorT &tensor)
-{
- return tensor.quantization.get();
+ return tensor->name()->c_str();
}
const circle::QuantizationParameters *tensor_quantization(const circle::Tensor *tensor)
@@ -334,41 +248,6 @@ std::unique_ptr<SparsityParam> luci_sparsityparam(const circle::SparsityParamete
return luci_sparsityparam(&sparsity);
}
-void copy_tensor_attributes(const circle::TensorT &tensor, CircleNode *node)
-{
- node->name(tensor_name(tensor));
- node->dtype(luci_datatype(tensor.type));
-
- assert(tensor.shape_signature.size() == 0 ||
- tensor.shape_signature.size() == tensor.shape.size());
-
- std::vector<int32_t> dims = tensor.shape; // in NHWC
- node->rank(dims.size());
- for (uint32_t r = 0; r < dims.size(); ++r)
- {
- if (tensor.shape_signature.size() > 0 && tensor.shape_signature.at(r) == -1)
- node->dim(r).unset();
- else
- node->dim(r).set(dims[r]);
- }
-
- const auto *quantization = tensor.quantization.get();
- if (quantization != nullptr)
- {
- auto quantparam = luci_quantparam(quantization);
- if (quantparam)
- node->quantparam(std::move(quantparam));
- }
-
- const auto *sparsity = tensor.sparsity.get();
- if (sparsity != nullptr)
- {
- auto sparsityparam = luci_sparsityparam(sparsity);
- if (sparsityparam)
- node->sparsityparam(std::move(sparsityparam));
- }
-}
-
void copy_tensor_attributes(const circle::Tensor *tensor, CircleNode *node)
{
assert(tensor != nullptr);
@@ -408,63 +287,60 @@ void copy_tensor_attributes(const circle::Tensor *tensor, CircleNode *node)
}
}
-circle::BuiltinOperator CircleReader::builtin_code(const circle::OperatorT &op) const
+std::string fb_string2std_string(const flatbuffers::String *fb_str)
{
- const auto &op_codes = opcodes();
- uint32_t index = op.opcode_index;
+ return fb_str == nullptr ? "" : fb_str->str();
+}
+
+circle::BuiltinOperator CircleReader::builtin_code(const circle::Operator *op) const
+{
+ assert(op != nullptr);
+
+ const auto op_codes = opcodes();
+ uint32_t index = op->opcode_index();
assert(index < op_codes.size());
- const circle::OperatorCodeT &opcode = *op_codes[index];
+ const auto opcode = op_codes[index];
+ assert(opcode != nullptr);
- return opcode.builtin_code;
+ return mio::circle::builtin_code_neutral(opcode);
}
-std::string CircleReader::opcode_name(const circle::OperatorT &op) const
+std::string CircleReader::opcode_name(const circle::Operator *op) const
{
- const auto &op_codes = opcodes();
- uint32_t index = op.opcode_index;
- assert(index < op_codes.size());
- const circle::OperatorCodeT &opcode = *op_codes[index];
+ assert(op != nullptr);
- if (!is_valid(opcode))
- {
- std::ostringstream oss;
- oss << "(invalid: " << index << ")";
- return oss.str();
- }
+ const auto op_codes = opcodes();
+ uint32_t index = op->opcode_index();
+ assert(index < op_codes.size());
+ const auto opcode = op_codes[index];
- return ::luci::opcode_name(opcode);
+ return mio::circle::opcode_name(opcode);
}
bool CircleReader::parse(const circle::Model *model)
{
assert(model != nullptr);
- _model.reset(model->UnPack());
-
// for direct pointer access
- _native_model = model;
+ _model = model;
return true;
}
bool CircleReader::select_subgraph(uint32_t sgindex)
{
- if (_model->subgraphs.size() <= sgindex)
+ if (num_subgraph() <= sgindex)
{
assert(false);
return false;
}
- _current_subgraph = _model->subgraphs[sgindex].get();
-
// for direct pointer access
- auto subgraphs = _native_model->subgraphs();
+ auto subgraphs = _model->subgraphs();
assert(subgraphs != nullptr);
- _native_subgraph = subgraphs->Get(sgindex);
- assert(_native_subgraph != nullptr);
-
- _tensors_ptr = _native_subgraph->tensors();
+ _current_subgraph = subgraphs->Get(sgindex);
+ assert(_current_subgraph != nullptr);
return true;
}
diff --git a/compiler/luci/import/src/GraphBuilder.cpp b/compiler/luci/import/src/GraphBuilder.cpp
index 356501c2f..59a08b546 100644
--- a/compiler/luci/import/src/GraphBuilder.cpp
+++ b/compiler/luci/import/src/GraphBuilder.cpp
@@ -29,10 +29,9 @@ CircleNode *GraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext
const std::vector<int32_t> &inputs = op.inputs;
const std::vector<int32_t> &outputs = op.outputs;
- const auto &tensors = context->reader()->tensors();
- const auto &opcodes = context->reader()->opcodes();
- auto tensors_ptr = context->reader()->tensors_ptr();
- assert(tensors_ptr != nullptr);
+ const auto tensors = context->reader()->tensors();
+ const auto opcodes = context->reader()->opcodes();
+ assert(!tensors.null());
std::vector<CircleNode *> input_nodes;
for (const int32_t input_tensor_index : inputs)
@@ -60,16 +59,18 @@ CircleNode *GraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext
// Set up node parameters.
assert(outputs.size() == 1);
{
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
+ const auto output_tensor = tensors[outputs[0]];
+ assert(output_tensor != nullptr);
copy_tensor_attributes(output_tensor, node);
// mark shape_status
- if (tensors_ptr->Get(outputs[0])->shape() == nullptr)
+ if (output_tensor->shape() == nullptr)
node->shape_status(ShapeStatus::NOSHAPE);
else
node->shape_status(ShapeStatus::VALID);
// mark operator version
- node->op_version(opcodes[op.opcode_index].get()->version);
+ assert(opcodes[op.opcode_index] != nullptr);
+ node->op_version(opcodes[op.opcode_index]->version());
}
// Register node's only output.
diff --git a/compiler/luci/import/src/GraphBuilderMultiOutput.cpp b/compiler/luci/import/src/GraphBuilderMultiOutput.cpp
index be553f4c0..4df8d1e5a 100644
--- a/compiler/luci/import/src/GraphBuilderMultiOutput.cpp
+++ b/compiler/luci/import/src/GraphBuilderMultiOutput.cpp
@@ -30,10 +30,9 @@ CircleNode *GraphBuilderMultiOutput::build(const circle::OperatorT &op,
const std::vector<int32_t> &inputs = op.inputs;
const std::vector<int32_t> &outputs = op.outputs;
- const auto &tensors = context->reader()->tensors();
- const auto &opcodes = context->reader()->opcodes();
- auto tensors_ptr = context->reader()->tensors_ptr();
- assert(tensors_ptr != nullptr);
+ const auto tensors = context->reader()->tensors();
+ const auto opcodes = context->reader()->opcodes();
+ assert(!tensors.null());
std::vector<CircleNode *> input_nodes;
for (const int32_t input_tensor_index : inputs)
@@ -64,12 +63,14 @@ CircleNode *GraphBuilderMultiOutput::build(const circle::OperatorT &op,
if (output_count > 0)
{
// Let's use attributes from output 0 for this node
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
+ const auto output_tensor = tensors[outputs[0]];
+ assert(output_tensor != nullptr);
node->name(tensor_name(output_tensor));
- node->dtype(luci_datatype(output_tensor.type));
+ node->dtype(luci_datatype(output_tensor->type()));
// mark operator version
- node->op_version(opcodes[op.opcode_index].get()->version);
+ assert(opcodes[op.opcode_index] != nullptr);
+ node->op_version(opcodes[op.opcode_index]->version());
// NOTE We don't set quantization for multiple output nodes but to virtual outputs
}
@@ -77,7 +78,8 @@ CircleNode *GraphBuilderMultiOutput::build(const circle::OperatorT &op,
// Create virtual outputs of Virtual Output node(s)
for (uint32_t n = 0; n < output_count; ++n)
{
- const circle::TensorT &output_tensor = *tensors[outputs[n]];
+ const auto output_tensor = tensors[outputs[n]];
+ assert(output_tensor != nullptr);
BuildOutArgs boa(node, n);
auto *nodeout = build_out(boa);
@@ -85,7 +87,7 @@ CircleNode *GraphBuilderMultiOutput::build(const circle::OperatorT &op,
copy_tensor_attributes(output_tensor, nodeout);
// NOTE name of CxxxOut nodes may have same name
// mark shape_status
- if (tensors_ptr->Get(outputs[n])->shape() == nullptr)
+ if (output_tensor->shape() == nullptr)
nodeout->shape_status(ShapeStatus::NOSHAPE);
else
nodeout->shape_status(ShapeStatus::VALID);
diff --git a/compiler/luci/import/src/GraphBuilderRegistry.cpp b/compiler/luci/import/src/GraphBuilderRegistry.cpp
index df07d9e48..fe2d830e9 100644
--- a/compiler/luci/import/src/GraphBuilderRegistry.cpp
+++ b/compiler/luci/import/src/GraphBuilderRegistry.cpp
@@ -131,6 +131,7 @@ GraphBuilderRegistry::GraphBuilderRegistry()
CIRCLE_NODE(STRIDED_SLICE, CircleStridedSliceGraphBuilder); // 45
CIRCLE_NODE(SUB, CircleSubGraphBuilder); // 41
CIRCLE_NODE(SUM, CircleSumGraphBuilder); // 74
+ CIRCLE_NODE(SVDF, CircleSVDFBuilder); // 27
CIRCLE_NODE(TANH, CircleTanhGraphBuilder); // 28
CIRCLE_NODE(TILE, CircleTileGraphBuilder); // 69
CIRCLE_NODE(TOPK_V2, CircleTopKV2GraphBuilder); // 48
@@ -150,7 +151,6 @@ GraphBuilderRegistry::GraphBuilderRegistry()
// BuiltinOperator_LSH_PROJECTION = 15,
// BuiltinOperator_LSTM = 16,
// BuiltinOperator_RNN = 24,
- // BuiltinOperator_SVDF = 27,
// BuiltinOperator_CONCAT_EMBEDDINGS = 29,
// BuiltinOperator_SKIP_GRAM = 30,
// BuiltinOperator_CALL = 31,
@@ -161,6 +161,13 @@ GraphBuilderRegistry::GraphBuilderRegistry()
// BuiltinOperator_ARG_MAX = 56,
// BuiltinOperator_HARD_SWISH = 117,
// BuiltinOperator_DENSIFY = 124,
+
+ // Register builders for nodes which not handles in builders registered above.
+#define CIRCLE_NODE(CLASS) add(std::make_unique<CLASS>())
+
+ CIRCLE_NODE(CircleConstNodeBuilder);
+
+#undef CIRCLE_NODE
}
} // namespace luci
diff --git a/compiler/luci/import/src/Importer.cpp b/compiler/luci/import/src/Importer.cpp
index 3f7f78591..15de03df2 100644
--- a/compiler/luci/import/src/Importer.cpp
+++ b/compiler/luci/import/src/Importer.cpp
@@ -23,6 +23,7 @@
#include "luci/Import/GraphBuilderRegistry.h"
#include "luci/Import/CircleReader.h"
#include "luci/Import/Nodes/CircleConst.h"
+#include "luci/Import/Nodes/CircleVariable.h"
#include <luci/IR/Module.h>
#include <luci/IR/CircleNodes.h>
@@ -50,18 +51,18 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r
luci::GraphBuilderContext gb_context(graph, &reader, nodefinder.get(), tensoroutputs.get());
- const auto &operators = reader.operators();
- const auto &tensors = reader.tensors();
- auto tensors_ptr = reader.tensors_ptr();
- assert(tensors_ptr != nullptr);
+ const auto operators = reader.operators();
+ const auto tensors = reader.tensors();
+ assert(!tensors.null());
auto circle_metadata = std::make_unique<luci::CircleImportMetadata>(reader);
// build a cache to identify if a tensor is output of an operator
// if this is set, we should not create a CircleConst for this tensor
for (uint32_t i = 0; i < operators.size(); ++i)
{
- const circle::OperatorT &op = *operators[i];
- const auto &outputs = op.outputs;
+ const auto op = operators[i];
+ assert(op != nullptr);
+ const auto outputs = luci::wrap(op->outputs());
for (uint32_t j = 0; j < outputs.size(); ++j)
{
@@ -77,10 +78,11 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r
{
auto input_node = graph->nodes()->create<luci::CircleInput>();
assert(input_node != nullptr);
- const circle::TensorT &tensor = *tensors[input];
+ const auto tensor = tensors[input];
+ assert(tensor != nullptr);
luci::copy_tensor_attributes(tensor, input_node);
- if (tensors_ptr->Get(input)->shape() == nullptr)
+ if (tensor->shape() == nullptr)
input_node->shape_status(luci::ShapeStatus::NOSHAPE);
else
input_node->shape_status(luci::ShapeStatus::VALID);
@@ -101,16 +103,18 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r
// Data type
graph_input->dtype(input_node->dtype());
- assert(tensor.shape_signature.size() == 0 ||
- tensor.shape_signature.size() == tensor.shape.size());
+ const auto tensor_shape_signature = luci::wrap(tensor->shape_signature());
+ const auto tensor_shape = luci::wrap(tensor->shape());
+ assert(tensor_shape_signature.size() == 0 ||
+ tensor_shape_signature.size() == tensor_shape.size());
// Shape of GraphInput
auto input_shape = std::make_unique<loco::TensorShape>();
- const std::vector<int32_t> &input_dims = tensor.shape; // in NHWC
+ const auto &input_dims = tensor_shape; // in NHWC
input_shape->rank(input_dims.size());
for (uint32_t r = 0; r < input_dims.size(); ++r)
{
- if (tensor.shape_signature.size() > 0 && tensor.shape_signature.at(r) == -1)
+ if (tensor_shape_signature.size() > 0 && tensor_shape_signature.at(r) == -1)
input_shape->dim(r).unset();
else
input_shape->dim(r).set(input_dims[r]);
@@ -118,15 +122,28 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r
graph_input->shape(std::move(input_shape));
}
- // Create CircleConst nodes for constant tensors.
+ // Create CircleNodes for constant tensors.
// NOTE Origin is intentionally not provided for constants.
+ auto const_builder = source.lookup(luci::NodeBuilderType::BUFFER);
+ if (not const_builder)
+ throw oops::UserExn("Not supported", "tensor with buffer builder");
+
for (uint32_t i = 0; i < tensors.size(); ++i)
{
- luci::CircleConst *const_node = luci::create_circleconst(&gb_context, i);
+ auto *const_node = const_builder->build(i, &gb_context);
if (const_node != nullptr)
nodefinder->enroll(i, const_node);
}
+ // Create CircleVariable nodes for variable tensors
+ // TODO Add Origin if needed, skip for now
+ for (uint32_t i = 0; i < tensors.size(); ++i)
+ {
+ luci::CircleVariable *variable_node = luci::create_circlevariable(&gb_context, i);
+ if (variable_node != nullptr)
+ nodefinder->enroll(i, variable_node);
+ }
+
// Import the operators.
// Note that operators in model are stored in execution order. This means that when importing
// an operator, its input operators have already been imported. We exploit this fact to set up
@@ -134,18 +151,23 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r
auto origin_table = circle_metadata->origin_table();
for (uint32_t i = 0; i < operators.size(); ++i)
{
- const circle::OperatorT &op = *operators[i];
+ const auto op = operators[i];
+ assert(op != nullptr);
circle::BuiltinOperator builtincode = reader.builtin_code(op);
if (const auto *builder = source.lookup(builtincode))
{
- luci::GraphBuilder::ValidateArgs args(op, reader);
+ // create temporary unpack API obj
+ circle::OperatorT oper_t;
+ op->UnPackTo(&oper_t);
+
+ luci::GraphBuilder::ValidateArgs args(oper_t, reader);
if (!builder->validate(args))
{
throw oops::UserExn("Invalid operator", reader.opcode_name(op));
}
- auto built_op = builder->build(op, &gb_context);
+ auto built_op = builder->build(oper_t, &gb_context);
set_node_id(built_op, i);
if (origin_table.find(i) != origin_table.end())
add_origin(built_op, origin_table.at(i));
@@ -161,7 +183,8 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r
// graph outputs
for (auto output : reader.outputs())
{
- const circle::TensorT &tensor = *tensors[output];
+ const auto tensor = tensors[output];
+ assert(tensor != nullptr);
auto output_node = graph->nodes()->create<luci::CircleOutput>();
assert(output_node != nullptr);
@@ -178,7 +201,7 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r
output_node->from(output_dummy);
luci::copy_tensor_attributes(tensor, output_dummy);
- if (tensors_ptr->Get(output)->shape() == nullptr)
+ if (tensor->shape() == nullptr)
output_dummy->shape_status(luci::ShapeStatus::NOSHAPE);
else
output_dummy->shape_status(luci::ShapeStatus::VALID);
@@ -197,16 +220,18 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r
// Set GraphInputOutputIndex for graph
output_node->index(graph_output->index());
- assert(tensor.shape_signature.size() == 0 ||
- tensor.shape_signature.size() == tensor.shape.size());
+ const auto tensor_shape_signature = luci::wrap(tensor->shape_signature());
+ const auto tensor_shape = luci::wrap(tensor->shape());
+ assert(tensor_shape_signature.size() == 0 ||
+ tensor_shape_signature.size() == tensor_shape.size());
// Shape of Output
auto output_shape = std::make_unique<loco::TensorShape>();
- const std::vector<int32_t> &output_dims = tensor.shape; // in NHWC
+ const auto &output_dims = tensor_shape; // in NHWC
output_shape->rank(output_dims.size());
for (uint32_t r = 0; r < output_dims.size(); ++r)
{
- if (tensor.shape_signature.size() > 0 && tensor.shape_signature.at(r) == -1)
+ if (tensor_shape_signature.size() > 0 && tensor_shape_signature.at(r) == -1)
output_shape->dim(r).unset();
else
output_shape->dim(r).set(output_dims[r]);
@@ -214,7 +239,7 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r
graph_output->shape(std::move(output_shape));
// Data type
- auto dtype = luci::luci_datatype(tensor.type);
+ auto dtype = luci::luci_datatype(tensor->type());
graph_output->dtype(dtype);
}
}
@@ -355,7 +380,12 @@ std::unique_ptr<Module> Importer::importModule(const circle::Model *model) const
{
if (auto circle_node = dynamic_cast<luci::CircleNode *>(node))
{
+ if (execution_plan_table.count(node_position) == 0)
+ continue;
+
auto node_plan = execution_plan_table[node_position];
+ assert(node_plan.size() > 0);
+
luci::add_execution_plan(
circle_node,
luci::CircleNodeExecutionPlan(
diff --git a/compiler/luci/import/src/Importer.test.cpp b/compiler/luci/import/src/Importer.test.cpp
index d963b4d49..91e4860ea 100644
--- a/compiler/luci/import/src/Importer.test.cpp
+++ b/compiler/luci/import/src/Importer.test.cpp
@@ -23,7 +23,7 @@
#include <mio/circle/schema_generated.h>
#include <flatbuffers/flatbuffers.h>
-TEST(TensorFlowLiteImport, Dummy)
+TEST(CircleImport, Dummy)
{
luci::Importer import;
@@ -68,6 +68,7 @@ struct BasicCircleModel
{
uint32_t id = model->operator_codes.size();
model->operator_codes.push_back(std::make_unique<circle::OperatorCodeT>());
+ model->operator_codes[id]->deprecated_builtin_code = opcode;
model->operator_codes[id]->builtin_code = opcode;
model->operator_codes[id]->version = 1;
return id;
@@ -179,7 +180,7 @@ struct SimpleRELUModel : public BasicCircleModel
/**
* This test checks that one op RELU model with execution plan is successfully imported
*/
-TEST(TensorFlowLiteImport, simple_plan)
+TEST(CircleImport, simple_plan)
{
SimpleRELUModel model;
auto metadata_buffer_id = model.add_buffer();
@@ -240,7 +241,7 @@ TEST(TensorFlowLiteImport, simple_plan)
/**
* This test checks that model with incomplete execution plan is successfully imported
*/
-TEST(TensorFlowLiteImport, DISABLED_incomplete_plan_NEG)
+TEST(CircleImport, incomplete_plan_NEG)
{
SimpleRELUModel model;
auto metadata_buffer_id = model.add_buffer();
@@ -287,7 +288,7 @@ TEST(TensorFlowLiteImport, DISABLED_incomplete_plan_NEG)
/**
* This test checks that corrupted execution plan induce exception
*/
-TEST(TensorFlowLiteImport, corrupted_plan_NEG)
+TEST(CircleImport, corrupted_plan_NEG)
{
SimpleRELUModel model;
auto metadata_buffer_id = model.add_buffer();
@@ -309,3 +310,44 @@ TEST(TensorFlowLiteImport, corrupted_plan_NEG)
ASSERT_ANY_THROW(import.importModule(model_ptr));
}
+
+/**
+ * This test checks that empty execution plan entry induce exception
+ */
+TEST(CircleImport, corrupted_plan_entry_NEG)
+{
+ SimpleRELUModel model;
+ auto metadata_buffer_id = model.add_buffer();
+ model.add_plan_metadata(metadata_buffer_id);
+
+ model.add_plan_entry(metadata_buffer_id, 1, {100});
+
+ // add corrupted entry with 0 size
+ {
+ auto &buffer = model.model->buffers[metadata_buffer_id]->data;
+ auto old_size = buffer.size();
+
+ // Allocate space for new entry:
+ // 4 bytes for entry id
+ // 4 bytes for entry size
+ buffer.resize(old_size + 8);
+ uint32_t *number_of_entries_ptr = reinterpret_cast<uint32_t *>(buffer.data());
+ *number_of_entries_ptr += 1;
+
+ uint32_t *entry_data_ptr = reinterpret_cast<uint32_t *>(buffer.data() + old_size);
+
+ entry_data_ptr[0] = *number_of_entries_ptr - 1; // entry id
+ entry_data_ptr[1] = 0; // entry size
+ }
+
+ model.add_plan_entry(metadata_buffer_id, 3, {200});
+
+ flatbuffers::FlatBufferBuilder fbb;
+ auto model_offset = circle::Model::Pack(fbb, model.model.get(), nullptr);
+ circle::FinishModelBuffer(fbb, model_offset);
+
+ auto model_ptr = circle::GetModel(fbb.GetBufferPointer());
+ luci::Importer import;
+
+ ASSERT_ANY_THROW(import.importModule(model_ptr));
+}
diff --git a/compiler/luci/import/src/Nodes/CircleCast.cpp b/compiler/luci/import/src/Nodes/CircleCast.cpp
index 3e8c08bfa..acde823b1 100644
--- a/compiler/luci/import/src/Nodes/CircleCast.cpp
+++ b/compiler/luci/import/src/Nodes/CircleCast.cpp
@@ -42,12 +42,14 @@ bool CircleCastGraphBuilder::validate(const ValidateArgs &args) const
const auto *options = args.op.builtin_options.AsCastOptions();
if (options != nullptr)
{
- const auto &tensors = args.reader.tensors();
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
+ const auto tensors = args.reader.tensors();
+ const auto output_tensor = tensors[outputs[0]];
+ assert(output_tensor != nullptr);
auto name = tensor_name(output_tensor);
- const auto &tensor_in = tensors.at(inputs.at(0));
- if (tensor_in->type != options->in_data_type)
+ const auto tensor_in = tensors.at(inputs.at(0));
+ assert(tensor_in != nullptr);
+ if (tensor_in->type() != options->in_data_type)
{
if (settings->get(luci::UserSettings::Key::DisableValidation))
{
@@ -57,7 +59,7 @@ bool CircleCastGraphBuilder::validate(const ValidateArgs &args) const
return false;
}
const auto &tensor_out = tensors.at(outputs[0]);
- if (tensor_out->type != options->out_data_type)
+ if (tensor_out->type() != options->out_data_type)
{
if (settings->get(luci::UserSettings::Key::DisableValidation))
{
diff --git a/compiler/luci/import/src/Nodes/CircleConst.cpp b/compiler/luci/import/src/Nodes/CircleConst.cpp
index 11fbb4e54..a4f190dd9 100644
--- a/compiler/luci/import/src/Nodes/CircleConst.cpp
+++ b/compiler/luci/import/src/Nodes/CircleConst.cpp
@@ -30,10 +30,10 @@
namespace
{
-std::ostream &operator<<(std::ostream &os, const std::vector<int32_t> &vect)
+std::ostream &operator<<(std::ostream &os, const luci::VectorWrapper<int32_t> &vect)
{
uint32_t seq = 0;
- for (auto &v : vect)
+ for (const auto &v : vect)
{
if (seq)
os << ", ";
@@ -46,7 +46,8 @@ std::ostream &operator<<(std::ostream &os, const std::vector<int32_t> &vect)
using namespace luci;
template <loco::DataType DT>
-void copy_data(const std::vector<uint8_t> &raw_data, uint32_t num_elements, CircleConst *const_node)
+void copy_data(const VectorWrapper<uint8_t> &raw_data, uint32_t num_elements,
+ CircleConst *const_node)
{
using T = typename loco::DataTypeImpl<DT>::Type;
@@ -67,8 +68,8 @@ void copy_data(const std::vector<uint8_t> &raw_data, uint32_t num_elements, Circ
}
template <>
-void copy_data<loco::DataType::STRING>(const std::vector<uint8_t> &raw_data, uint32_t num_elements,
- CircleConst *const_node)
+void copy_data<loco::DataType::STRING>(const VectorWrapper<uint8_t> &raw_data,
+ uint32_t num_elements, CircleConst *const_node)
{
assert(const_node->sparsityparam() == nullptr);
@@ -106,17 +107,26 @@ void copy_data<loco::DataType::STRING>(const std::vector<uint8_t> &raw_data, uin
namespace luci
{
-CircleConst *create_circleconst(GraphBuilderContext *context, int32_t tensor_index)
+CircleNode *CircleConstNodeBuilder::build(TensorIndex tensor_index,
+ GraphBuilderContext *context) const
{
+ assert(tensor_index >= 0);
LOGGER(l);
auto graph = context->graph();
auto reader = context->reader();
- const auto &tensors = reader->tensors();
- const circle::TensorT &const_tensor = *tensors[tensor_index];
+ const auto tensors = reader->tensors();
+ const auto const_tensor = tensors[tensor_index];
+ assert(const_tensor != nullptr);
+ if (const_tensor->is_variable())
+ {
+ // Create CircleVariable for variable
+ return nullptr;
+ }
- const std::vector<uint8_t> &buffer = reader->buffers()[const_tensor.buffer]->data;
- std::vector<int32_t> const_dims = const_tensor.shape; // in NHWC
+ assert(reader->buffers()[const_tensor->buffer()] != nullptr);
+ const auto buffer = wrap(reader->buffers()[const_tensor->buffer()]->data());
+ const auto const_dims = wrap(const_tensor->shape()); // in NHWC
if (const_dims.size() == 0 && buffer.empty())
{
// unknown shape tensor and scalar tensor
@@ -150,7 +160,7 @@ CircleConst *create_circleconst(GraphBuilderContext *context, int32_t tensor_ind
<< const_dims << std::endl;
if (num_elements > 0)
{
- switch (luci_datatype(const_tensor.type))
+ switch (luci_datatype(const_tensor->type()))
{
case loco::DataType::FLOAT32:
copy_data<loco::DataType::FLOAT32>(buffer, num_elements, const_node);
@@ -186,7 +196,7 @@ CircleConst *create_circleconst(GraphBuilderContext *context, int32_t tensor_ind
default:
throw oops::UserExn("Unsupported tensor type",
- circle::EnumNameTensorType(const_tensor.type));
+ circle::EnumNameTensorType(const_tensor->type()));
}
}
diff --git a/compiler/luci/import/src/Nodes/CircleCustom.cpp b/compiler/luci/import/src/Nodes/CircleCustom.cpp
index 01ac3e2a0..4e78d5fb7 100644
--- a/compiler/luci/import/src/Nodes/CircleCustom.cpp
+++ b/compiler/luci/import/src/Nodes/CircleCustom.cpp
@@ -39,13 +39,15 @@ CircleNode *CircleCustomGraphBuilder::build_node(const BuildNodeArgs &bna) const
node->inputs(idx, bna.input_nodes[idx]);
}
- const auto &opcodes = bna.context->reader()->opcodes();
+ const auto opcodes = bna.context->reader()->opcodes();
const uint32_t opcode_index = bna.op.opcode_index;
- const circle::OperatorCodeT &opcode = *opcodes[opcode_index];
+ const auto opcode = opcodes[opcode_index];
+ assert(opcode != nullptr);
node->custom_options(
std::vector<uint8_t>{bna.op.custom_options.begin(), bna.op.custom_options.end()});
- node->custom_code(opcode.custom_code);
+ assert(opcode->custom_code() != nullptr);
+ node->custom_code(opcode->custom_code()->c_str());
// NOTE Operator version of custom is always 1
diff --git a/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp b/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp
index 49eb30a83..83fc2e37d 100644
--- a/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp
+++ b/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp
@@ -34,9 +34,10 @@ bool CircleDepthToSpaceGraphBuilder::validate(const ValidateArgs &args) const
const auto &outputs = args.op.outputs;
const auto *options = args.op.builtin_options.AsDepthToSpaceOptions();
- const auto &tensors = args.reader.tensors();
+ const auto tensors = args.reader.tensors();
+ assert(tensors[outputs[0]] != nullptr && tensors[inputs.at(0)] != nullptr);
- if (tensors[outputs[0]]->type != tensors[inputs.at(0)]->type)
+ if (tensors[outputs[0]]->type() != tensors[inputs.at(0)]->type())
{
return false;
}
diff --git a/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp b/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp
index 727487c6a..a24e4160d 100644
--- a/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp
+++ b/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp
@@ -32,19 +32,21 @@ bool CircleDepthwiseConv2DGraphBuilder::validate(const ValidateArgs &args) const
if (args.op.outputs.size() != 1)
return false;
- const auto &tensors = args.reader.tensors();
+ const auto tensors = args.reader.tensors();
// input shape
- const auto &input = tensors.at(args.op.inputs.at(0));
- const auto &input_shape = input->shape;
+ const auto input = tensors.at(args.op.inputs.at(0));
+ assert(input != nullptr);
+ const auto input_shape = wrap(input->shape());
// input shape must be rank 4
if (input_shape.size() != 4)
return false;
// filter shape
- const auto &filter = tensors.at(args.op.inputs.at(1));
- const auto &filter_shape = filter->shape;
+ const auto filter = tensors.at(args.op.inputs.at(1));
+ assert(filter != nullptr);
+ const auto filter_shape = wrap(filter->shape());
// filter shape must be rank 4
if (filter_shape.size() != 4)
diff --git a/compiler/luci/import/src/Nodes/CircleElu.cpp b/compiler/luci/import/src/Nodes/CircleElu.cpp
index 41696a65a..e5d7a4c7a 100644
--- a/compiler/luci/import/src/Nodes/CircleElu.cpp
+++ b/compiler/luci/import/src/Nodes/CircleElu.cpp
@@ -31,10 +31,11 @@ bool CircleEluGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
- switch (tensor->type)
+ switch (tensor->type())
{
case circle::TensorType_FLOAT64:
break;
@@ -48,7 +49,8 @@ bool CircleEluGraphBuilder::validate(const ValidateArgs &args) const
return false;
}
- if (tensors[outputs[0]]->type != tensor->type)
+ assert(tensors[outputs[0]] != nullptr);
+ if (tensors[outputs[0]]->type() != tensor->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleEqual.cpp b/compiler/luci/import/src/Nodes/CircleEqual.cpp
index 4909692b4..b326d9b5d 100644
--- a/compiler/luci/import/src/Nodes/CircleEqual.cpp
+++ b/compiler/luci/import/src/Nodes/CircleEqual.cpp
@@ -29,9 +29,10 @@ bool CircleEqualGraphBuilder::validate(const ValidateArgs &args) const
return false;
const auto &inputs = args.op.inputs;
- const auto &tensors = args.reader.tensors();
+ const auto tensors = args.reader.tensors();
- return tensors[inputs.at(0)]->type == tensors[inputs.at(1)]->type;
+ assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(1)] != nullptr);
+ return tensors[inputs.at(0)]->type() == tensors[inputs.at(1)]->type();
}
CircleNode *CircleEqualGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleExp.cpp b/compiler/luci/import/src/Nodes/CircleExp.cpp
index 5bb7bb664..82c26f0e5 100644
--- a/compiler/luci/import/src/Nodes/CircleExp.cpp
+++ b/compiler/luci/import/src/Nodes/CircleExp.cpp
@@ -30,9 +30,10 @@ bool CircleExpGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
// input type check
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- switch (tensor->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ switch (tensor->type())
{
case circle::TensorType_FLOAT16:
case circle::TensorType_FLOAT32:
diff --git a/compiler/luci/import/src/Nodes/CircleExpandDims.cpp b/compiler/luci/import/src/Nodes/CircleExpandDims.cpp
index ee0fbdc7e..67d9b7e9e 100644
--- a/compiler/luci/import/src/Nodes/CircleExpandDims.cpp
+++ b/compiler/luci/import/src/Nodes/CircleExpandDims.cpp
@@ -29,9 +29,10 @@ bool CircleExpandDimsGraphBuilder::validate(const ValidateArgs &args) const
return false;
const auto &inputs = args.op.inputs;
- const auto &tensors = args.reader.tensors();
+ const auto tensors = args.reader.tensors();
- return tensors[inputs.at(1)]->type == circle::TensorType_INT32;
+ assert(tensors[inputs.at(1)] != nullptr);
+ return tensors[inputs.at(1)]->type() == circle::TensorType_INT32;
}
CircleNode *CircleExpandDimsGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp b/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp
index ce329326a..67eeddf91 100644
--- a/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp
+++ b/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp
@@ -30,15 +30,18 @@ bool CircleFloorDivGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor_in_0 = tensors.at(inputs.at(0));
- const auto &tensor_in_1 = tensors.at(inputs.at(1));
- const auto &tensor_out = tensors.at(outputs[0]);
-
- if (tensor_in_0->type != tensor_in_1->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor_in_0 = tensors.at(inputs.at(0));
+ const auto tensor_in_1 = tensors.at(inputs.at(1));
+ const auto tensor_out = tensors.at(outputs[0]);
+ assert(tensor_in_0 != nullptr);
+ assert(tensor_in_1 != nullptr);
+ assert(tensor_out != nullptr);
+
+ if (tensor_in_0->type() != tensor_in_1->type())
return false;
- if (tensor_out->type != tensor_in_1->type)
+ if (tensor_out->type() != tensor_in_1->type())
{
return false;
}
diff --git a/compiler/luci/import/src/Nodes/CircleFloorMod.cpp b/compiler/luci/import/src/Nodes/CircleFloorMod.cpp
index d8420a43c..d2a275b62 100644
--- a/compiler/luci/import/src/Nodes/CircleFloorMod.cpp
+++ b/compiler/luci/import/src/Nodes/CircleFloorMod.cpp
@@ -29,10 +29,11 @@ bool CircleFloorModGraphBuilder::validate(const ValidateArgs &args) const
return false;
const auto &inputs = args.op.inputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor_in_0 = tensors.at(inputs.at(0));
- const auto &tensor_in_1 = tensors.at(inputs.at(1));
- if (tensor_in_0->type != tensor_in_1->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor_in_0 = tensors.at(inputs.at(0));
+ const auto tensor_in_1 = tensors.at(inputs.at(1));
+ assert(tensor_in_0 != nullptr && tensor_in_1 != nullptr);
+ if (tensor_in_0->type() != tensor_in_1->type())
return false;
// TODO dtype check
diff --git a/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp b/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp
index 58750d79a..cc7be1693 100644
--- a/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp
+++ b/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp
@@ -42,6 +42,7 @@ CircleNode *CircleFullyConnectedGraphBuilder::build_node(const circle::OperatorT
const auto *options = op.builtin_options.AsFullyConnectedOptions();
node->fusedActivationFunction(luci_actfunc(options->fused_activation_function));
node->weights_format(luci_weights_format(options->weights_format));
+ node->keep_num_dims(options->keep_num_dims);
return node;
}
diff --git a/compiler/luci/import/src/Nodes/CircleGatherNd.cpp b/compiler/luci/import/src/Nodes/CircleGatherNd.cpp
index a4bb26a10..d336878ad 100644
--- a/compiler/luci/import/src/Nodes/CircleGatherNd.cpp
+++ b/compiler/luci/import/src/Nodes/CircleGatherNd.cpp
@@ -31,10 +31,11 @@ bool CircleGatherNdGraphBuilder::validate(const ValidateArgs &args) const
return false;
const auto &inputs = args.op.inputs;
- auto &indices_tensor = args.reader.tensors()[inputs.at(1)];
+ auto indices_tensor = args.reader.tensors()[inputs.at(1)];
+ assert(indices_tensor != nullptr);
- if (!(indices_tensor->type == circle::TensorType::TensorType_INT32 ||
- indices_tensor->type == circle::TensorType::TensorType_INT64))
+ if (!(indices_tensor->type() == circle::TensorType::TensorType_INT32 ||
+ indices_tensor->type() == circle::TensorType::TensorType_INT64))
{
return false;
}
diff --git a/compiler/luci/import/src/Nodes/CircleGreater.cpp b/compiler/luci/import/src/Nodes/CircleGreater.cpp
index f9c00346c..7f031b0ba 100644
--- a/compiler/luci/import/src/Nodes/CircleGreater.cpp
+++ b/compiler/luci/import/src/Nodes/CircleGreater.cpp
@@ -37,17 +37,19 @@ bool CircleGreaterGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
+ const auto tensors = args.reader.tensors();
- if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type)
+ assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(1)] != nullptr);
+ if (tensors[inputs.at(0)]->type() != tensors[inputs.at(1)]->type())
return false;
// NOTE: real models do have output dtype NOT BOOL
- if (tensors[outputs[0]]->type != circle::TensorType_BOOL)
+ assert(tensors[outputs[0]] != nullptr);
+ if (tensors[outputs[0]]->type() != circle::TensorType_BOOL)
{
if (settings->get(luci::UserSettings::Key::DisableValidation))
{
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
+ const auto output_tensor = tensors[outputs[0]];
auto name = tensor_name(output_tensor);
WARN(l) << "Warning: import Greater(" << name << ") output dtype is not boolean";
}
diff --git a/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp b/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp
index e20038fd9..ac4ce62f5 100644
--- a/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp
+++ b/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp
@@ -30,14 +30,16 @@ bool CircleGreaterEqualGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
+ const auto tensors = args.reader.tensors();
- if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type)
+ assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(1)] != nullptr);
+ if (tensors[inputs.at(0)]->type() != tensors[inputs.at(1)]->type())
{
return false;
}
- return tensors[outputs[0]]->type == circle::TensorType::TensorType_BOOL;
+ assert(tensors[outputs[0]] != nullptr);
+ return tensors[outputs[0]]->type() == circle::TensorType::TensorType_BOOL;
}
CircleNode *CircleGreaterEqualGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleIf.cpp b/compiler/luci/import/src/Nodes/CircleIf.cpp
index ffdbf0b79..e8a50ff32 100644
--- a/compiler/luci/import/src/Nodes/CircleIf.cpp
+++ b/compiler/luci/import/src/Nodes/CircleIf.cpp
@@ -42,12 +42,13 @@ bool CircleIfGraphBuilder::validate(const ValidateArgs &args) const
return false;
// input 0 should be BOOL type
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- if (tensor->type != circle::TensorType_BOOL)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ if (tensor->type() != circle::TensorType_BOOL)
return false;
- const auto &shape = tensor->shape;
+ const auto shape = wrap(tensor->shape());
if (shape.size() != 1 && shape.size() != 0)
return false;
diff --git a/compiler/luci/import/src/Nodes/CircleLess.cpp b/compiler/luci/import/src/Nodes/CircleLess.cpp
index f9b99bebe..5c5ae51e1 100644
--- a/compiler/luci/import/src/Nodes/CircleLess.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLess.cpp
@@ -30,10 +30,11 @@ bool CircleLessGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
- switch (tensor->type)
+ switch (tensor->type())
{
case circle::TensorType_FLOAT32:
case circle::TensorType_FLOAT64:
@@ -48,12 +49,14 @@ bool CircleLessGraphBuilder::validate(const ValidateArgs &args) const
return false;
}
- if (tensors[inputs.at(1)]->type != tensor->type)
+ assert(tensors[inputs.at(1)] != nullptr);
+ if (tensors[inputs.at(1)]->type() != tensor->type())
{
return false;
}
- return tensors[outputs[0]]->type == circle::TensorType_BOOL;
+ assert(tensors[outputs[0]] != nullptr);
+ return tensors[outputs[0]]->type() == circle::TensorType_BOOL;
}
CircleNode *CircleLessGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleLessEqual.cpp b/compiler/luci/import/src/Nodes/CircleLessEqual.cpp
index bb1712137..8a2aea8db 100644
--- a/compiler/luci/import/src/Nodes/CircleLessEqual.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLessEqual.cpp
@@ -30,14 +30,16 @@ bool CircleLessEqualGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
+ const auto tensors = args.reader.tensors();
- if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type)
+ assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(1)] != nullptr);
+ if (tensors[inputs.at(0)]->type() != tensors[inputs.at(1)]->type())
{
return false;
}
- return tensors[outputs[0]]->type == circle::TensorType::TensorType_BOOL;
+ assert(tensors[outputs[0]] != nullptr);
+ return tensors[outputs[0]]->type() == circle::TensorType::TensorType_BOOL;
}
CircleNode *CircleLessEqualGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleLog.cpp b/compiler/luci/import/src/Nodes/CircleLog.cpp
index 26b575070..f41926829 100644
--- a/compiler/luci/import/src/Nodes/CircleLog.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLog.cpp
@@ -32,9 +32,10 @@ bool CircleLogGraphBuilder::validate(const ValidateArgs &args) const
// input type check
// Must be one of bfloat16, half, float32, float64, complex64, complex128.
// Currently circle supports half(float16), float32, float64, complex64.
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- switch (tensor->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ switch (tensor->type())
{
case circle::TensorType_FLOAT16:
case circle::TensorType_FLOAT32:
diff --git a/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp b/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp
index b13fc2735..b61fb6f3e 100644
--- a/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp
@@ -30,11 +30,12 @@ bool CircleLogicalAndGraphBuilder::validate(const ValidateArgs &args) const
// Only BOOL type is allowed for inputs
const auto &inputs = args.op.inputs;
- const auto &tensors = args.reader.tensors();
+ const auto tensors = args.reader.tensors();
for (auto input : inputs)
{
- const auto &tensor = tensors.at(input);
- if (tensor->type != circle::TensorType::TensorType_BOOL)
+ const auto tensor = tensors.at(input);
+ assert(tensor != nullptr);
+ if (tensor->type() != circle::TensorType::TensorType_BOOL)
return false;
}
diff --git a/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp b/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp
index f68218349..43e9ed39f 100644
--- a/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp
@@ -30,9 +30,10 @@ bool CircleLogicalNotGraphBuilder::validate(const ValidateArgs &args) const
// Only BOOL type is allowed for the input
const auto &inputs = args.op.inputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- if (tensor->type != circle::TensorType::TensorType_BOOL)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ if (tensor->type() != circle::TensorType::TensorType_BOOL)
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp b/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp
index 8c9023dd3..6354e7dc1 100644
--- a/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp
@@ -30,11 +30,12 @@ bool CircleLogicalOrGraphBuilder::validate(const ValidateArgs &args) const
// Only BOOL type is allowed for inputs
const auto &inputs = args.op.inputs;
- const auto &tensors = args.reader.tensors();
+ const auto tensors = args.reader.tensors();
for (auto input : inputs)
{
- const auto &tensor = tensors.at(input);
- if (tensor->type != circle::TensorType::TensorType_BOOL)
+ const auto tensor = tensors.at(input);
+ assert(tensor != nullptr);
+ if (tensor->type() != circle::TensorType::TensorType_BOOL)
return false;
}
diff --git a/compiler/luci/import/src/Nodes/CircleLogistic.cpp b/compiler/luci/import/src/Nodes/CircleLogistic.cpp
index 0f92a9bb4..b0d08e039 100644
--- a/compiler/luci/import/src/Nodes/CircleLogistic.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLogistic.cpp
@@ -30,8 +30,9 @@ bool CircleLogisticGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
- if (tensors.at(inputs.at(0))->type != tensors.at(outputs[0])->type)
+ const auto tensors = args.reader.tensors();
+ assert(tensors.at(inputs.at(0)) != nullptr && tensors.at(outputs[0]) != nullptr);
+ if (tensors.at(inputs.at(0))->type() != tensors.at(outputs[0])->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp b/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp
index 590a07f2d..384b98586 100644
--- a/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp
+++ b/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp
@@ -30,10 +30,11 @@ bool CircleMatrixDiagGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
- if (tensors[outputs[0]]->type != tensor->type)
+ assert(tensors[outputs[0]] != nullptr && tensor != nullptr);
+ if (tensors[outputs[0]]->type() != tensor->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp b/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp
index edd7d2ae2..64870c057 100644
--- a/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp
+++ b/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp
@@ -30,10 +30,11 @@ bool CircleMatrixSetDiagGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
- if (tensors[outputs[0]]->type != tensor->type)
+ assert(tensors[outputs[0]] != nullptr && tensor != nullptr);
+ if (tensors[outputs[0]]->type() != tensor->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp
index d3d69506b..e86f2ba81 100644
--- a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp
+++ b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp
@@ -35,20 +35,26 @@ bool CircleNonMaxSuppressionV4GraphBuilder::validate(const ValidateArgs &args) c
if (outputs.size() != 2)
return false;
- const auto &tensors = args.reader.tensors();
- const auto &boxes_tensor = tensors.at(inputs[0]);
- if (boxes_tensor->shape.size() != 2)
+ const auto tensors = args.reader.tensors();
+ const auto boxes_tensor = tensors.at(inputs[0]);
+ assert(boxes_tensor != nullptr);
+ const auto boxes_tensor_shape = wrap(boxes_tensor->shape());
+ if (boxes_tensor_shape.size() != 2)
return false;
- if (boxes_tensor->shape.at(1) != 4)
+ if (boxes_tensor_shape.at(1) != 4)
return false;
- if (boxes_tensor->shape.at(0) != tensors.at(inputs[1])->shape.at(0))
+ assert(tensors.at(inputs[1]) != nullptr);
+ if (boxes_tensor_shape.at(0) != wrap(tensors.at(inputs[1])->shape()).at(0))
return false;
- if (tensors.at(inputs[2])->type != circle::TensorType_INT32)
+ assert(tensors.at(inputs[2]) != nullptr);
+ if (tensors.at(inputs[2])->type() != circle::TensorType_INT32)
return false;
- if (tensors.at(inputs[3])->type != circle::TensorType_FLOAT32)
+ assert(tensors.at(inputs[3]) != nullptr);
+ if (tensors.at(inputs[3])->type() != circle::TensorType_FLOAT32)
return false;
- if (tensors.at(inputs[4])->type != circle::TensorType_FLOAT32)
+ assert(tensors.at(inputs[4]) != nullptr);
+ if (tensors.at(inputs[4])->type() != circle::TensorType_FLOAT32)
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp
index d797d4cb7..a60eed4e4 100644
--- a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp
+++ b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp
@@ -35,22 +35,29 @@ bool CircleNonMaxSuppressionV5GraphBuilder::validate(const ValidateArgs &args) c
if (outputs.size() != 3)
return false;
- const auto &tensors = args.reader.tensors();
- const auto &boxes_tensor = tensors.at(inputs[0]);
- if (boxes_tensor->shape.size() != 2)
+ const auto tensors = args.reader.tensors();
+ const auto boxes_tensor = tensors.at(inputs[0]);
+ assert(boxes_tensor != nullptr);
+ const auto boxes_tensor_shape = wrap(boxes_tensor->shape());
+ if (boxes_tensor_shape.size() != 2)
return false;
- if (boxes_tensor->shape.at(1) != 4)
+ if (boxes_tensor_shape.at(1) != 4)
return false;
- if (boxes_tensor->shape.at(0) != tensors.at(inputs[1])->shape.at(0))
+ assert(tensors.at(inputs[1]) != nullptr);
+ if (boxes_tensor_shape.at(0) != wrap(tensors.at(inputs[1])->shape()).at(0))
return false;
- if (tensors.at(inputs[2])->type != circle::TensorType_INT32)
+ assert(tensors.at(inputs[2]) != nullptr);
+ if (tensors.at(inputs[2])->type() != circle::TensorType_INT32)
return false;
- if (tensors.at(inputs[3])->type != circle::TensorType_FLOAT32)
+ assert(tensors.at(inputs[3]) != nullptr);
+ if (tensors.at(inputs[3])->type() != circle::TensorType_FLOAT32)
return false;
- if (tensors.at(inputs[4])->type != circle::TensorType_FLOAT32)
+ assert(tensors.at(inputs[4]) != nullptr);
+ if (tensors.at(inputs[4])->type() != circle::TensorType_FLOAT32)
return false;
- if (tensors.at(inputs[5])->type != circle::TensorType_FLOAT32)
+ assert(tensors.at(inputs[5]) != nullptr);
+ if (tensors.at(inputs[5])->type() != circle::TensorType_FLOAT32)
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleNotEqual.cpp b/compiler/luci/import/src/Nodes/CircleNotEqual.cpp
index a0b8f9e4f..3f5c1e033 100644
--- a/compiler/luci/import/src/Nodes/CircleNotEqual.cpp
+++ b/compiler/luci/import/src/Nodes/CircleNotEqual.cpp
@@ -30,14 +30,16 @@ bool CircleNotEqualGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
+ const auto tensors = args.reader.tensors();
- if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type)
+ assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(1)] != nullptr);
+ if (tensors[inputs.at(0)]->type() != tensors[inputs.at(1)]->type())
{
return false;
}
- return tensors[outputs[0]]->type == circle::TensorType::TensorType_BOOL;
+ assert(tensors[outputs[0]] != nullptr);
+ return tensors[outputs[0]]->type() == circle::TensorType::TensorType_BOOL;
}
CircleNode *CircleNotEqualGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleOneHot.cpp b/compiler/luci/import/src/Nodes/CircleOneHot.cpp
index 3952cc21a..6e5f8e16f 100644
--- a/compiler/luci/import/src/Nodes/CircleOneHot.cpp
+++ b/compiler/luci/import/src/Nodes/CircleOneHot.cpp
@@ -32,21 +32,25 @@ bool CircleOneHotGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto *options = args.op.builtin_options.AsOneHotOptions();
- const auto &tensors = args.reader.tensors();
- const auto &indices = tensors.at(inputs.at(0));
- const auto &depth = tensors.at(inputs.at(1));
- const auto &on_value = tensors.at(inputs.at(2));
- const auto &off_value = tensors.at(inputs.at(3));
+ const auto tensors = args.reader.tensors();
+ const auto indices = tensors.at(inputs.at(0));
+ const auto depth = tensors.at(inputs.at(1));
+ const auto on_value = tensors.at(inputs.at(2));
+ const auto off_value = tensors.at(inputs.at(3));
+ assert(indices != nullptr);
+ assert(depth != nullptr);
+ assert(on_value != nullptr);
+ assert(off_value != nullptr);
- if (options->axis < -1 || options->axis > static_cast<int32_t>(indices->shape.size()))
+ if (options->axis < -1 || options->axis > static_cast<int32_t>(wrap(indices->shape()).size()))
return false;
- if (depth->shape.size() != 0)
+ if (wrap(depth->shape()).size() != 0)
return false;
- if (on_value->shape.size() != 0)
+ if (wrap(on_value->shape()).size() != 0)
return false;
- if (off_value->shape.size() != 0)
+ if (wrap(off_value->shape()).size() != 0)
return false;
- if (on_value->type != off_value->type)
+ if (on_value->type() != off_value->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleReduceAny.cpp b/compiler/luci/import/src/Nodes/CircleReduceAny.cpp
index 13205dd7a..ebe2368e0 100644
--- a/compiler/luci/import/src/Nodes/CircleReduceAny.cpp
+++ b/compiler/luci/import/src/Nodes/CircleReduceAny.cpp
@@ -28,17 +28,20 @@ bool CircleReduceAnyGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor_0 = tensors.at(inputs.at(0));
- const auto &tensor_1 = tensors.at(inputs.at(1));
- const auto &tensor_o = tensors.at(outputs[0]);
+ const auto tensors = args.reader.tensors();
+ const auto tensor_0 = tensors.at(inputs.at(0));
+ const auto tensor_1 = tensors.at(inputs.at(1));
+ const auto tensor_o = tensors.at(outputs[0]);
+ assert(tensor_0 != nullptr);
+ assert(tensor_1 != nullptr);
+ assert(tensor_o != nullptr);
- if (tensor_0->type != circle::TensorType_BOOL)
+ if (tensor_0->type() != circle::TensorType_BOOL)
return false;
- if (tensor_o->type != circle::TensorType_BOOL)
+ if (tensor_o->type() != circle::TensorType_BOOL)
return false;
- switch (tensor_1->type)
+ switch (tensor_1->type())
{
case circle::TensorType_INT32:
case circle::TensorType_INT64:
diff --git a/compiler/luci/import/src/Nodes/CircleReduceProd.cpp b/compiler/luci/import/src/Nodes/CircleReduceProd.cpp
index 3549c1a18..3b874b7c9 100644
--- a/compiler/luci/import/src/Nodes/CircleReduceProd.cpp
+++ b/compiler/luci/import/src/Nodes/CircleReduceProd.cpp
@@ -27,13 +27,14 @@ bool CircleReduceProdGraphBuilder::validate(const ValidateArgs &args) const
return false;
const auto &inputs = args.op.inputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor_1 = tensors.at(inputs.at(1));
+ const auto tensors = args.reader.tensors();
+ const auto tensor_1 = tensors.at(inputs.at(1));
+ assert(tensor_1 != nullptr);
// TODO check input types
// Check for reduction_indices types
- switch (tensor_1->type)
+ switch (tensor_1->type())
{
case circle::TensorType_INT32:
case circle::TensorType_INT64:
diff --git a/compiler/luci/import/src/Nodes/CircleReshape.cpp b/compiler/luci/import/src/Nodes/CircleReshape.cpp
index 401dff0fc..3421620ce 100644
--- a/compiler/luci/import/src/Nodes/CircleReshape.cpp
+++ b/compiler/luci/import/src/Nodes/CircleReshape.cpp
@@ -34,12 +34,13 @@ bool CircleReshapeGraphBuilder::validate(const ValidateArgs &args) const
if (args.op.inputs.size() == 2)
{
const auto &inputs = args.op.inputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor_in = tensors.at(inputs.at(1));
+ const auto tensors = args.reader.tensors();
+ const auto tensor_in = tensors.at(inputs.at(1));
+ assert(tensor_in != nullptr);
// NOTE fix this if there is any other case
// TensorFlow lite and circle only supports S32
- if (tensor_in->type != circle::TensorType::TensorType_INT32)
+ if (tensor_in->type() != circle::TensorType::TensorType_INT32)
return false;
}
diff --git a/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp b/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp
index 2fbb7a87c..c9cc792bb 100644
--- a/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp
+++ b/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp
@@ -30,12 +30,15 @@ bool CircleReverseSequenceGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor_in = tensors.at(inputs.at(0));
- const auto &tensor_lengths = tensors.at(inputs.at(1));
- const auto &tensor_out = tensors.at(outputs[0]);
+ const auto tensors = args.reader.tensors();
+ const auto tensor_in = tensors.at(inputs.at(0));
+ const auto tensor_lengths = tensors.at(inputs.at(1));
+ const auto tensor_out = tensors.at(outputs[0]);
+ assert(tensor_in != nullptr);
+ assert(tensor_lengths != nullptr);
+ assert(tensor_out != nullptr);
- switch (tensor_lengths->type)
+ switch (tensor_lengths->type())
{
case circle::TensorType_INT32:
case circle::TensorType_INT64:
@@ -44,7 +47,7 @@ bool CircleReverseSequenceGraphBuilder::validate(const ValidateArgs &args) const
return false;
}
- if (tensor_in->type != tensor_out->type)
+ if (tensor_in->type() != tensor_out->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleReverseV2.cpp b/compiler/luci/import/src/Nodes/CircleReverseV2.cpp
index ca7653201..c19a0fdd2 100644
--- a/compiler/luci/import/src/Nodes/CircleReverseV2.cpp
+++ b/compiler/luci/import/src/Nodes/CircleReverseV2.cpp
@@ -30,12 +30,15 @@ bool CircleReverseV2GraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor_in = tensors.at(inputs.at(0));
- const auto &tensor_axis = tensors.at(inputs.at(1));
- const auto &tensor_out = tensors.at(outputs[0]);
+ const auto tensors = args.reader.tensors();
+ const auto tensor_in = tensors.at(inputs.at(0));
+ const auto tensor_axis = tensors.at(inputs.at(1));
+ const auto tensor_out = tensors.at(outputs[0]);
+ assert(tensor_in != nullptr);
+ assert(tensor_axis != nullptr);
+ assert(tensor_out != nullptr);
- switch (tensor_axis->type)
+ switch (tensor_axis->type())
{
case circle::TensorType_INT32:
case circle::TensorType_INT64:
@@ -44,7 +47,7 @@ bool CircleReverseV2GraphBuilder::validate(const ValidateArgs &args) const
return false;
}
- if (tensor_out->type != tensor_in->type)
+ if (tensor_out->type() != tensor_in->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleRound.cpp b/compiler/luci/import/src/Nodes/CircleRound.cpp
index d13e0fafe..08cfae6c2 100644
--- a/compiler/luci/import/src/Nodes/CircleRound.cpp
+++ b/compiler/luci/import/src/Nodes/CircleRound.cpp
@@ -33,11 +33,13 @@ bool CircleRoundGraphBuilder::validate(const ValidateArgs &args) const
// Must be one of the following types
// bfloat16, half (float16), float32, float64, complex64, complex128
// Currently, circle supports float16, float32, complex64
- const auto &tensors = args.reader.tensors();
- const auto &tensor_in = tensors.at(inputs.at(0));
- const auto &tensor_out = tensors.at(outputs[0]);
+ const auto tensors = args.reader.tensors();
+ const auto tensor_in = tensors.at(inputs.at(0));
+ const auto tensor_out = tensors.at(outputs[0]);
+ assert(tensor_in != nullptr);
+ assert(tensor_out != nullptr);
- switch (tensor_in->type)
+ switch (tensor_in->type())
{
case circle::TensorType_FLOAT16:
case circle::TensorType_FLOAT32:
@@ -49,7 +51,7 @@ bool CircleRoundGraphBuilder::validate(const ValidateArgs &args) const
return false;
}
- if (tensor_out->type != tensor_in->type)
+ if (tensor_out->type() != tensor_in->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleRsqrt.cpp b/compiler/luci/import/src/Nodes/CircleRsqrt.cpp
index a9ca90832..e3bc68f8b 100644
--- a/compiler/luci/import/src/Nodes/CircleRsqrt.cpp
+++ b/compiler/luci/import/src/Nodes/CircleRsqrt.cpp
@@ -32,9 +32,10 @@ bool CircleRsqrtGraphBuilder::validate(const ValidateArgs &args) const
// Must be one of the following types
// bfloat16, half (float16), float32, float64, complex64, complex128
// Currently, circle supports float16, float32, complex64
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- switch (tensor->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ switch (tensor->type())
{
case circle::TensorType_UINT8:
case circle::TensorType_INT16:
diff --git a/compiler/luci/import/src/Nodes/CircleSVDF.cpp b/compiler/luci/import/src/Nodes/CircleSVDF.cpp
new file mode 100644
index 000000000..83a025177
--- /dev/null
+++ b/compiler/luci/import/src/Nodes/CircleSVDF.cpp
@@ -0,0 +1,67 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Import/Nodes/CircleSVDF.h"
+
+#include <luci/IR/Nodes/CircleSVDF.h>
+
+#include <loco.h>
+
+namespace luci
+{
+
+bool CircleSVDFBuilder::validate(const ValidateArgs &args) const
+{
+ const auto &inputs = args.op.inputs;
+ if (!(inputs.size() == 4 || inputs.size() == 5))
+ return false;
+
+ return true;
+}
+
+CircleNode *CircleSVDFBuilder::build_node(const circle::OperatorT &op,
+ const std::vector<CircleNode *> &inputs,
+ loco::Graph *graph) const
+{
+ auto *node = graph->nodes()->create<CircleSVDF>();
+ node->input(inputs.at(0));
+ node->weight_feature(inputs.at(1));
+ node->weight_time(inputs.at(2));
+ if (inputs.size() == 4)
+ {
+ auto *bias = graph->nodes()->create<CircleOutputExclude>();
+ // CircleOutputExclude doesn't need a type, but since all nodes must have a type,
+ // a dummy type is inserted.
+ bias->dtype(inputs.at(0)->dtype());
+ node->bias(bias);
+
+ node->input_activation_state(inputs.at(3));
+ }
+ else
+ {
+ node->bias(inputs.at(3));
+ node->input_activation_state(inputs.at(4));
+ }
+
+ const auto *options = op.builtin_options.AsSVDFOptions();
+ node->svdf_rank(options->rank);
+ node->fusedActivationFunction(luci_actfunc(options->fused_activation_function));
+ node->asymmetric_quantize_inputs(options->asymmetric_quantize_inputs);
+
+ return node;
+}
+
+} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleScatterNd.cpp b/compiler/luci/import/src/Nodes/CircleScatterNd.cpp
index f8c175110..ebe252527 100644
--- a/compiler/luci/import/src/Nodes/CircleScatterNd.cpp
+++ b/compiler/luci/import/src/Nodes/CircleScatterNd.cpp
@@ -30,14 +30,15 @@ bool CircleScatterNdGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
// indices must have the same type as shape
- const auto &tensors = args.reader.tensors();
+ const auto tensors = args.reader.tensors();
- if (tensors[inputs.at(0)]->type != tensors[inputs.at(2)]->type)
+ assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(2)] != nullptr);
+ if (tensors[inputs.at(0)]->type() != tensors[inputs.at(2)]->type())
return false;
// indices must be either int32 or int64
- if (tensors[inputs.at(0)]->type != circle::TensorType_INT32 &&
- tensors[inputs.at(0)]->type != circle::TensorType_INT64)
+ if (tensors[inputs.at(0)]->type() != circle::TensorType_INT32 &&
+ tensors[inputs.at(0)]->type() != circle::TensorType_INT64)
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp b/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp
index bfa333e8d..01d1aab44 100644
--- a/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp
@@ -30,12 +30,15 @@ bool CircleSegmentSumGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor_in = tensors.at(inputs.at(0));
- const auto &tensor_out = tensors.at(outputs[0]);
- const auto &tensor_ids = tensors.at(inputs.at(1));
+ const auto tensors = args.reader.tensors();
+ const auto tensor_in = tensors.at(inputs.at(0));
+ const auto tensor_out = tensors.at(outputs[0]);
+ const auto tensor_ids = tensors.at(inputs.at(1));
+ assert(tensor_in != nullptr);
+ assert(tensor_out != nullptr);
+ assert(tensor_ids != nullptr);
- switch (tensor_ids->type)
+ switch (tensor_ids->type())
{
case circle::TensorType_INT32:
case circle::TensorType_INT64:
@@ -44,7 +47,7 @@ bool CircleSegmentSumGraphBuilder::validate(const ValidateArgs &args) const
return false;
}
- if (tensor_out->type != tensor_in->type)
+ if (tensor_out->type() != tensor_in->type())
{
return false;
}
diff --git a/compiler/luci/import/src/Nodes/CircleSelect.cpp b/compiler/luci/import/src/Nodes/CircleSelect.cpp
index 36a5fa8a8..002f62f6c 100644
--- a/compiler/luci/import/src/Nodes/CircleSelect.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSelect.cpp
@@ -29,9 +29,10 @@ bool CircleSelectGraphBuilder::validate(const ValidateArgs &args) const
return false;
const auto &inputs = args.op.inputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- if (tensor->type != circle::TensorType_BOOL)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ if (tensor->type() != circle::TensorType_BOOL)
return false;
// TODO check dtypes for input 1, 2
diff --git a/compiler/luci/import/src/Nodes/CircleSelectV2.cpp b/compiler/luci/import/src/Nodes/CircleSelectV2.cpp
index 556c8fa33..062fdc143 100644
--- a/compiler/luci/import/src/Nodes/CircleSelectV2.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSelectV2.cpp
@@ -29,14 +29,16 @@ bool CircleSelectV2GraphBuilder::validate(const ValidateArgs &args) const
return false;
const auto &inputs = args.op.inputs;
- const auto &tensors = args.reader.tensors();
- const auto &condition = tensors.at(inputs.at(0));
- if (condition->type != circle::TensorType_BOOL)
+ const auto tensors = args.reader.tensors();
+ const auto condition = tensors.at(inputs.at(0));
+ assert(condition != nullptr);
+ if (condition->type() != circle::TensorType_BOOL)
return false;
- const auto &t = tensors.at(inputs.at(1));
- const auto &e = tensors.at(inputs.at(2));
- if (t->type != e->type)
+ const auto t = tensors.at(inputs.at(1));
+ const auto e = tensors.at(inputs.at(2));
+ assert(t != nullptr && e != nullptr);
+ if (t->type() != e->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleSin.cpp b/compiler/luci/import/src/Nodes/CircleSin.cpp
index 22f461123..51ebf0355 100644
--- a/compiler/luci/import/src/Nodes/CircleSin.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSin.cpp
@@ -30,9 +30,10 @@ bool CircleSinGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
// input type check
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- switch (tensor->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ switch (tensor->type())
{
case circle::TensorType_FLOAT16:
case circle::TensorType_FLOAT32:
diff --git a/compiler/luci/import/src/Nodes/CircleSquare.cpp b/compiler/luci/import/src/Nodes/CircleSquare.cpp
index 7ff2b84e6..bec84b4c0 100644
--- a/compiler/luci/import/src/Nodes/CircleSquare.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSquare.cpp
@@ -29,13 +29,13 @@ bool CircleSquareGraphBuilder::validate(const ValidateArgs &args) const
return false;
const auto &inputs = args.op.inputs;
- // Must be one of the following types
- // bfloat16, half (float16), float32, float64, complex64, complex128
- // Currently, circle supports float16, float32, complex64
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- switch (tensor->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ switch (tensor->type())
{
+ case circle::TensorType_UINT8:
+ case circle::TensorType_INT16:
case circle::TensorType_INT32:
case circle::TensorType_INT64:
case circle::TensorType_FLOAT16:
diff --git a/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp b/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp
index 33440d5ab..1983465d3 100644
--- a/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp
@@ -32,9 +32,10 @@ bool CircleSquaredDifferenceGraphBuilder::validate(const ValidateArgs &args) con
const auto &outputs = args.op.outputs;
// Inputs must be one of the following types
// bfloat16, half(float16), float32, float64, int32, int64, complex64, complex128
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- switch (tensor->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ switch (tensor->type())
{
case circle::TensorType_FLOAT16:
case circle::TensorType_FLOAT32:
@@ -53,11 +54,13 @@ bool CircleSquaredDifferenceGraphBuilder::validate(const ValidateArgs &args) con
}
// Input types must match
- if (tensors.at(inputs.at(0))->type != tensors.at(inputs.at(1))->type)
+ assert(tensors.at(inputs.at(0)) != nullptr && tensors.at(inputs.at(1)) != nullptr);
+ if (tensors.at(inputs.at(0))->type() != tensors.at(inputs.at(1))->type())
return false;
// Input and output types must match
- if (tensors.at(inputs.at(0))->type != tensors.at(outputs[0])->type)
+ assert(tensors.at(outputs[0]) != nullptr);
+ if (tensors.at(inputs.at(0))->type() != tensors.at(outputs[0])->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleTanh.cpp b/compiler/luci/import/src/Nodes/CircleTanh.cpp
index 95625a0e4..80a0e887f 100644
--- a/compiler/luci/import/src/Nodes/CircleTanh.cpp
+++ b/compiler/luci/import/src/Nodes/CircleTanh.cpp
@@ -30,8 +30,9 @@ bool CircleTanhGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
- if (tensors.at(inputs.at(0))->type != tensors.at(outputs[0])->type)
+ const auto tensors = args.reader.tensors();
+ assert(tensors.at(inputs.at(0)) != nullptr && tensors.at(outputs[0]) != nullptr);
+ if (tensors.at(inputs.at(0))->type() != tensors.at(outputs[0])->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleTile.cpp b/compiler/luci/import/src/Nodes/CircleTile.cpp
index 6da44130c..c41a6ba3f 100644
--- a/compiler/luci/import/src/Nodes/CircleTile.cpp
+++ b/compiler/luci/import/src/Nodes/CircleTile.cpp
@@ -32,9 +32,10 @@ bool CircleTileGraphBuilder::validate(const ValidateArgs &args) const
auto outputs = args.op.outputs;
// Multiples (inputs.at(1)) must be one of the following types
// int32, int64
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(1));
- switch (tensor->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(1));
+ assert(tensor != nullptr);
+ switch (tensor->type())
{
case circle::TensorType_INT32:
case circle::TensorType_INT64:
@@ -44,7 +45,8 @@ bool CircleTileGraphBuilder::validate(const ValidateArgs &args) const
}
// Type of input and output must be the same
- if (tensors.at(inputs.at(0))->type != tensors.at(outputs[0])->type)
+ assert(tensors.at(inputs.at(0)) != nullptr && tensors.at(outputs[0]) != nullptr);
+ if (tensors.at(inputs.at(0))->type() != tensors.at(outputs[0])->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleTopKV2.cpp b/compiler/luci/import/src/Nodes/CircleTopKV2.cpp
index 49f858798..9f9173738 100644
--- a/compiler/luci/import/src/Nodes/CircleTopKV2.cpp
+++ b/compiler/luci/import/src/Nodes/CircleTopKV2.cpp
@@ -35,9 +35,10 @@ bool CircleTopKV2GraphBuilder::validate(const ValidateArgs &args) const
if (outputs.size() != 2)
return false;
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(1));
- if (tensor->type != circle::TensorType_INT32)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(1));
+ assert(tensor != nullptr);
+ if (tensor->type() != circle::TensorType_INT32)
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp b/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp
index 5a60e2f54..041983dac 100644
--- a/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp
+++ b/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp
@@ -31,11 +31,13 @@ bool CircleTransposeConvGraphBuilder::validate(const ValidateArgs &args) const
return false;
const auto &inputs = args.op.inputs;
- const auto &tensors = args.reader.tensors();
- const auto &filter_tensor = tensors.at(inputs.at(1));
- const auto &filter_shape = filter_tensor.get()->shape;
- const auto &ifm_tensor = tensors.at(inputs.at(2));
- const auto &ifm_shape = ifm_tensor.get()->shape;
+ const auto tensors = args.reader.tensors();
+ const auto filter_tensor = tensors.at(inputs.at(1));
+ assert(filter_tensor != nullptr);
+ const auto filter_shape = wrap(filter_tensor->shape());
+ const auto ifm_tensor = tensors.at(inputs.at(2));
+ assert(ifm_tensor != nullptr);
+ const auto ifm_shape = wrap(ifm_tensor->shape());
// ifm and filters must be 4-D tensor
if (ifm_shape.size() != 4)
@@ -45,7 +47,7 @@ bool CircleTransposeConvGraphBuilder::validate(const ValidateArgs &args) const
// input shape : [batch, height, width, in_channels]
// filters shape : [output_channels, height, weight, in_channels]
- if (ifm_tensor.get()->shape.at(3) != filter_tensor.get()->shape.at(3))
+ if (ifm_shape.at(3) != filter_shape.at(3))
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleUnpack.cpp b/compiler/luci/import/src/Nodes/CircleUnpack.cpp
index 9bfc76b57..6b3401609 100644
--- a/compiler/luci/import/src/Nodes/CircleUnpack.cpp
+++ b/compiler/luci/import/src/Nodes/CircleUnpack.cpp
@@ -46,8 +46,8 @@ bool CircleUnpackGraphBuilder::validate(const ValidateArgs &args) const
{
if (settings->get(luci::UserSettings::Key::DisableValidation))
{
- const auto &tensors = args.reader.tensors();
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
+ const auto tensors = args.reader.tensors();
+ const auto output_tensor = tensors[outputs[0]];
auto name = tensor_name(output_tensor);
WARN(l) << "Warning: import Unpack(" << name << ") 'num' is not same as outputs used";
}
@@ -58,9 +58,10 @@ bool CircleUnpackGraphBuilder::validate(const ValidateArgs &args) const
if (options->num < 0)
return false;
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- const auto &shape = tensor->shape;
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ const auto shape = wrap(tensor->shape());
auto shape_size = static_cast<int32_t>(shape.size());
if (shape_size > 0)
{
diff --git a/compiler/luci/import/src/Nodes/CircleVariable.cpp b/compiler/luci/import/src/Nodes/CircleVariable.cpp
new file mode 100644
index 000000000..23ae9e7be
--- /dev/null
+++ b/compiler/luci/import/src/Nodes/CircleVariable.cpp
@@ -0,0 +1,80 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Import/Nodes/CircleVariable.h"
+
+#include <luci/IR/Nodes/CircleVariable.h>
+#include <luci/Log.h>
+
+#include <cassert>
+#include <ostream>
+#include <string>
+#include <vector>
+
+namespace
+{
+
+std::ostream &operator<<(std::ostream &os, const luci::VectorWrapper<int32_t> &vect)
+{
+ uint32_t seq = 0;
+ for (const auto &v : vect)
+ {
+ if (seq)
+ os << ", ";
+ os << v;
+ seq++;
+ }
+ return os;
+}
+
+} // namespace
+
+namespace luci
+{
+
+CircleVariable *create_circlevariable(GraphBuilderContext *context, int32_t tensor_index)
+{
+ LOGGER(l);
+
+ auto graph = context->graph();
+ auto reader = context->reader();
+ const auto tensors = reader->tensors();
+ const auto variable_tensor = tensors[tensor_index];
+ assert(variable_tensor != nullptr);
+
+ if (not variable_tensor->is_variable())
+ {
+ // not a variable
+ return nullptr;
+ }
+ {
+ // check if there is no buffer as we don't support this for now
+ // TODO use buffer when this is enabled in Kernel
+ assert(reader->buffers()[variable_tensor->buffer()] != nullptr);
+ assert(reader->buffers()[variable_tensor->buffer()]->data() == nullptr);
+ }
+
+ auto variable_node = graph->nodes()->create<CircleVariable>();
+ copy_tensor_attributes(variable_tensor, variable_node);
+ variable_node->shape_status(luci::ShapeStatus::VALID);
+
+ INFO(l) << "[luci] NodeFinder variable node(" << tensor_index << ") -> " << variable_node << " "
+ << wrap(variable_tensor->shape()) << std::endl;
+
+ return variable_node;
+}
+
+} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleWhere.cpp b/compiler/luci/import/src/Nodes/CircleWhere.cpp
index 8e4f1a0c4..bc6199ace 100644
--- a/compiler/luci/import/src/Nodes/CircleWhere.cpp
+++ b/compiler/luci/import/src/Nodes/CircleWhere.cpp
@@ -30,14 +30,16 @@ bool CircleWhereGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor_condition = tensors.at(inputs.at(0));
- const auto &tensor_out = tensors.at(outputs[0]);
+ const auto tensors = args.reader.tensors();
+ const auto tensor_condition = tensors.at(inputs.at(0));
+ const auto tensor_out = tensors.at(outputs[0]);
+ assert(tensor_condition != nullptr);
+ assert(tensor_out != nullptr);
- if (tensor_condition->type != circle::TensorType_BOOL)
+ if (tensor_condition->type() != circle::TensorType_BOOL)
return false;
- if (tensor_out->type != circle::TensorType_INT64)
+ if (tensor_out->type() != circle::TensorType_INT64)
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleWhile.cpp b/compiler/luci/import/src/Nodes/CircleWhile.cpp
index 26147562f..27a392b2a 100644
--- a/compiler/luci/import/src/Nodes/CircleWhile.cpp
+++ b/compiler/luci/import/src/Nodes/CircleWhile.cpp
@@ -67,8 +67,8 @@ CircleNode *CircleWhileGraphBuilder::build(const circle::OperatorT &op,
const std::vector<int32_t> &inputs = op.inputs;
const std::vector<int32_t> &outputs = op.outputs;
- const auto &tensors = context->reader()->tensors();
- const auto &opcodes = context->reader()->opcodes();
+ const auto tensors = context->reader()->tensors();
+ const auto opcodes = context->reader()->opcodes();
std::vector<CircleNode *> input_nodes;
for (const int32_t input_tensor_index : inputs)
@@ -96,9 +96,11 @@ CircleNode *CircleWhileGraphBuilder::build(const circle::OperatorT &op,
assert(outputs.size() > 0);
{
// Lets use name of output 0 as While name
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
+ const auto output_tensor = tensors[outputs[0]];
+ assert(output_tensor != nullptr);
node->name(tensor_name(output_tensor));
- node->op_version(opcodes[op.opcode_index].get()->version);
+ assert(opcodes[op.opcode_index] != nullptr);
+ node->op_version(opcodes[op.opcode_index]->version());
// NOTE We don't set quantization for While itself but to virtual outputs
}
@@ -106,7 +108,8 @@ CircleNode *CircleWhileGraphBuilder::build(const circle::OperatorT &op,
// Create virtual outputs of While
for (uint32_t n = 0; n < output_count; ++n)
{
- const circle::TensorT &output_tensor = *tensors[outputs[n]];
+ const auto output_tensor = tensors[outputs[n]];
+ assert(output_tensor != nullptr);
auto *nodeout = graph->nodes()->create<CircleWhileOut>();
diff --git a/compiler/luci/import/src/ValidateHelpers.cpp b/compiler/luci/import/src/ValidateHelpers.cpp
index 27306ba90..fc027704b 100644
--- a/compiler/luci/import/src/ValidateHelpers.cpp
+++ b/compiler/luci/import/src/ValidateHelpers.cpp
@@ -26,9 +26,10 @@ bool validate_batch_space_nd(const GraphBuilderBase::ValidateArgs &args)
return false;
// input 1 and 2 should have INT32/INT64 type
- const auto &tensors = args.reader.tensors();
- const auto &tensor_1 = tensors.at(inputs.at(1));
- switch (tensor_1->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor_1 = tensors.at(inputs.at(1));
+ assert(tensor_1 != nullptr);
+ switch (tensor_1->type())
{
case circle::TensorType_INT32:
case circle::TensorType_INT64:
@@ -36,8 +37,9 @@ bool validate_batch_space_nd(const GraphBuilderBase::ValidateArgs &args)
default:
return false;
}
- const auto &tensor_2 = tensors.at(inputs.at(2));
- switch (tensor_2->type)
+ const auto tensor_2 = tensors.at(inputs.at(2));
+ assert(tensor_2 != nullptr);
+ switch (tensor_2->type())
{
case circle::TensorType_INT32:
case circle::TensorType_INT64:
@@ -47,8 +49,9 @@ bool validate_batch_space_nd(const GraphBuilderBase::ValidateArgs &args)
}
// Only support input shape dimension 3 and 4 only
- const auto &tensor_0 = tensors.at(inputs.at(0));
- const auto t_0_s = tensor_0->shape.size();
+ const auto tensor_0 = tensors.at(inputs.at(0));
+ assert(tensor_0 != nullptr);
+ const auto t_0_s = wrap(tensor_0->shape()).size();
if (t_0_s != 3 && t_0_s != 4)
return false;
@@ -68,10 +71,10 @@ bool validate_minmax(const GraphBuilderBase::ValidateArgs &args)
if (outputs.size() != 1)
return false;
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
-
- switch (tensor->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ switch (tensor->type())
{
case circle::TensorType_FLOAT16:
case circle::TensorType_FLOAT32:
@@ -84,10 +87,12 @@ bool validate_minmax(const GraphBuilderBase::ValidateArgs &args)
return false;
}
- if (tensors[inputs.at(1)]->type != tensor->type)
+ assert(tensors[inputs.at(1)] != nullptr);
+ if (tensors[inputs.at(1)]->type() != tensor->type())
return false;
- if (tensors[outputs[0]]->type != tensor->type)
+ assert(tensors[outputs[0]] != nullptr);
+ if (tensors[outputs[0]]->type() != tensor->type())
return false;
return true;
@@ -104,10 +109,10 @@ bool validate_reduce_minmax(const GraphBuilderBase::ValidateArgs &args)
if (outputs.size() != 1)
return false;
- const auto &tensors = args.reader.tensors();
- const auto &tensor_axis = tensors.at(inputs.at(1));
-
- switch (tensor_axis->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor_axis = tensors.at(inputs.at(1));
+ assert(tensor_axis != nullptr);
+ switch (tensor_axis->type())
{
case circle::TensorType_INT32:
case circle::TensorType_INT64:
diff --git a/compiler/luci/lang/include/luci/IR/CircleNodes.h b/compiler/luci/lang/include/luci/IR/CircleNodes.h
index a313f9d5b..d89ea03cc 100644
--- a/compiler/luci/lang/include/luci/IR/CircleNodes.h
+++ b/compiler/luci/lang/include/luci/IR/CircleNodes.h
@@ -29,7 +29,6 @@
#include "Nodes/CircleCast.h"
#include "Nodes/CircleCeil.h"
#include "Nodes/CircleConcatenation.h"
-#include "Nodes/CircleConst.h"
#include "Nodes/CircleConv2D.h"
#include "Nodes/CircleCos.h"
#include "Nodes/CircleCustom.h"
@@ -119,6 +118,7 @@
#include "Nodes/CircleStridedSlice.h"
#include "Nodes/CircleSub.h"
#include "Nodes/CircleSum.h"
+#include "Nodes/CircleSVDF.h"
#include "Nodes/CircleTanh.h"
#include "Nodes/CircleTile.h"
#include "Nodes/CircleTopKV2.h"
@@ -135,18 +135,21 @@
#include "Nodes/CircleBCQGather.h"
#include "Nodes/CircleInstanceNorm.h"
// Virtual nodes
+#include "Nodes/CircleConst.h"
#include "Nodes/CircleInput.h"
#include "Nodes/CircleOutput.h"
+#include "Nodes/CircleVariable.h"
+// Multi-output virtual nodes
#include "Nodes/CircleBidirectionalSequenceLSTMOut.h"
#include "Nodes/CircleCustomOut.h"
#include "Nodes/CircleIfOut.h"
#include "Nodes/CircleNonMaxSuppressionV4Out.h"
#include "Nodes/CircleNonMaxSuppressionV5Out.h"
-#include "Nodes/CircleUnpackOut.h"
-#include "Nodes/CircleUniqueOut.h"
#include "Nodes/CircleSplitOut.h"
#include "Nodes/CircleSplitVOut.h"
#include "Nodes/CircleTopKV2Out.h"
+#include "Nodes/CircleUniqueOut.h"
+#include "Nodes/CircleUnpackOut.h"
#include "Nodes/CircleWhileOut.h"
#include <loco/IR/Graph.h>
diff --git a/compiler/luci/lang/include/luci/IR/CircleNodes.lst b/compiler/luci/lang/include/luci/IR/CircleNodes.lst
index 914aa16e4..1472008df 100644
--- a/compiler/luci/lang/include/luci/IR/CircleNodes.lst
+++ b/compiler/luci/lang/include/luci/IR/CircleNodes.lst
@@ -116,6 +116,7 @@ CIRCLE_NODE(SQUEEZE, CircleSqueeze)
CIRCLE_NODE(STRIDED_SLICE, CircleStridedSlice)
CIRCLE_NODE(SUB, CircleSub)
CIRCLE_NODE(SUM, CircleSum)
+CIRCLE_NODE(SVDF, CircleSVDF)
CIRCLE_NODE(TANH, CircleTanh)
CIRCLE_NODE(TILE, CircleTile)
CIRCLE_NODE(TOPK_V2, CircleTopKV2)
@@ -132,12 +133,14 @@ CIRCLE_NODE(BCQ_FULLY_CONNECTED, CircleBCQFullyConnected)
CIRCLE_NODE(BCQ_GATHER, CircleBCQGather)
CIRCLE_NODE(INSTANCE_NORM, CircleInstanceNorm)
// Virtual node(s)
-CIRCLE_VNODE(CIRCLEBIDIRECTIONAL_SEQUENCE_LSTM_OUT, CircleBidirectionalSequenceLSTMOut)
CIRCLE_VNODE(CIRCLECONST, CircleConst)
CIRCLE_VNODE(CIRCLEINPUT, CircleInput)
CIRCLE_VNODE(CIRCLEOUTPUT, CircleOutput)
CIRCLE_VNODE(CIRCLEOUTPUTDUMMY, CircleOutputDummy)
CIRCLE_VNODE(CIRCLEOUTPUTEXCLUDE, CircleOutputExclude)
+CIRCLE_VNODE(CIRCLEVARIABLE, CircleVariable)
+// Multi-output virtual nodes
+CIRCLE_VNODE(CIRCLEBIDIRECTIONAL_SEQUENCE_LSTM_OUT, CircleBidirectionalSequenceLSTMOut)
CIRCLE_VNODE(CIRCLECUSTOMOUT, CircleCustomOut)
CIRCLE_VNODE(CIRCLEIFOUT, CircleIfOut)
CIRCLE_VNODE(CIRCLENONMAXSUPPRESSIONV4OUT, CircleNonMaxSuppressionV4Out)
diff --git a/compiler/luci/lang/include/luci/IR/CircleQuantParam.h b/compiler/luci/lang/include/luci/IR/CircleQuantParam.h
index 694437303..8afc80a76 100644
--- a/compiler/luci/lang/include/luci/IR/CircleQuantParam.h
+++ b/compiler/luci/lang/include/luci/IR/CircleQuantParam.h
@@ -32,6 +32,10 @@ struct CircleQuantParam
int32_t quantized_dimension{0};
};
+struct CircleNode;
+
+void copy_quantparam(const luci::CircleNode *src, luci::CircleNode *dst);
+
} // namespace luci
#endif // __LUCI_IR_CIRCLEQUANTPARAM_H__
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h
index 2862cadb2..dc5aeb267 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h
@@ -58,8 +58,12 @@ public:
WeightsFormat weights_format(void) const { return _weights_format; }
void weights_format(WeightsFormat weights_format) { _weights_format = weights_format; }
+ bool keep_num_dims(void) const { return _keep_num_dims; }
+ void keep_num_dims(bool keep_num_dims) { _keep_num_dims = keep_num_dims; }
+
private:
WeightsFormat _weights_format{WeightsFormat::DEFAULT};
+ bool _keep_num_dims{false};
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSVDF.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSVDF.h
new file mode 100644
index 000000000..839d11e04
--- /dev/null
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSVDF.h
@@ -0,0 +1,70 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_IR_CIRCLE_SVDF_H__
+#define __LUCI_IR_CIRCLE_SVDF_H__
+
+#include "luci/IR/CircleNodeDecl.h"
+#include "luci/IR/CircleOpcode.h"
+
+#include "luci/IR/LuciNodeMixins.h"
+
+namespace luci
+{
+
+/**
+ * @brief SVDF in Circle
+ */
+class CircleSVDF final : public FixedArityNode<5, CircleNodeImpl<CircleOpcode::SVDF>>,
+ public CircleNodeMixin<CircleNodeTrait::FusedActFunc>
+{
+public:
+ CircleSVDF() = default;
+
+public:
+ loco::Node *input(void) const { return at(0)->node(); }
+ void input(loco::Node *node) { at(0)->node(node); }
+
+ loco::Node *weight_feature(void) const { return at(1)->node(); }
+ void weight_feature(loco::Node *node) { at(1)->node(node); }
+
+ loco::Node *weight_time(void) const { return at(2)->node(); }
+ void weight_time(loco::Node *node) { at(2)->node(node); }
+
+ loco::Node *bias(void) const { return at(3)->node(); }
+ void bias(loco::Node *node) { at(3)->node(node); }
+
+ loco::Node *input_activation_state(void) const { return at(4)->node(); }
+ void input_activation_state(loco::Node *node) { at(4)->node(node); }
+
+public:
+ bool asymmetric_quantize_inputs() const { return _asymmetric_quantize_inputs; }
+ void asymmetric_quantize_inputs(bool asymmetric_quantize_inputs)
+ {
+ _asymmetric_quantize_inputs = asymmetric_quantize_inputs;
+ }
+
+ int32_t svdf_rank() const { return _rank; }
+ void svdf_rank(int32_t svdf_rank) { _rank = svdf_rank; }
+
+private:
+ bool _asymmetric_quantize_inputs = false;
+ int32_t _rank = 0;
+};
+
+} // namespace luci
+
+#endif // __LUCI_IR_CIRCLE_SVDF_H__
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleVariable.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleVariable.h
new file mode 100644
index 000000000..8c15b66c9
--- /dev/null
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleVariable.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_IR_CIRCLE_VARIABLE_H__
+#define __LUCI_IR_CIRCLE_VARIABLE_H__
+
+#include "luci/IR/CircleNodeDecl.h"
+#include "luci/IR/CircleOpcode.h"
+
+#include "luci/IR/CircleNodeMixins.h"
+
+namespace luci
+{
+
+/**
+ * @brief Virtual CircleVariable in Circle for 'variable' Tensor
+ */
+class CircleVariable final : public FixedArityNode<0, CircleNodeImpl<CircleOpcode::CIRCLEVARIABLE>>
+{
+public:
+ CircleVariable() = default;
+};
+
+} // namespace luci
+
+#endif // __LUCI_IR_CIRCLE_VARIABLE_H__
diff --git a/compiler/luci/lang/src/CircleQuantParam.cpp b/compiler/luci/lang/src/CircleQuantParam.cpp
new file mode 100644
index 000000000..89671d3c3
--- /dev/null
+++ b/compiler/luci/lang/src/CircleQuantParam.cpp
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/IR/CircleQuantParam.h"
+#include "luci/IR/CircleNode.h"
+
+#include <memory>
+
+namespace luci
+{
+
+/**
+ * @brief copy CircleQuantParam of src to dst
+ */
+void copy_quantparam(const luci::CircleNode *src, luci::CircleNode *dst)
+{
+ auto q = src->quantparam();
+ if (q == nullptr)
+ dst->quantparam(nullptr);
+ else
+ {
+ auto qparam = std::make_unique<luci::CircleQuantParam>();
+ qparam->scale = q->scale;
+ qparam->zerop = q->zerop;
+ qparam->min = q->min;
+ qparam->max = q->max;
+ qparam->quantized_dimension = q->quantized_dimension;
+
+ dst->quantparam(std::move(qparam));
+ }
+}
+
+} // namespace luci
diff --git a/compiler/luci/lang/src/CircleQuantParam.test.cpp b/compiler/luci/lang/src/CircleQuantParam.test.cpp
new file mode 100644
index 000000000..520ca05cc
--- /dev/null
+++ b/compiler/luci/lang/src/CircleQuantParam.test.cpp
@@ -0,0 +1,78 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// NOTE any node will do for testing
+#include "luci/IR/Nodes/CircleAdd.h"
+
+#include <loco/IR/Graph.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+luci::CircleAdd *build_simple_add_graph(loco::Graph *g)
+{
+ auto node = g->nodes()->create<luci::CircleAdd>();
+
+ node->name("name");
+ node->dtype(loco::DataType::FLOAT32);
+ node->rank(1);
+ node->dim(0).set(3);
+ node->shape_status(luci::ShapeStatus::VALID);
+ node->fusedActivationFunction(luci::FusedActFunc::NONE);
+
+ auto qparam = std::make_unique<luci::CircleQuantParam>();
+ qparam->scale = {1.0};
+ qparam->zerop = {0};
+ qparam->min = {0.0};
+ qparam->max = {1.0};
+ qparam->quantized_dimension = 0;
+ node->quantparam(std::move(qparam));
+
+ return node;
+}
+
+} // namespace
+
+TEST(CircleNodeCloneTest, copy_quantparam)
+{
+ auto g = loco::make_graph();
+ auto node = build_simple_add_graph(g.get());
+
+ auto copy = g->nodes()->create<luci::CircleAdd>();
+ luci::copy_quantparam(node, copy);
+
+ const auto *qparam_node = node->quantparam();
+ const auto *qparam_copy = copy->quantparam();
+ ASSERT_EQ(qparam_node->scale, qparam_copy->scale);
+ ASSERT_EQ(qparam_node->zerop, qparam_copy->zerop);
+ ASSERT_EQ(qparam_node->quantized_dimension, qparam_copy->quantized_dimension);
+}
+
+TEST(CircleNodeCloneTest, copy_quantparam_NEG)
+{
+ auto g = loco::make_graph();
+ auto node = build_simple_add_graph(g.get());
+
+ node->quantparam(nullptr);
+
+ auto copy = g->nodes()->create<luci::CircleAdd>();
+ luci::copy_quantparam(node, copy);
+
+ const auto *qparam_copy = copy->quantparam();
+ ASSERT_EQ(qparam_copy, nullptr);
+}
diff --git a/compiler/luci/lang/src/Nodes/CircleFullyConnected.test.cpp b/compiler/luci/lang/src/Nodes/CircleFullyConnected.test.cpp
index bb0e3c51b..15a780085 100644
--- a/compiler/luci/lang/src/Nodes/CircleFullyConnected.test.cpp
+++ b/compiler/luci/lang/src/Nodes/CircleFullyConnected.test.cpp
@@ -32,6 +32,7 @@ TEST(CircleFullyConnectedTest, constructor)
ASSERT_EQ(nullptr, fc_node.weights());
ASSERT_EQ(nullptr, fc_node.bias());
ASSERT_EQ(luci::FusedActFunc::UNDEFINED, fc_node.fusedActivationFunction());
+ ASSERT_EQ(false, fc_node.keep_num_dims());
}
TEST(CircleFullyConnectedTest, input_NEG)
diff --git a/compiler/luci/lang/src/Nodes/CircleSVDF.test.cpp b/compiler/luci/lang/src/Nodes/CircleSVDF.test.cpp
new file mode 100644
index 000000000..833ae0732
--- /dev/null
+++ b/compiler/luci/lang/src/Nodes/CircleSVDF.test.cpp
@@ -0,0 +1,101 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/IR/Nodes/CircleSVDF.h"
+
+#include "luci/IR/CircleDialect.h"
+#include "luci/IR/CircleNodeVisitor.h"
+
+#include <gtest/gtest.h>
+
+TEST(CircleSVDFTest, constructor)
+{
+ luci::CircleSVDF svdf_node;
+
+ ASSERT_EQ(luci::CircleDialect::get(), svdf_node.dialect());
+ ASSERT_EQ(luci::CircleOpcode::SVDF, svdf_node.opcode());
+
+ ASSERT_EQ(nullptr, svdf_node.input());
+ ASSERT_EQ(nullptr, svdf_node.weight_feature());
+ ASSERT_EQ(nullptr, svdf_node.weight_time());
+ ASSERT_EQ(nullptr, svdf_node.bias());
+ ASSERT_EQ(nullptr, svdf_node.input_activation_state());
+
+ ASSERT_EQ(false, svdf_node.asymmetric_quantize_inputs());
+ ASSERT_EQ(0, svdf_node.svdf_rank());
+}
+
+TEST(CircleSVDFTest, input_NEG)
+{
+ luci::CircleSVDF svdf_node;
+ luci::CircleSVDF node;
+
+ svdf_node.input(&node);
+ svdf_node.weight_feature(&node);
+ svdf_node.weight_time(&node);
+ svdf_node.bias(&node);
+ svdf_node.input_activation_state(&node);
+
+ ASSERT_NE(nullptr, svdf_node.input());
+ ASSERT_NE(nullptr, svdf_node.weight_feature());
+ ASSERT_NE(nullptr, svdf_node.weight_time());
+ ASSERT_NE(nullptr, svdf_node.bias());
+ ASSERT_NE(nullptr, svdf_node.input_activation_state());
+
+ svdf_node.input(nullptr);
+ svdf_node.weight_feature(nullptr);
+ svdf_node.weight_time(nullptr);
+ svdf_node.bias(nullptr);
+ svdf_node.input_activation_state(nullptr);
+
+ ASSERT_EQ(nullptr, svdf_node.input());
+ ASSERT_EQ(nullptr, svdf_node.weight_feature());
+ ASSERT_EQ(nullptr, svdf_node.weight_time());
+ ASSERT_EQ(nullptr, svdf_node.bias());
+ ASSERT_EQ(nullptr, svdf_node.input_activation_state());
+}
+
+TEST(CircleSVDFTest, arity_NEG)
+{
+ luci::CircleSVDF svdf_node;
+
+ ASSERT_NO_THROW(svdf_node.arg(4));
+ ASSERT_THROW(svdf_node.arg(5), std::out_of_range);
+}
+
+TEST(CircleSVDFTest, visit_mutable_NEG)
+{
+ struct TestVisitor final : public luci::CircleNodeMutableVisitor<void>
+ {
+ };
+
+ luci::CircleSVDF svdf_node;
+
+ TestVisitor tv;
+ ASSERT_THROW(svdf_node.accept(&tv), std::exception);
+}
+
+TEST(CircleSVDFTest, visit_NEG)
+{
+ struct TestVisitor final : public luci::CircleNodeVisitor<void>
+ {
+ };
+
+ luci::CircleSVDF svdf_node;
+
+ TestVisitor tv;
+ ASSERT_THROW(svdf_node.accept(&tv), std::exception);
+}
diff --git a/compiler/luci/lang/src/Nodes/CircleVariable.test.cpp b/compiler/luci/lang/src/Nodes/CircleVariable.test.cpp
new file mode 100644
index 000000000..e1864f8da
--- /dev/null
+++ b/compiler/luci/lang/src/Nodes/CircleVariable.test.cpp
@@ -0,0 +1,61 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/IR/Nodes/CircleVariable.h"
+
+#include "luci/IR/CircleDialect.h"
+#include "luci/IR/CircleNodeVisitor.h"
+
+#include <gtest/gtest.h>
+
+TEST(CircleVariableTest, constructor)
+{
+ luci::CircleVariable var_node;
+
+ ASSERT_EQ(luci::CircleDialect::get(), var_node.dialect());
+ ASSERT_EQ(luci::CircleOpcode::CIRCLEVARIABLE, var_node.opcode());
+}
+
+TEST(CircleVariableTest, arity_NEG)
+{
+ luci::CircleVariable var_node;
+
+ ASSERT_THROW(var_node.arg(0), std::out_of_range);
+}
+
+TEST(CircleVariableTest, visit_mutable_NEG)
+{
+ struct TestVisitor final : public luci::CircleNodeMutableVisitor<void>
+ {
+ };
+
+ luci::CircleVariable var_node;
+
+ TestVisitor tv;
+ ASSERT_THROW(var_node.accept(&tv), std::exception);
+}
+
+TEST(CircleVariableTest, visit_NEG)
+{
+ struct TestVisitor final : public luci::CircleNodeVisitor<void>
+ {
+ };
+
+ luci::CircleVariable var_node;
+
+ TestVisitor tv;
+ ASSERT_THROW(var_node.accept(&tv), std::exception);
+}
diff --git a/compiler/luci/logex/CMakeLists.txt b/compiler/luci/logex/CMakeLists.txt
index aed9fb79b..b8a2111dd 100644
--- a/compiler/luci/logex/CMakeLists.txt
+++ b/compiler/luci/logex/CMakeLists.txt
@@ -1,5 +1,7 @@
# TODO Find how to test logging-ex utility
file(GLOB_RECURSE SOURCES "src/*.cpp")
+file(GLOB_RECURSE TESTS "src/*.test.cpp")
+list(REMOVE_ITEM SOURCES ${TESTS})
if (NOT LUCI_LIBRARY_TYPE)
set(LUCI_LIBRARY_TYPE "SHARED")
@@ -13,7 +15,17 @@ target_link_libraries(luci_logex PRIVATE luci_log)
target_link_libraries(luci_logex PRIVATE luci_lang)
target_link_libraries(luci_logex PRIVATE hermes_std)
target_link_libraries(luci_logex PRIVATE nncc_common)
-target_link_libraries(luci_logex PRIVATE pepper_str)
install(TARGETS luci_logex DESTINATION lib)
install(DIRECTORY include/ DESTINATION include
FILES_MATCHING PATTERN "*.h")
+
+if(NOT ENABLE_TEST)
+ return()
+endif(NOT ENABLE_TEST)
+
+nnas_find_package(GTest REQUIRED)
+
+GTest_AddTest(luci_logex_test ${TESTS})
+target_include_directories(luci_logex_test PRIVATE src)
+target_link_libraries(luci_logex_test luci_logex)
+target_link_libraries(luci_logex_test luci_lang)
diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp b/compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp
new file mode 100644
index 000000000..eff0830b4
--- /dev/null
+++ b/compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp
@@ -0,0 +1,265 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License")
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleNodeSummaryBuilder.h"
+#include "CircleNodeSummaryBuilders.h"
+
+#include <luci/IR/CircleDialect.h>
+
+#include <memory>
+
+namespace
+{
+
+std::string circle_opname(luci::CircleOpcode opcode)
+{
+ static const std::string prefix{"circle."};
+
+ switch (opcode)
+ {
+#define CIRCLE_NODE(OPCODE, CLASS) \
+ case luci::CircleOpcode::OPCODE: \
+ return prefix + #OPCODE;
+#define CIRCLE_VNODE CIRCLE_NODE
+#include <luci/IR/CircleNodes.lst>
+#undef CIRCLE_VNODE
+#undef CIRCLE_NODE
+ default:
+ break;
+ };
+
+ return prefix + "Invalid";
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool CircleNodeSummaryBuilder::build(const loco::Node *node, const locop::SymbolTable *tbl,
+ locop::NodeSummary &s)
+{
+ if (node->dialect() != luci::CircleDialect::get())
+ return false;
+
+ auto ptr_to_str = [](const void *ptr) {
+ std::stringstream ss;
+ ss << ptr;
+ return ss.str();
+ };
+
+ auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
+ if (const auto builder = create_builder(circle_node))
+ {
+ if (!builder->validate(circle_node))
+ {
+ s.state(locop::NodeDesc::State::Invalid);
+ return false;
+ }
+
+ auto input_names = builder->get_input_names(circle_node);
+ assert(node->arity() == input_names.size());
+ for (uint32_t i = 0; i < node->arity(); ++i)
+ s.args().append(input_names.at(i), tbl->lookup(node->arg(i)));
+
+ builder->build_attributes(circle_node, s);
+ builder->update_status(s);
+
+ s.opname(circle_opname(circle_node->opcode()));
+ s.comments().append("[" + circle_node->name() + "] = " + ptr_to_str(node));
+
+ return true;
+ }
+ else
+ {
+ // When SummaryBuilder is not implemented, return false
+ return false;
+ }
+}
+
+bool CircleNodeSummaryBuilder::validate(const luci::CircleNode *) { return true; }
+
+std::vector<std::string> CircleNodeSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ // Return empty names for default
+ return std::vector<std::string>();
+}
+
+void CircleNodeSummaryBuilder::build_attributes(const luci::CircleNode *, locop::NodeSummary &)
+{
+ // Do nothing for default
+}
+
+void CircleNodeSummaryBuilder::update_status(locop::NodeSummary &s)
+{
+ s.state(locop::NodeDesc::State::Complete);
+}
+
+std::unique_ptr<CircleNodeSummaryBuilder>
+CircleNodeSummaryBuilder::create_builder(const luci::CircleNode *node)
+{
+ switch (node->opcode())
+ {
+#define CIRCLE_NODE(OPCODE, CLASS) \
+ case luci::CircleOpcode::OPCODE: \
+ { \
+ return std::make_unique<CLASS>(); \
+ }
+
+ CIRCLE_NODE(ABS, CircleAbsSummaryBuilder)
+ CIRCLE_NODE(ADD, CircleAddSummaryBuilder)
+ CIRCLE_NODE(ADD_N, CircleAddNSummaryBuilder)
+ CIRCLE_NODE(ARG_MAX, CircleArgMaxSummaryBuilder)
+ CIRCLE_NODE(ARG_MIN, CircleArgMinSummaryBuilder)
+ CIRCLE_NODE(AVERAGE_POOL_2D, CircleAveragePool2DSummaryBuilder)
+ CIRCLE_NODE(BATCH_MATMUL, CircleBatchMatMulSummaryBuilder)
+ CIRCLE_NODE(BATCH_TO_SPACE_ND, CircleBatchToSpaceNDSummaryBuilder)
+ CIRCLE_NODE(BCQ_FULLY_CONNECTED, CircleBCQFullyConnectedSummaryBuilder)
+ CIRCLE_NODE(BCQ_GATHER, CircleBCQGatherSummaryBuilder)
+ CIRCLE_NODE(BIDIRECTIONAL_SEQUENCE_LSTM, CircleBidirectionalSequenceLSTMSummaryBuilder)
+ CIRCLE_NODE(CAST, CircleCastSummaryBuilder)
+ CIRCLE_NODE(CEIL, CircleCeilSummaryBuilder)
+ CIRCLE_NODE(CONCATENATION, CircleConcatenationSummaryBuilder)
+ CIRCLE_NODE(CIRCLECONST, CircleConstSummaryBuilder)
+ CIRCLE_NODE(CONV_2D, CircleConv2DSummaryBuilder)
+ CIRCLE_NODE(COS, CircleCosSummaryBuilder)
+ CIRCLE_NODE(CUSTOM, CircleCustomSummaryBuilder)
+ CIRCLE_NODE(DEPTH_TO_SPACE, CircleDepthToSpaceSummaryBuilder)
+ CIRCLE_NODE(DEPTHWISE_CONV_2D, CircleDepthwiseConv2DSummaryBuilder)
+ CIRCLE_NODE(DEQUANTIZE, CircleDequantizeSummaryBuilder)
+ CIRCLE_NODE(DIV, CircleDivSummaryBuilder)
+ CIRCLE_NODE(ELU, CircleEluSummaryBuilder)
+ CIRCLE_NODE(EQUAL, CircleEqualSummaryBuilder)
+ CIRCLE_NODE(EXP, CircleExpSummaryBuilder)
+ CIRCLE_NODE(EXPAND_DIMS, CircleExpandDimsSummaryBuilder)
+ CIRCLE_NODE(FAKE_QUANT, CircleFakeQuantSummaryBuilder)
+ CIRCLE_NODE(FILL, CircleFillSummaryBuilder)
+ CIRCLE_NODE(FLOOR, CircleFloorSummaryBuilder)
+ CIRCLE_NODE(FLOOR_DIV, CircleFloorDivSummaryBuilder)
+ CIRCLE_NODE(FLOOR_MOD, CircleFloorModSummaryBuilder)
+ CIRCLE_NODE(FULLY_CONNECTED, CircleFullyConnectedSummaryBuilder)
+ CIRCLE_NODE(GATHER, CircleGatherSummaryBuilder)
+ CIRCLE_NODE(GATHER_ND, CircleGatherNdSummaryBuilder)
+ CIRCLE_NODE(GREATER, CircleGreaterSummaryBuilder)
+ CIRCLE_NODE(GREATER_EQUAL, CircleGreaterEqualSummaryBuilder)
+ CIRCLE_NODE(IF, CircleIfSummaryBuilder)
+ CIRCLE_NODE(INSTANCE_NORM, CircleInstanceNormSummaryBuilder)
+ CIRCLE_NODE(L2_NORMALIZATION, CircleL2NormalizeSummaryBuilder)
+ CIRCLE_NODE(L2_POOL_2D, CircleL2Pool2DSummaryBuilder)
+ CIRCLE_NODE(LEAKY_RELU, CircleLeakyReluSummaryBuilder)
+ CIRCLE_NODE(LESS, CircleLessSummaryBuilder)
+ CIRCLE_NODE(LESS_EQUAL, CircleLessEqualSummaryBuilder)
+ CIRCLE_NODE(LOCAL_RESPONSE_NORMALIZATION, CircleLocalResponseNormalizationSummaryBuilder)
+ CIRCLE_NODE(LOG, CircleLogSummaryBuilder)
+ CIRCLE_NODE(LOGICAL_AND, CircleLogicalAndSummaryBuilder)
+ CIRCLE_NODE(LOGICAL_NOT, CircleLogicalNotSummaryBuilder)
+ CIRCLE_NODE(LOGICAL_OR, CircleLogicalOrSummaryBuilder)
+ CIRCLE_NODE(LOGISTIC, CircleLogisticSummaryBuilder)
+ CIRCLE_NODE(LOG_SOFTMAX, CircleLogSoftmaxSummaryBuilder)
+ CIRCLE_NODE(MATRIX_DIAG, CircleMatrixDiagSummaryBuilder)
+ CIRCLE_NODE(MATRIX_SET_DIAG, CircleMatrixSetDiagSummaryBuilder)
+ CIRCLE_NODE(MAXIMUM, CircleMaximumSummaryBuilder)
+ CIRCLE_NODE(MAX_POOL_2D, CircleMaxPool2DSummaryBuilder)
+ CIRCLE_NODE(MEAN, CircleMeanSummaryBuilder)
+ CIRCLE_NODE(MINIMUM, CircleMinimumSummaryBuilder)
+ CIRCLE_NODE(MIRROR_PAD, CircleMirrorPadSummaryBuilder)
+ CIRCLE_NODE(MUL, CircleMulSummaryBuilder)
+ CIRCLE_NODE(NEG, CircleNegSummaryBuilder)
+ CIRCLE_NODE(NON_MAX_SUPPRESSION_V4, CircleNonMaxSuppressionV4SummaryBuilder)
+ CIRCLE_NODE(NON_MAX_SUPPRESSION_V5, CircleNonMaxSuppressionV5SummaryBuilder)
+ CIRCLE_NODE(NOT_EQUAL, CircleNotEqualSummaryBuilder)
+ CIRCLE_NODE(ONE_HOT, CircleOneHotSummaryBuilder)
+ CIRCLE_NODE(PACK, CirclePackSummaryBuilder)
+ CIRCLE_NODE(PAD, CirclePadSummaryBuilder)
+ CIRCLE_NODE(PADV2, CirclePadV2SummaryBuilder)
+ CIRCLE_NODE(POW, CirclePowSummaryBuilder)
+ CIRCLE_NODE(PRELU, CirclePReluSummaryBuilder)
+ CIRCLE_NODE(QUANTIZE, CircleQuantizeSummaryBuilder)
+ CIRCLE_NODE(RANGE, CircleRangeSummaryBuilder)
+ CIRCLE_NODE(RANK, CircleRankSummaryBuilder)
+ CIRCLE_NODE(REDUCE_ANY, CircleReduceAnySummaryBuilder)
+ CIRCLE_NODE(REDUCE_MAX, CircleReduceMaxSummaryBuilder)
+ CIRCLE_NODE(REDUCE_MIN, CircleReduceMinSummaryBuilder)
+ CIRCLE_NODE(REDUCE_PROD, CircleReduceProdSummaryBuilder)
+ CIRCLE_NODE(RELU, CircleReluSummaryBuilder)
+ CIRCLE_NODE(RELU6, CircleRelu6SummaryBuilder)
+ CIRCLE_NODE(RELU_N1_TO_1, CircleReluN1To1SummaryBuilder)
+ CIRCLE_NODE(RESHAPE, CircleReshapeSummaryBuilder)
+ CIRCLE_NODE(RESIZE_BILINEAR, CircleResizeBilinearSummaryBuilder)
+ CIRCLE_NODE(RESIZE_NEAREST_NEIGHBOR, CircleResizeNearestNeighborSummaryBuilder)
+ CIRCLE_NODE(REVERSE_SEQUENCE, CircleReverseSequenceSummaryBuilder)
+ CIRCLE_NODE(REVERSE_V2, CircleReverseV2SummaryBuilder)
+ CIRCLE_NODE(ROUND, CircleRoundSummaryBuilder)
+ CIRCLE_NODE(RSQRT, CircleRsqrtSummaryBuilder)
+ CIRCLE_NODE(SCATTER_ND, CircleScatterNdSummaryBuilder)
+ CIRCLE_NODE(SEGMENT_SUM, CircleSegmentSumSummaryBuilder)
+ CIRCLE_NODE(SELECT, CircleSelectSummaryBuilder)
+ CIRCLE_NODE(SELECT_V2, CircleSelectV2SummaryBuilder)
+ CIRCLE_NODE(SHAPE, CircleShapeSummaryBuilder)
+ CIRCLE_NODE(SIN, CircleSinSummaryBuilder)
+ CIRCLE_NODE(SLICE, CircleSliceSummaryBuilder)
+ CIRCLE_NODE(SOFTMAX, CircleSoftmaxSummaryBuilder)
+ CIRCLE_NODE(SPACE_TO_BATCH_ND, CircleSpaceToBatchNDSummaryBuilder)
+ CIRCLE_NODE(SPACE_TO_DEPTH, CircleSpaceToDepthSummaryBuilder)
+ CIRCLE_NODE(SPARSE_TO_DENSE, CircleSparseToDenseSummaryBuilder)
+ CIRCLE_NODE(SPLIT, CircleSplitSummaryBuilder)
+ CIRCLE_NODE(SPLIT_V, CircleSplitVSummaryBuilder)
+ CIRCLE_NODE(SQRT, CircleSqrtSummaryBuilder)
+ CIRCLE_NODE(SQUARE, CircleSquareSummaryBuilder)
+ CIRCLE_NODE(SQUARED_DIFFERENCE, CircleSquaredDifferenceSummaryBuilder)
+ CIRCLE_NODE(SQUEEZE, CircleSqueezeSummaryBuilder)
+ CIRCLE_NODE(STRIDED_SLICE, CircleStridedSliceSummaryBuilder)
+ CIRCLE_NODE(SUB, CircleSubSummaryBuilder)
+ CIRCLE_NODE(SUM, CircleSumSummaryBuilder)
+ CIRCLE_NODE(SVDF, CircleSVDFSummaryBuilder)
+ CIRCLE_NODE(TANH, CircleTanhSummaryBuilder)
+ CIRCLE_NODE(TILE, CircleTileSummaryBuilder)
+ CIRCLE_NODE(TOPK_V2, CircleTopKV2SummaryBuilder)
+ CIRCLE_NODE(TRANSPOSE, CircleTransposeSummaryBuilder)
+ CIRCLE_NODE(TRANSPOSE_CONV, CircleTransposeConvSummaryBuilder)
+ CIRCLE_NODE(UNIDIRECTIONAL_SEQUENCE_LSTM, CircleUnidirectionalSequenceLSTMSummaryBuilder)
+ CIRCLE_NODE(UNIQUE, CircleUniqueSummaryBuilder)
+ CIRCLE_NODE(UNPACK, CircleUnpackSummaryBuilder)
+ CIRCLE_NODE(WHERE, CircleWhereSummaryBuilder)
+ CIRCLE_NODE(WHILE, CircleWhileSummaryBuilder)
+ CIRCLE_NODE(ZEROS_LIKE, CircleZerosLikeSummaryBuilder)
+
+ CIRCLE_NODE(CIRCLEBIDIRECTIONAL_SEQUENCE_LSTM_OUT,
+ CircleBidirectionalSequenceLSTMOutSummaryBuilder)
+ CIRCLE_NODE(CIRCLECUSTOMOUT, CircleCustomOutSummaryBuilder)
+ CIRCLE_NODE(CIRCLEIFOUT, CircleIfOutSummaryBuilder)
+ CIRCLE_NODE(CIRCLEINPUT, CircleInputSummaryBuilder)
+ CIRCLE_NODE(CIRCLENONMAXSUPPRESSIONV4OUT, CircleNonMaxSuppressionV4OutSummaryBuilder)
+ CIRCLE_NODE(CIRCLENONMAXSUPPRESSIONV5OUT, CircleNonMaxSuppressionV5OutSummaryBuilder)
+ CIRCLE_NODE(CIRCLEOUTPUT, CircleOutputSummaryBuilder)
+ CIRCLE_NODE(CIRCLEOUTPUTDUMMY, CircleOutputDummySummaryBuilder)
+ CIRCLE_NODE(CIRCLEOUTPUTEXCLUDE, CircleOutputExcludeSummaryBuilder)
+ CIRCLE_NODE(CIRCLESPLITOUT, CircleSplitOutSummaryBuilder)
+ CIRCLE_NODE(CIRCLESPLITVOUT, CircleSplitVOutSummaryBuilder)
+ CIRCLE_NODE(CIRCLETOPKV2OUT, CircleTopKV2OutSummaryBuilder)
+ CIRCLE_NODE(CIRCLEUNIQUEOUT, CircleUniqueOutSummaryBuilder)
+ CIRCLE_NODE(CIRCLEUNPACKOUT, CircleUnpackOutSummaryBuilder)
+ CIRCLE_NODE(CIRCLEVARIABLE, CircleVariableSummaryBuilder)
+ CIRCLE_NODE(CIRCLEWHILEOUT, CircleWhileOutSummaryBuilder)
+
+ default:
+ return nullptr;
+
+#undef CIRCLE_NODE
+ }
+}
+
+} // namespace luci
diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilder.h b/compiler/luci/logex/src/CircleNodeSummaryBuilder.h
new file mode 100644
index 000000000..e21d77310
--- /dev/null
+++ b/compiler/luci/logex/src/CircleNodeSummaryBuilder.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_LOGEX_CIRCLE_NODE_SUMMARY_BUILDER__
+#define __LUCI_LOGEX_CIRCLE_NODE_SUMMARY_BUILDER__
+
+#include <luci/IR/CircleNode.h>
+#include <locop/NodeSummary.h>
+#include <locop/SymbolTable.h>
+
+#include <memory>
+#include <sstream>
+#include <vector>
+
+namespace luci
+{
+
+class CircleNodeSummaryBuilder
+{
+public:
+ bool build(const loco::Node *node, const locop::SymbolTable *tbl, locop::NodeSummary &s);
+
+private:
+ /**
+ * @brief Template methods for building node summary.
+ * Default behavior is building a node which has no input.
+ */
+ virtual bool validate(const luci::CircleNode *node);
+ virtual std::vector<std::string> get_input_names(const luci::CircleNode *node);
+ virtual void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+ virtual void update_status(locop::NodeSummary &s);
+
+private:
+ std::unique_ptr<CircleNodeSummaryBuilder> create_builder(const luci::CircleNode *node);
+};
+
+} // namespace luci
+
+#endif // __LUCI_LOGEX_CIRCLE_NODE_SUMMARY_BUILDER__
diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilder.test.cpp b/compiler/luci/logex/src/CircleNodeSummaryBuilder.test.cpp
new file mode 100644
index 000000000..89ea213e0
--- /dev/null
+++ b/compiler/luci/logex/src/CircleNodeSummaryBuilder.test.cpp
@@ -0,0 +1,309 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleNodeSummaryBuilder.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <locop/NodeSummary.h>
+#include <locop/SymbolTable.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+class MockSymbolTable : public locop::SymbolTable
+{
+ std::string lookup(const loco::Node *) const override
+ {
+ return "Do nothing because it is mocking Symbol Table!";
+ }
+};
+
+class CircleNodeSummaryBuilderTest : public ::testing::Test
+{
+protected:
+ bool mock_build(const loco::Node *node)
+ {
+ return luci::CircleNodeSummaryBuilder().build(node, &_tbl, _s);
+ }
+
+protected:
+ MockSymbolTable _tbl;
+ locop::NodeSummary _s;
+};
+
+} // namespace
+
+TEST_F(CircleNodeSummaryBuilderTest, Add_validate)
+{
+ luci::CircleAdd node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, Add_validate_fused_NEG)
+{
+ luci::CircleAdd node;
+ node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, AveragePool2D_validate)
+{
+ luci::CircleAveragePool2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ node.padding(luci::Padding::SAME);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, AveragePool2D_validate_fused_NEG)
+{
+ luci::CircleAveragePool2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ node.padding(luci::Padding::SAME);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, AveragePool2D_validate_padding_NEG)
+{
+ luci::CircleAveragePool2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ node.padding(luci::Padding::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, BCQFullyConnected_validate)
+{
+ luci::CircleBCQFullyConnected node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, BCQFullyConnected_validate_fused_NEG)
+{
+ luci::CircleBCQFullyConnected node;
+ node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, Concatenation_validate)
+{
+ luci::CircleConcatenation node(2);
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, Concatenation_validate_fused_NEG)
+{
+ luci::CircleConcatenation node(2);
+ node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, Conv2D_validate)
+{
+ luci::CircleConv2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ node.padding(luci::Padding::SAME);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, Conv2D_validate_fused_NEG)
+{
+ luci::CircleConv2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ node.padding(luci::Padding::SAME);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, Conv2D_validate_padding_NEG)
+{
+ luci::CircleConv2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ node.padding(luci::Padding::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, DepthwiseConv2D_validate)
+{
+ luci::CircleDepthwiseConv2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ node.padding(luci::Padding::SAME);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, DepthwiseConv2D_validate_fused_NEG)
+{
+ luci::CircleDepthwiseConv2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ node.padding(luci::Padding::SAME);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, DepthwiseConv2D_validate_padding_NEG)
+{
+ luci::CircleDepthwiseConv2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ node.padding(luci::Padding::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, FullyConnected_validate)
+{
+ luci::CircleFullyConnected node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, FullyConnected_validate_fused_NEG)
+{
+ luci::CircleFullyConnected node;
+ node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, InstanceNorm_validate)
+{
+ luci::CircleInstanceNorm node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, InstanceNorm_validate_fused_NEG)
+{
+ luci::CircleInstanceNorm node;
+ node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, L2Normalize_validate)
+{
+ luci::CircleL2Normalize node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, L2Normalize_validate_fused_NEG)
+{
+ luci::CircleL2Normalize node;
+ node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, L2Pool2D_validate)
+{
+ luci::CircleL2Pool2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ node.padding(luci::Padding::SAME);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, L2Pool2D_validate_fused_NEG)
+{
+ luci::CircleL2Pool2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ node.padding(luci::Padding::SAME);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, L2Pool2D_validate_padding_NEG)
+{
+ luci::CircleL2Pool2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ node.padding(luci::Padding::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, MaxPool2D_validate)
+{
+ luci::CircleMaxPool2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ node.padding(luci::Padding::SAME);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, MaxPool2D_validate_fused_NEG)
+{
+ luci::CircleMaxPool2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ node.padding(luci::Padding::SAME);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, MaxPool2D_validate_padding_NEG)
+{
+ luci::CircleMaxPool2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ node.padding(luci::Padding::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, MirrorPad_validate)
+{
+ luci::CircleMirrorPad node;
+ node.mode(luci::MirrorPadMode::REFLECT);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, MirrorPad_validate_mirror_padding_NEG)
+{
+ luci::CircleMirrorPad node;
+ node.mode(luci::MirrorPadMode::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, Mul_validate)
+{
+ luci::CircleMul node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, Mul_validate_fused_NEG)
+{
+ luci::CircleMul node;
+ node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, SVDF_validate)
+{
+ luci::CircleSVDF node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, SVDF_validate_fused_NEG)
+{
+ luci::CircleSVDF node;
+ node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, TransposeConv_validate)
+{
+ luci::CircleTransposeConv node;
+ node.padding(luci::Padding::SAME);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, TransposeConv_validate_padding_NEG)
+{
+ luci::CircleTransposeConv node;
+ node.padding(luci::Padding::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp b/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp
new file mode 100644
index 000000000..6df9270e3
--- /dev/null
+++ b/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp
@@ -0,0 +1,1128 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleNodeSummaryBuilders.h"
+
+#include <luci/IR/CircleNode.h>
+#include <luci/IR/CircleNodes.h>
+#include <loco/IR/Node.h>
+
+#include <string>
+#include <vector>
+
+namespace
+{
+
+std::string to_str(loco::DataType type)
+{
+ switch (type)
+ {
+ case loco::DataType::U8:
+ return "UINT8";
+ case loco::DataType::U16:
+ return "UINT16";
+ case loco::DataType::U32:
+ return "UINT32";
+ case loco::DataType::U64:
+ return "UINT64";
+
+ case loco::DataType::S8:
+ return "INT8";
+ case loco::DataType::S16:
+ return "INT16";
+ case loco::DataType::S32:
+ return "INT32";
+ case loco::DataType::S64:
+ return "INT64";
+
+ case loco::DataType::FLOAT16:
+ return "FLOAT16";
+ case loco::DataType::FLOAT32:
+ return "FLOAT32";
+ case loco::DataType::FLOAT64:
+ return "FLOAT64";
+
+ case loco::DataType::BOOL:
+ return "BOOL";
+
+ default:
+ return "Error";
+ }
+}
+
+std::string to_str(bool value) { return value ? "true" : "false"; }
+
+std::string to_str(luci::FusedActFunc fused)
+{
+ switch (fused)
+ {
+ case luci::FusedActFunc::NONE:
+ return "NONE";
+ case luci::FusedActFunc::RELU:
+ return "RELU";
+ case luci::FusedActFunc::RELU_N1_TO_1:
+ return "RELU_N1_TO_1";
+ case luci::FusedActFunc::RELU6:
+ return "RELU6";
+ case luci::FusedActFunc::TANH:
+ return "TANH";
+ case luci::FusedActFunc::SIGN_BIT:
+ return "SIGN_BIT";
+ default:
+ return "Error";
+ }
+}
+
+std::string to_str(luci::Padding padding)
+{
+ switch (padding)
+ {
+ case luci::Padding::SAME:
+ return "SAME";
+ case luci::Padding::VALID:
+ return "VALID";
+ default:
+ return "Error";
+ }
+}
+
+std::string to_str(const luci::Stride *stride)
+{
+ return std::to_string(stride->h()) + "," + std::to_string(stride->w());
+}
+
+std::string to_str(const luci::Filter *filter)
+{
+ return std::to_string(filter->h()) + "," + std::to_string(filter->w());
+}
+
+std::string to_str(luci::MirrorPadMode mode)
+{
+ switch (mode)
+ {
+ case luci::MirrorPadMode::REFLECT:
+ return "REFLECT";
+ case luci::MirrorPadMode::SYMMETRIC:
+ return "SYMMETRIC";
+ default:
+ return "Error";
+ }
+}
+
+} // namespace
+
+namespace luci
+{
+
+std::vector<std::string> CircleNodeWithXSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"x"};
+}
+
+std::vector<std::string>
+CircleNodeWithINPUTSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input"};
+}
+
+std::vector<std::string> CircleNodeWithXYSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"x", "y"};
+}
+
+std::vector<std::string>
+CircleNodeWithFEATURESSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"features"};
+}
+
+} // namespace luci
+
+namespace luci
+{
+
+bool CircleAddSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto add = loco::must_cast<const luci::CircleAdd *>(node);
+ if (add->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+void CircleAddSummaryBuilder::build_attributes(const luci::CircleNode *node, locop::NodeSummary &s)
+{
+ auto add = loco::must_cast<const luci::CircleAdd *>(node);
+ s.args().append("fused_activation_function", to_str(add->fusedActivationFunction()));
+}
+
+std::vector<std::string> CircleAddNSummaryBuilder::get_input_names(const luci::CircleNode *node)
+{
+ return std::vector<std::string>(node->arity(), "inputs");
+}
+
+std::vector<std::string> CircleArgMaxSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "dimension"};
+}
+
+void CircleArgMaxSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto argmax = loco::must_cast<const luci::CircleArgMax *>(node);
+ s.args().append("output_type", to_str(argmax->output_type()));
+}
+
+std::vector<std::string> CircleArgMinSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "dimension"};
+}
+
+void CircleArgMinSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto argmin = loco::must_cast<const luci::CircleArgMin *>(node);
+ s.args().append("output_type", to_str(argmin->output_type()));
+}
+
+bool CircleAveragePool2DSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto avgpool = loco::must_cast<const luci::CircleAveragePool2D *>(node);
+ if (avgpool->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return false;
+ if (avgpool->padding() == luci::Padding::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+std::vector<std::string>
+CircleAveragePool2DSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"value"};
+}
+
+void CircleAveragePool2DSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto avgpool = loco::must_cast<const luci::CircleAveragePool2D *>(node);
+ s.args().append("filter(h,w)", to_str(avgpool->filter()));
+ s.args().append("stride(h,w)", to_str(avgpool->stride()));
+ s.args().append("padding", to_str(avgpool->padding()));
+ s.args().append("fused_activation_function", to_str(avgpool->fusedActivationFunction()));
+}
+
+void CircleBatchMatMulSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto batchmatmul = loco::must_cast<const luci::CircleBatchMatMul *>(node);
+ s.args().append("adj_x", to_str(batchmatmul->adj_x()));
+ s.args().append("adj_y", to_str(batchmatmul->adj_y()));
+}
+
+std::vector<std::string>
+CircleBatchToSpaceNDSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "block_shape", "crops"};
+}
+
+bool CircleBCQFullyConnectedSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto bcq_fc = loco::must_cast<const luci::CircleBCQFullyConnected *>(node);
+ if (bcq_fc->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+std::vector<std::string>
+CircleBCQFullyConnectedSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "weights_scales", "weights_binary", "bias", "weights_clusters"};
+}
+
+void CircleBCQFullyConnectedSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto bcq_fc = loco::must_cast<const luci::CircleBCQFullyConnected *>(node);
+ s.args().append("fused_activation_function", to_str(bcq_fc->fusedActivationFunction()));
+ s.args().append("weights_hidden_size", std::to_string(bcq_fc->weights_hidden_size()));
+}
+
+std::vector<std::string> CircleBCQGatherSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input_scales", "input_binary", "indices", "input_clusters"};
+}
+
+void CircleBCQGatherSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto bcq_gather = loco::must_cast<const luci::CircleBCQGather *>(node);
+ s.args().append("axis", std::to_string(bcq_gather->axis()));
+ s.args().append("input_hidden_size", std::to_string(bcq_gather->input_hidden_size()));
+}
+
+std::vector<std::string>
+CircleBidirectionalSequenceLSTMSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input",
+ "fw_input_to_input_weights",
+ "fw_input_to_forget_weights",
+ "fw_input_to_cell_weights",
+ "fw_input_to_output_weights",
+ "fw_recurrent_to_input_weights",
+ "fw_recurrent_to_forget_weights",
+ "fw_recurrent_to_cell_weights",
+ "fw_recurrent_to_output_weights",
+ "fw_cell_to_input_weights",
+ "fw_cell_to_forget_weights",
+ "fw_cell_to_output_weights",
+ "fw_input_gate_bias",
+ "fw_forget_gate_bias",
+ "fw_cell_gate_bias",
+ "fw_output_gate_bias",
+ "fw_projection_weights",
+ "fw_projection_bias",
+ "bw_input_to_input_weights",
+ "bw_input_to_forget_weights",
+ "bw_input_to_cell_weights",
+ "bw_input_to_output_weights",
+ "bw_recurrent_to_input_weights",
+ "bw_recurrent_to_forget_weights",
+ "bw_recurrent_to_cell_weights",
+ "bw_recurrent_to_output_weights",
+ "bw_cell_to_input_weights",
+ "bw_cell_to_forget_weights",
+ "bw_cell_to_output_weights",
+ "bw_input_gate_bias",
+ "bw_forget_gate_bias",
+ "bw_cell_gate_bias",
+ "bw_output_gate_bias",
+ "bw_projection_weights",
+ "bw_projection_bias",
+ "fw_activation_state",
+ "fw_cell_state",
+ "bw_activation_state",
+ "bw_cell_state",
+ "auxillary_input",
+ "fw_auxillary_input_to_input_weights",
+ "fw_auxillary_input_to_forget_weights",
+ "fw_auxillary_input_to_cell_weights",
+ "fw_auxillary_input_to_output_weights",
+ "bw_auxillary_input_to_input_weights",
+ "bw_auxillary_input_to_forget_weights",
+ "bw_auxillary_input_to_cell_weights",
+ "bw_auxillary_input_to_output_weights"};
+}
+
+void CircleBidirectionalSequenceLSTMSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto lstm = loco::must_cast<const luci::CircleBidirectionalSequenceLSTM *>(node);
+ s.args().append("cell_clip", to_str(lstm->cell_clip()));
+ s.args().append("proj_clip", to_str(lstm->proj_clip()));
+ s.args().append("merge_outputs", to_str(lstm->merge_outputs()));
+ s.args().append("time_major", to_str(lstm->time_major()));
+ s.args().append("asymmetric_quantize_inputs", to_str(lstm->asymmetric_quantize_inputs()));
+}
+
+std::vector<std::string> CircleCastSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"x"};
+}
+
+void CircleCastSummaryBuilder::build_attributes(const luci::CircleNode *node, locop::NodeSummary &s)
+{
+ auto cast = loco::must_cast<const luci::CircleCast *>(node);
+ s.args().append("in_data_type", to_str(cast->in_data_type()));
+ s.args().append("out_data_type", to_str(cast->out_data_type()));
+}
+
+bool CircleConcatenationSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto concat = loco::must_cast<const luci::CircleConcatenation *>(node);
+ if (concat->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+std::vector<std::string>
+CircleConcatenationSummaryBuilder::get_input_names(const luci::CircleNode *node)
+{
+ return std::vector<std::string>(node->arity(), "values");
+}
+
+void CircleConcatenationSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto concat = loco::must_cast<const luci::CircleConcatenation *>(node);
+ s.args().append("axis", std::to_string(concat->axis()));
+ s.args().append("fused_activation_function", to_str(concat->fusedActivationFunction()));
+}
+
+void CircleConstSummaryBuilder::update_status(locop::NodeSummary &s)
+{
+ s.state(locop::NodeDesc::State::PartiallyKnown);
+}
+
+bool CircleConv2DSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto conv2d = loco::must_cast<const luci::CircleConv2D *>(node);
+ if (conv2d->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return false;
+ if (conv2d->padding() == luci::Padding::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+std::vector<std::string> CircleConv2DSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "filter", "bias"};
+}
+
+void CircleConv2DSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto conv2d = loco::must_cast<const luci::CircleConv2D *>(node);
+ s.args().append("stride(h,w)", to_str(conv2d->stride()));
+ s.args().append("dilation(h,w)", to_str(conv2d->dilation()));
+ s.args().append("padding", to_str(conv2d->padding()));
+ s.args().append("fused_activation_function", to_str(conv2d->fusedActivationFunction()));
+}
+
+std::vector<std::string> CircleCustomSummaryBuilder::get_input_names(const luci::CircleNode *node)
+{
+ auto input_names = std::vector<std::string>();
+ for (uint32_t i = 0; i < node->arity(); ++i)
+ input_names.push_back("input" + std::to_string(i));
+ return input_names;
+}
+
+void CircleCustomSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto custom = loco::must_cast<const luci::CircleCustom *>(node);
+ s.args().append("custom_code", custom->custom_code());
+}
+
+void CircleDepthToSpaceSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto depth_to_space = loco::must_cast<const luci::CircleDepthToSpace *>(node);
+ s.args().append("block_size", std::to_string(depth_to_space->block_size()));
+}
+
+bool CircleDepthwiseConv2DSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto dw_conv2d = loco::must_cast<const luci::CircleDepthwiseConv2D *>(node);
+ if (dw_conv2d->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return false;
+ if (dw_conv2d->padding() == luci::Padding::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+std::vector<std::string>
+CircleDepthwiseConv2DSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "filter", "bias"};
+}
+
+void CircleDepthwiseConv2DSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto dw_conv2d = loco::must_cast<const luci::CircleDepthwiseConv2D *>(node);
+ s.args().append("stride(h,w)", to_str(dw_conv2d->stride()));
+ s.args().append("dilation(h,w)", to_str(dw_conv2d->dilation()));
+ s.args().append("padding", to_str(dw_conv2d->padding()));
+ s.args().append("depthMultiplier", std::to_string(dw_conv2d->depthMultiplier()));
+ s.args().append("fused_activation_function", to_str(dw_conv2d->fusedActivationFunction()));
+}
+
+std::vector<std::string> CircleExpandDimsSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "axis"};
+}
+
+std::vector<std::string> CircleFakeQuantSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"inputs"};
+}
+
+void CircleFakeQuantSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto fake_quant = loco::must_cast<const luci::CircleFakeQuant *>(node);
+ s.args().append("min", std::to_string(fake_quant->min()));
+ s.args().append("max", std::to_string(fake_quant->max()));
+ s.args().append("num_bits", std::to_string(fake_quant->num_bits()));
+ s.args().append("narrow_range", to_str(fake_quant->narrow_range()));
+}
+
+std::vector<std::string> CircleFillSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"dims", "value"};
+}
+
+bool CircleFullyConnectedSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto fc = loco::must_cast<const luci::CircleFullyConnected *>(node);
+ if (fc->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+std::vector<std::string>
+CircleFullyConnectedSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "weights", "bias"};
+}
+
+void CircleFullyConnectedSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto fc = loco::must_cast<const luci::CircleFullyConnected *>(node);
+ s.args().append("fused_activation_function", to_str(fc->fusedActivationFunction()));
+}
+
+std::vector<std::string> CircleGatherSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"params", "indices"};
+}
+
+void CircleGatherSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto gather = loco::must_cast<const luci::CircleGather *>(node);
+ s.args().append("axis", std::to_string(gather->axis()));
+}
+
+std::vector<std::string> CircleGatherNdSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"params", "indices"};
+}
+
+std::vector<std::string> CircleIfSummaryBuilder::get_input_names(const luci::CircleNode *node)
+{
+ auto circle_if = loco::must_cast<const luci::CircleIf *>(node);
+
+ auto input_names = std::vector<std::string>();
+ input_names.push_back("cond");
+ for (uint32_t i = 0; i < circle_if->input_count(); ++i)
+ input_names.push_back("input");
+
+ return input_names;
+}
+
+void CircleIfSummaryBuilder::build_attributes(const luci::CircleNode *node, locop::NodeSummary &s)
+{
+ auto circle_if = loco::must_cast<const luci::CircleIf *>(node);
+
+ if (circle_if->then_graph() != nullptr)
+ s.args().append("then_graph", circle_if->then_graph()->name());
+ else
+ s.args().append("then_branch", std::to_string(circle_if->then_branch()));
+
+ if (circle_if->else_graph() != nullptr)
+ s.args().append("else_graph", circle_if->else_graph()->name());
+ else
+ s.args().append("else_branch", std::to_string(circle_if->else_branch()));
+}
+
+bool CircleInstanceNormSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto instnorm = loco::must_cast<const luci::CircleInstanceNorm *>(node);
+ if (instnorm->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+std::vector<std::string> CircleInstanceNormSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "gamma", "beta"};
+}
+
+void CircleInstanceNormSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto instnorm = loco::must_cast<const luci::CircleInstanceNorm *>(node);
+ s.args().append("epsilon", std::to_string(instnorm->epsilon()));
+ s.args().append("fused_activation_function", to_str(instnorm->fusedActivationFunction()));
+}
+
+bool CircleL2NormalizeSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto l2norm = loco::must_cast<const luci::CircleL2Normalize *>(node);
+ if (l2norm->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+std::vector<std::string> CircleL2NormalizeSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"x"};
+}
+
+void CircleL2NormalizeSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto l2norm = loco::must_cast<const luci::CircleL2Normalize *>(node);
+ s.args().append("fused_activation_function", to_str(l2norm->fusedActivationFunction()));
+}
+
+bool CircleL2Pool2DSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto l2pool = loco::must_cast<const luci::CircleL2Pool2D *>(node);
+ if (l2pool->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return false;
+ if (l2pool->padding() == luci::Padding::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+std::vector<std::string> CircleL2Pool2DSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"value"};
+}
+
+void CircleL2Pool2DSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto l2pool = loco::must_cast<const luci::CircleL2Pool2D *>(node);
+ s.args().append("filter(h,w)", to_str(l2pool->filter()));
+ s.args().append("stride(h,w)", to_str(l2pool->stride()));
+ s.args().append("padding", to_str(l2pool->padding()));
+ s.args().append("fused_activation_function", to_str(l2pool->fusedActivationFunction()));
+}
+
+void CircleLeakyReluSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto leaky_relu = loco::must_cast<const luci::CircleLeakyRelu *>(node);
+ s.args().append("alpha", std::to_string(leaky_relu->alpha()));
+}
+
+void CircleLocalResponseNormalizationSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto lrn = loco::must_cast<const luci::CircleLocalResponseNormalization *>(node);
+ s.args().append("radius", std::to_string(lrn->radius()));
+ s.args().append("bias", std::to_string(lrn->bias()));
+ s.args().append("alpha", std::to_string(lrn->alpha()));
+ s.args().append("beta", std::to_string(lrn->beta()));
+}
+
+std::vector<std::string> CircleLogSoftmaxSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"logits"};
+}
+
+std::vector<std::string> CircleMatrixDiagSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"diagonal"};
+}
+
+std::vector<std::string>
+CircleMatrixSetDiagSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "diagonal"};
+}
+
+bool CircleMaxPool2DSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto maxpool = loco::must_cast<const luci::CircleMaxPool2D *>(node);
+ if (maxpool->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return false;
+ if (maxpool->padding() == luci::Padding::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+std::vector<std::string> CircleMaxPool2DSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"value"};
+}
+
+void CircleMaxPool2DSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto maxpool = loco::must_cast<const luci::CircleMaxPool2D *>(node);
+ s.args().append("filter(h,w)", to_str(maxpool->filter()));
+ s.args().append("stride(h,w)", to_str(maxpool->stride()));
+ s.args().append("padding", to_str(maxpool->padding()));
+ s.args().append("fused_activation_function", to_str(maxpool->fusedActivationFunction()));
+}
+
+bool CircleMirrorPadSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto mirror_pad = loco::must_cast<const luci::CircleMirrorPad *>(node);
+ if (mirror_pad->mode() == luci::MirrorPadMode::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+std::vector<std::string> CircleMirrorPadSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "paddings"};
+}
+
+void CircleMirrorPadSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto mirror_pad = loco::must_cast<const luci::CircleMirrorPad *>(node);
+ s.args().append("mode", to_str(mirror_pad->mode()));
+}
+
+bool CircleMulSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto mul = loco::must_cast<const luci::CircleMul *>(node);
+ if (mul->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+void CircleMulSummaryBuilder::build_attributes(const luci::CircleNode *node, locop::NodeSummary &s)
+{
+ auto mul = loco::must_cast<const luci::CircleMul *>(node);
+ s.args().append("fused_activation_function", to_str(mul->fusedActivationFunction()));
+}
+
+std::vector<std::string>
+CircleNonMaxSuppressionV4SummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"boxes", "scores", "max_output_size", "iou_threshold", "score_threshold"};
+}
+
+std::vector<std::string>
+CircleNonMaxSuppressionV5SummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"boxes", "scores", "max_output_size",
+ "iou_threshold", "score_threshold", "soft_nms_sigma"};
+}
+
+std::vector<std::string> CircleOneHotSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"indices", "depth", "on_value", "off_value"};
+}
+
+void CircleOneHotSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto onehot = loco::must_cast<const luci::CircleOneHot *>(node);
+ s.args().append("axis", std::to_string(onehot->axis()));
+}
+
+std::vector<std::string> CirclePackSummaryBuilder::get_input_names(const luci::CircleNode *node)
+{
+ return std::vector<std::string>(node->arity(), "values");
+}
+
+void CirclePackSummaryBuilder::build_attributes(const luci::CircleNode *node, locop::NodeSummary &s)
+{
+ auto pack = loco::must_cast<const luci::CirclePack *>(node);
+ s.args().append("values_count", std::to_string(pack->values_count()));
+ s.args().append("axis", std::to_string(pack->axis()));
+}
+
+std::vector<std::string> CirclePadSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "paddings"};
+}
+
+std::vector<std::string> CirclePadV2SummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "paddings", "constant_values"};
+}
+
+std::vector<std::string> CirclePReluSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "alpha"};
+}
+
+std::vector<std::string> CircleRangeSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"start", "limit", "delta"};
+}
+
+std::vector<std::string> CircleReshapeSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"tensor", "shape"};
+}
+
+void CircleReshapeSummaryBuilder::update_status(locop::NodeSummary &s)
+{
+ s.state(locop::NodeDesc::State::PartiallyKnown);
+}
+
+std::vector<std::string>
+CircleResizeBilinearSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "size"};
+}
+
+void CircleResizeBilinearSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto resize_bilinear = loco::must_cast<const luci::CircleResizeBilinear *>(node);
+ s.args().append("align_corners", to_str(resize_bilinear->align_corners()));
+ s.args().append("half_pixel_centers", to_str(resize_bilinear->half_pixel_centers()));
+}
+
+std::vector<std::string>
+CircleResizeNearestNeighborSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "size"};
+}
+
+void CircleResizeNearestNeighborSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto resize_nn = loco::must_cast<const luci::CircleResizeNearestNeighbor *>(node);
+ s.args().append("align_corners", to_str(resize_nn->align_corners()));
+}
+
+std::vector<std::string>
+CircleReverseSequenceSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "seq_lengths"};
+}
+
+void CircleReverseSequenceSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto reverse_seq = loco::must_cast<const luci::CircleReverseSequence *>(node);
+ s.args().append("seq_axis", std::to_string(reverse_seq->seq_axis()));
+ s.args().append("batch_axis", std::to_string(reverse_seq->batch_axis()));
+}
+
+std::vector<std::string> CircleReverseV2SummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"tensor", "axis"};
+}
+
+std::vector<std::string> CircleScatterNdSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"indices", "updates", "shape"};
+}
+
+std::vector<std::string> CircleSegmentSumSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "segment_ids"};
+}
+
+std::vector<std::string> CircleSelectSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"condition", "t", "e"};
+}
+
+std::vector<std::string> CircleSelectV2SummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"condition", "t", "e"};
+}
+
+void CircleShapeSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto shape = loco::must_cast<const luci::CircleShape *>(node);
+ s.args().append("out_type", to_str(shape->out_type()));
+}
+
+std::vector<std::string> CircleSliceSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "begin", "size"};
+}
+
+std::vector<std::string> CircleSoftmaxSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"logits"};
+}
+
+void CircleSoftmaxSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto softmax = loco::must_cast<const luci::CircleSoftmax *>(node);
+ s.args().append("beta", to_str(softmax->beta()));
+}
+
+std::vector<std::string>
+CircleSpaceToBatchNDSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "block_shape", "paddings"};
+}
+
+void CircleSpaceToDepthSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto space_to_depth = loco::must_cast<const luci::CircleSpaceToDepth *>(node);
+ s.args().append("block_size", to_str(space_to_depth->block_size()));
+}
+
+std::vector<std::string>
+CircleSparseToDenseSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"indices", "output_shape", "values", "default_value"};
+}
+
+void CircleSparseToDenseSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto sparse_to_dense = loco::must_cast<const luci::CircleSparseToDense *>(node);
+ s.args().append("validate_indices", to_str(sparse_to_dense->validate_indices()));
+}
+
+std::vector<std::string> CircleSplitSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"split_dim", "input"};
+}
+
+void CircleSplitSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto split = loco::must_cast<const luci::CircleSplit *>(node);
+ s.args().append("num_split", std::to_string(split->num_split()));
+}
+
+std::vector<std::string> CircleSplitVSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "size_splits", "split_dim"};
+}
+
+void CircleSplitVSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto split_v = loco::must_cast<const luci::CircleSplitV *>(node);
+ s.args().append("num_split", std::to_string(split_v->num_split()));
+}
+
+void CircleSqueezeSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto squeeze = loco::must_cast<const luci::CircleSqueeze *>(node);
+
+ std::string squeeze_dims = "(";
+ for (size_t i = 0; i < squeeze->squeeze_dims().size(); ++i)
+ {
+ if (i != 0)
+ squeeze_dims += ", ";
+ squeeze_dims += std::to_string(squeeze->squeeze_dims().at(i));
+ }
+ squeeze_dims += ")";
+
+ s.args().append("squeeze_dims", squeeze_dims);
+}
+
+std::vector<std::string> CircleStridedSliceSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "begin", "end", "strides"};
+}
+
+void CircleStridedSliceSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto strided_slice = loco::must_cast<const luci::CircleStridedSlice *>(node);
+ s.args().append("begin_mask", std::to_string(strided_slice->begin_mask()));
+ s.args().append("end_mask", std::to_string(strided_slice->end_mask()));
+ s.args().append("ellipsis_mask", std::to_string(strided_slice->ellipsis_mask()));
+ s.args().append("new_axis_mask", std::to_string(strided_slice->new_axis_mask()));
+ s.args().append("shrink_axis_mask", std::to_string(strided_slice->shrink_axis_mask()));
+}
+
+bool CircleSVDFSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto svdf = loco::must_cast<const luci::CircleSVDF *>(node);
+ if (svdf->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+std::vector<std::string> CircleSVDFSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "weight_feature", "weight_time", "bias", "State"};
+}
+
+void CircleSVDFSummaryBuilder::build_attributes(const luci::CircleNode *node, locop::NodeSummary &s)
+{
+ auto svdf = loco::must_cast<const luci::CircleSVDF *>(node);
+ s.args().append("rank", to_str(svdf->svdf_rank()));
+ s.args().append("asymmetric_quantize_inputs", to_str(svdf->asymmetric_quantize_inputs()));
+ s.args().append("fused_activation_function", to_str(svdf->fusedActivationFunction()));
+}
+
+std::vector<std::string> CircleTileSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "multiples"};
+}
+
+std::vector<std::string> CircleTopKV2SummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "k"};
+}
+
+std::vector<std::string> CircleTransposeSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"a", "perm"};
+}
+
+bool CircleTransposeConvSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto transpose_conv = loco::must_cast<const luci::CircleTransposeConv *>(node);
+ if (transpose_conv->padding() == luci::Padding::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+std::vector<std::string>
+CircleTransposeConvSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"inputSizes", "filter", "outBackProp", "bias"};
+}
+
+void CircleTransposeConvSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto transpose_conv = loco::must_cast<const luci::CircleTransposeConv *>(node);
+ s.args().append("stride(h,w)", to_str(transpose_conv->stride()));
+ s.args().append("padding", to_str(transpose_conv->padding()));
+}
+
+std::vector<std::string>
+CircleUnidirectionalSequenceLSTMSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input",
+ "input_to_input_weights",
+ "input_to_forget_weights",
+ "input_to_cell_weights",
+ "input_to_output_weights",
+ "recurrent_to_input_weights",
+ "recurrent_to_forget_weights",
+ "recurrent_to_cell_weights",
+ "recurrent_to_output_weights",
+ "cell_to_input_weights",
+ "cell_to_forget_weights",
+ "cell_to_output_weights",
+ "input_gate_bias",
+ "forget_gate_bias",
+ "cell_gate_bias",
+ "output_gate_bias",
+ "projection_weights",
+ "projection_bias",
+ "activation_state",
+ "cell_state",
+ "input_layer_norm_coefficients",
+ "forget_layer_norm_coefficients",
+ "cell_layer_norm_coefficients",
+ "output_layer_norm_coefficients"};
+}
+
+void CircleUnidirectionalSequenceLSTMSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto lstm = loco::must_cast<const luci::CircleUnidirectionalSequenceLSTM *>(node);
+ s.args().append("cell_clip", to_str(lstm->cell_clip()));
+ s.args().append("proj_clip", to_str(lstm->proj_clip()));
+ s.args().append("time_major", to_str(lstm->time_major()));
+ s.args().append("asymmetric_quantize_inputs", to_str(lstm->asymmetric_quantize_inputs()));
+}
+
+void CircleUniqueSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto unique = loco::must_cast<const luci::CircleUnique *>(node);
+ s.args().append("idx_out_type", to_str(unique->idx_out_type()));
+}
+
+std::vector<std::string> CircleUnpackSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"value"};
+}
+
+void CircleUnpackSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto unpack = loco::must_cast<const luci::CircleUnpack *>(node);
+ s.args().append("num", std::to_string(unpack->num()));
+ s.args().append("axis", std::to_string(unpack->axis()));
+}
+std::vector<std::string> CircleWhereSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"condition"};
+}
+
+std::vector<std::string> CircleWhileSummaryBuilder::get_input_names(const luci::CircleNode *node)
+{
+ auto circle_while = loco::must_cast<const luci::CircleWhile *>(node);
+
+ auto input_names = std::vector<std::string>();
+ for (uint32_t i = 0; i < circle_while->input_count(); ++i)
+ input_names.push_back("input");
+
+ return input_names;
+}
+
+void CircleWhileSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto circle_while = loco::must_cast<const luci::CircleWhile *>(node);
+
+ if (circle_while->cond_graph() != nullptr)
+ s.args().append("then_graph", circle_while->cond_graph()->name());
+ else
+ s.args().append("then_branch", std::to_string(circle_while->cond_branch()));
+
+ if (circle_while->body_graph() != nullptr)
+ s.args().append("else_graph", circle_while->body_graph()->name());
+ else
+ s.args().append("else_branch", std::to_string(circle_while->body_branch()));
+}
+
+std::vector<std::string> CircleOutputSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"from"};
+}
+
+std::vector<std::string> CircleTopKV2OutSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"topkv2"};
+}
+
+std::vector<std::string> CircleUniqueOutSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"unique"};
+}
+
+std::vector<std::string> CircleUnpackOutSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"unpack"};
+}
+
+std::vector<std::string> CircleWhileOutSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"while"};
+}
+
+} // namespace luci
diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilders.h b/compiler/luci/logex/src/CircleNodeSummaryBuilders.h
new file mode 100644
index 000000000..6cd24b7f1
--- /dev/null
+++ b/compiler/luci/logex/src/CircleNodeSummaryBuilders.h
@@ -0,0 +1,821 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_LOGEX_CIRCLE_NODE_SUMMARY_BUILDERS__
+#define __LUCI_LOGEX_CIRCLE_NODE_SUMMARY_BUILDERS__
+
+#include "CircleNodeSummaryBuilder.h"
+
+#include <luci/IR/CircleNode.h>
+
+#include <string>
+#include <vector>
+
+namespace luci
+{
+
+class CircleNodeWithXSummaryBuilder : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleNodeWithINPUTSummaryBuilder : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleNodeWithXYSummaryBuilder : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleNodeWithFEATURESSummaryBuilder : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+template <class REDUCER_NODE>
+class CircleNodeWithReducerSummaryBuilder : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *)
+ {
+ return {"input", "reduction_indices"};
+ }
+
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s)
+ {
+ auto mean = loco::must_cast<const REDUCER_NODE *>(node);
+ s.args().append("keep_dims", mean->keep_dims() ? "true" : "false");
+ }
+};
+
+} // namespace luci
+
+namespace luci
+{
+
+class CircleAbsSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleAddSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleAddNSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *node);
+};
+
+class CircleArgMaxSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleArgMinSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleAveragePool2DSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleBatchMatMulSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+private:
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleBatchToSpaceNDSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleBCQFullyConnectedSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleBCQGatherSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleBidirectionalSequenceLSTMSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleCastSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleCeilSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleConcatenationSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ std::vector<std::string> get_input_names(const luci::CircleNode *node);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleConstSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ void update_status(locop::NodeSummary &s);
+};
+
+class CircleConv2DSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleCosSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleCustomSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *node);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleDepthToSpaceSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+private:
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleDepthwiseConv2DSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleDequantizeSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+};
+
+class CircleDivSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleEluSummaryBuilder final : public CircleNodeWithFEATURESSummaryBuilder
+{
+};
+
+class CircleEqualSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleExpSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleExpandDimsSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleFakeQuantSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleFillSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleFloorSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleFloorDivSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleFloorModSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleFullyConnectedSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleGatherSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleGatherNdSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleGreaterSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleGreaterEqualSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleIfSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *node);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleInstanceNormSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleL2NormalizeSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleL2Pool2DSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleLeakyReluSummaryBuilder final : public CircleNodeWithFEATURESSummaryBuilder
+{
+private:
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleLessSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleLessEqualSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleLocalResponseNormalizationSummaryBuilder final
+ : public CircleNodeWithINPUTSummaryBuilder
+{
+private:
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleLogSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleLogicalAndSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleLogicalNotSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleLogicalOrSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleLogisticSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleLogSoftmaxSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleMatrixDiagSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleMatrixSetDiagSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleMaximumSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleMaxPool2DSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleMeanSummaryBuilder final : public CircleNodeWithReducerSummaryBuilder<luci::CircleMean>
+{
+};
+
+class CircleMinimumSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleMirrorPadSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleMulSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleNegSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleNonMaxSuppressionV4SummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleNonMaxSuppressionV5SummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleNotEqualSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleOneHotSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CirclePackSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *node);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CirclePadSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CirclePadV2SummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CirclePowSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CirclePReluSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleQuantizeSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+};
+
+class CircleRangeSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleRankSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+};
+
+class CircleReduceAnySummaryBuilder final
+ : public CircleNodeWithReducerSummaryBuilder<luci::CircleReduceAny>
+{
+};
+
+class CircleReduceMaxSummaryBuilder final
+ : public CircleNodeWithReducerSummaryBuilder<luci::CircleReduceMax>
+{
+};
+
+class CircleReduceMinSummaryBuilder final
+ : public CircleNodeWithReducerSummaryBuilder<luci::CircleReduceMin>
+{
+};
+
+class CircleReduceProdSummaryBuilder final
+ : public CircleNodeWithReducerSummaryBuilder<luci::CircleReduceProd>
+{
+};
+
+class CircleReluSummaryBuilder final : public CircleNodeWithFEATURESSummaryBuilder
+{
+};
+
+class CircleRelu6SummaryBuilder final : public CircleNodeWithFEATURESSummaryBuilder
+{
+};
+
+class CircleReluN1To1SummaryBuilder final : public CircleNodeWithFEATURESSummaryBuilder
+{
+};
+
+class CircleReshapeSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void update_status(locop::NodeSummary &s);
+};
+
+class CircleResizeBilinearSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleResizeNearestNeighborSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleReverseSequenceSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleReverseV2SummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleRoundSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleRsqrtSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleScatterNdSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleSegmentSumSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleSelectSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleSelectV2SummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleShapeSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+private:
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleSinSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleSliceSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleSoftmaxSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleSpaceToBatchNDSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleSpaceToDepthSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+private:
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleSparseToDenseSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleSplitSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleSplitVSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleSqrtSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleSquareSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleSquaredDifferenceSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleSqueezeSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+private:
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleStridedSliceSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleSubSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleSumSummaryBuilder final : public CircleNodeWithReducerSummaryBuilder<luci::CircleSum>
+{
+};
+
+class CircleSVDFSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleTanhSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleTileSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleTopKV2SummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleTransposeSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleTransposeConvSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleUnidirectionalSequenceLSTMSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleUniqueSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+private:
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleUnpackSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleWhereSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleWhileSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *node);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleZerosLikeSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+};
+
+class CircleBidirectionalSequenceLSTMOutSummaryBuilder final
+ : public CircleNodeWithINPUTSummaryBuilder
+{
+};
+
+class CircleCustomOutSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+};
+
+class CircleIfOutSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+};
+
+class CircleInputSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+};
+
+class CircleNonMaxSuppressionV4OutSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+};
+
+class CircleNonMaxSuppressionV5OutSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+};
+
+class CircleOutputSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleOutputDummySummaryBuilder final : public CircleNodeSummaryBuilder
+{
+};
+
+class CircleOutputExcludeSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+};
+
+class CircleSplitOutSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+};
+
+class CircleSplitVOutSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+};
+
+class CircleTopKV2OutSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleUniqueOutSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleUnpackOutSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleVariableSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+};
+
+class CircleWhileOutSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+} // namespace luci
+
+#endif // __LUCI_LOGEX_CIRCLE_NODE_SUMMARY_BUILDERS__
diff --git a/compiler/luci/logex/src/FormattedGraph.cpp b/compiler/luci/logex/src/FormattedGraph.cpp
index 0588ed79e..d3b2170b0 100644
--- a/compiler/luci/logex/src/FormattedGraph.cpp
+++ b/compiler/luci/logex/src/FormattedGraph.cpp
@@ -14,6 +14,7 @@
* limitations under the License.
*/
+#include "CircleNodeSummaryBuilder.h"
#include "luci/FormattedGraph.h"
#include <luci/IR/CircleDialect.h>
@@ -25,2179 +26,6 @@
#include <sstream>
#include <vector>
-using namespace luci;
-/**
- * @brief dump std::vector<int64_t> values to stream
- */
-std::ostream &operator<<(std::ostream &os, const std::vector<int64_t> &vi64)
-{
- for (auto vi : vi64)
- {
- os << vi << " ";
- }
- return os;
-}
-
-// For TF lite
-namespace
-{
-
-const char *to_str(loco::DataType type)
-{
- switch (type)
- {
- case loco::DataType::U8:
- return "UINT8";
- case loco::DataType::U16:
- return "UINT16";
- case loco::DataType::U32:
- return "UINT32";
- case loco::DataType::U64:
- return "UINT64";
-
- case loco::DataType::S8:
- return "INT8";
- case loco::DataType::S16:
- return "INT16";
- case loco::DataType::S32:
- return "INT32";
- case loco::DataType::S64:
- return "INT64";
-
- case loco::DataType::FLOAT16:
- return "FLOAT16";
- case loco::DataType::FLOAT32:
- return "FLOAT32";
- case loco::DataType::FLOAT64:
- return "FLOAT64";
-
- case loco::DataType::BOOL:
- return "BOOL";
-
- default:
- return "Error";
- }
-}
-
-const char *to_str(bool value) { return value ? "true" : "false"; }
-
-const char *to_str(luci::FusedActFunc fused)
-{
- switch (fused)
- {
- case luci::FusedActFunc::NONE:
- return "NONE";
- case luci::FusedActFunc::RELU:
- return "RELU";
- case luci::FusedActFunc::RELU_N1_TO_1:
- return "RELU_N1_TO_1";
- case luci::FusedActFunc::RELU6:
- return "RELU6";
- case luci::FusedActFunc::TANH:
- return "TANH";
- case luci::FusedActFunc::SIGN_BIT:
- return "SIGN_BIT";
- default:
- return "Error";
- }
-}
-
-const char *to_str(luci::Padding padding)
-{
- switch (padding)
- {
- case luci::Padding::SAME:
- return "SAME";
- case luci::Padding::VALID:
- return "VALID";
- default:
- return "Error";
- }
-}
-
-const char *to_str(luci::MirrorPadMode mode)
-{
- switch (mode)
- {
- case luci::MirrorPadMode::REFLECT:
- return "REFLECT";
- case luci::MirrorPadMode::SYMMETRIC:
- return "SYMMETRIC";
- default:
- return "Error";
- }
-}
-
-std::string to_str(const luci::Stride *stride)
-{
- return pepper::str(stride->h(), ",", stride->w());
-}
-
-std::string to_str(const luci::Filter *filter)
-{
- return pepper::str(filter->h(), ",", filter->w());
-}
-
-std::string circle_opname(uint32_t opnum)
-{
- static const std::string prefix{"circle."};
-
- switch (static_cast<luci::CircleOpcode>(opnum))
- {
-#define CIRCLE_NODE(OPCODE, CLASS) \
- case luci::CircleOpcode::OPCODE: \
- return prefix + #OPCODE;
-#define CIRCLE_VNODE CIRCLE_NODE
-#include <luci/IR/CircleNodes.lst>
-#undef CIRCLE_VNODE
-#undef CIRCLE_NODE
- default:
- break;
- };
-
- return prefix + "Invalid";
-}
-
-// CircleNodeSummaryBuilder with default implementation
-class CircleNodeSummaryBuilderBase : public locop::NodeSummaryBuilder
-{
-public:
- CircleNodeSummaryBuilderBase(const locop::SymbolTable *tbl) : _tbl{tbl}
- {
- // DO NOTHING
- }
-
-public:
- bool build(const loco::Node *, locop::NodeSummary &s) const final;
-
-protected:
-#define CIRCLE_NODE(OPCODE, CLASS) \
- virtual bool summary(const CLASS *, locop::NodeSummary &) const { return false; }
-#define CIRCLE_VNODE CIRCLE_NODE
-#include <luci/IR/CircleNodes.lst>
-#undef CIRCLE_VNODE
-#undef CIRCLE_NODE
-
-protected:
- const locop::SymbolTable *tbl(void) const { return _tbl; }
-
-private:
- const locop::SymbolTable *_tbl;
-};
-
-template <class CIRCLENODE>
-bool use_x(const locop::SymbolTable *tbl, const CIRCLENODE *node, locop::NodeSummary &s)
-{
- s.args().append("x", tbl->lookup(node->x()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-template <class CIRCLENODE>
-bool use_input(const locop::SymbolTable *tbl, const CIRCLENODE *node, locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-template <class CIRCLENODE>
-bool use_features(const locop::SymbolTable *tbl, const CIRCLENODE *node, locop::NodeSummary &s)
-{
- s.args().append("features", tbl->lookup(node->features()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-template <class CIRCLENODE>
-bool use_xy(const locop::SymbolTable *tbl, const CIRCLENODE *node, locop::NodeSummary &s)
-{
- s.args().append("x", tbl->lookup(node->x()));
- s.args().append("y", tbl->lookup(node->y()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-template <class CIRCLENODE>
-bool use_xy_act(const locop::SymbolTable *tbl, const CIRCLENODE *node, locop::NodeSummary &s)
-{
- assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED);
-
- s.args().append("x", tbl->lookup(node->x()));
- s.args().append("y", tbl->lookup(node->y()));
- s.args().append("fused_activation_function", to_str(node->fusedActivationFunction()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-template <class CIRCLENODE>
-bool use_reducer(const locop::SymbolTable *tbl, const CIRCLENODE *node, locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("reduction_indices", tbl->lookup(node->reduction_indices()));
- s.args().append("keep_dims", node->keep_dims() ? "true" : "false");
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-template <class CIRCLENODE>
-bool use_ido(const locop::SymbolTable *tbl, const CIRCLENODE *node, locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("dimension", tbl->lookup(node->dimension()));
- s.args().append("output_type", to_str(node->output_type()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleAddN *node,
- locop::NodeSummary &s)
-{
- for (uint32_t i = 0; i < node->arity(); ++i)
- s.args().append("inputs", tbl->lookup(node->inputs(i)));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleAveragePool2D *node,
- locop::NodeSummary &s)
-{
- assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED);
-
- s.args().append("value", tbl->lookup(node->value()));
- s.args().append("filter(h,w)", to_str(node->filter()));
- s.args().append("stride(h,w)", to_str(node->stride()));
- s.args().append("padding", to_str(node->padding()));
- s.args().append("fused", to_str(node->fusedActivationFunction()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleBatchMatMul *node,
- locop::NodeSummary &s)
-{
- s.args().append("x", tbl->lookup(node->x()));
- s.args().append("y", tbl->lookup(node->y()));
- s.args().append("adj_x", to_str(node->adj_x()));
- s.args().append("adj_y", to_str(node->adj_y()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleBatchToSpaceND *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("block_shape", tbl->lookup(node->block_shape()));
- s.args().append("crops", tbl->lookup(node->crops()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleBidirectionalSequenceLSTM *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
-
- s.args().append("fw_input_to_input_weights", tbl->lookup(node->fw_input_to_input_weights()));
- s.args().append("fw_input_to_forget_weights", tbl->lookup(node->fw_input_to_forget_weights()));
- s.args().append("fw_input_to_cell_weights", tbl->lookup(node->fw_input_to_cell_weights()));
- s.args().append("fw_input_to_output_weights", tbl->lookup(node->fw_input_to_output_weights()));
-
- s.args().append("fw_recurrent_to_input_weights",
- tbl->lookup(node->fw_recurrent_to_input_weights()));
- s.args().append("fw_recurrent_to_forget_weights",
- tbl->lookup(node->fw_recurrent_to_forget_weights()));
- s.args().append("fw_recurrent_to_cell_weights",
- tbl->lookup(node->fw_recurrent_to_cell_weights()));
- s.args().append("fw_recurrent_to_output_weights",
- tbl->lookup(node->fw_recurrent_to_output_weights()));
-
- s.args().append("fw_cell_to_input_weights", tbl->lookup(node->fw_cell_to_input_weights()));
- s.args().append("fw_cell_to_forget_weights", tbl->lookup(node->fw_cell_to_forget_weights()));
- s.args().append("fw_cell_to_output_weights", tbl->lookup(node->fw_cell_to_output_weights()));
-
- s.args().append("fw_input_gate_bias", tbl->lookup(node->fw_input_gate_bias()));
- s.args().append("fw_forget_gate_bias", tbl->lookup(node->fw_forget_gate_bias()));
- s.args().append("fw_cell_gate_bias", tbl->lookup(node->fw_cell_gate_bias()));
- s.args().append("fw_output_gate_bias", tbl->lookup(node->fw_output_gate_bias()));
-
- s.args().append("fw_projection_weights", tbl->lookup(node->fw_projection_weights()));
- s.args().append("fw_projection_bias", tbl->lookup(node->fw_projection_bias()));
-
- s.args().append("bw_input_to_input_weights", tbl->lookup(node->bw_input_to_input_weights()));
- s.args().append("bw_input_to_forget_weights", tbl->lookup(node->bw_input_to_forget_weights()));
- s.args().append("bw_input_to_cell_weights", tbl->lookup(node->bw_input_to_cell_weights()));
- s.args().append("bw_input_to_output_weights", tbl->lookup(node->bw_input_to_output_weights()));
-
- s.args().append("bw_recurrent_to_input_weights",
- tbl->lookup(node->bw_recurrent_to_input_weights()));
- s.args().append("bw_recurrent_to_forget_weights",
- tbl->lookup(node->bw_recurrent_to_forget_weights()));
- s.args().append("bw_recurrent_to_cell_weights",
- tbl->lookup(node->bw_recurrent_to_cell_weights()));
- s.args().append("bw_recurrent_to_output_weights",
- tbl->lookup(node->bw_recurrent_to_output_weights()));
-
- s.args().append("bw_cell_to_input_weights", tbl->lookup(node->bw_cell_to_input_weights()));
- s.args().append("bw_cell_to_forget_weights", tbl->lookup(node->bw_cell_to_forget_weights()));
- s.args().append("bw_cell_to_output_weights", tbl->lookup(node->bw_cell_to_output_weights()));
-
- s.args().append("bw_input_gate_bias", tbl->lookup(node->bw_input_gate_bias()));
- s.args().append("bw_forget_gate_bias", tbl->lookup(node->bw_forget_gate_bias()));
- s.args().append("bw_cell_gate_bias", tbl->lookup(node->bw_cell_gate_bias()));
- s.args().append("bw_output_gate_bias", tbl->lookup(node->bw_output_gate_bias()));
-
- s.args().append("bw_projection_weights", tbl->lookup(node->bw_projection_weights()));
- s.args().append("bw_projection_bias", tbl->lookup(node->bw_projection_bias()));
-
- s.args().append("fw_activation_state", tbl->lookup(node->fw_activation_state()));
- s.args().append("fw_cell_state", tbl->lookup(node->fw_cell_state()));
- s.args().append("bw_activation_state", tbl->lookup(node->bw_activation_state()));
- s.args().append("bw_cell_state", tbl->lookup(node->bw_cell_state()));
-
- s.args().append("auxillary_input", tbl->lookup(node->auxillary_input()));
- s.args().append("fw_auxillary_input_to_input_weights",
- tbl->lookup(node->fw_auxillary_input_to_input_weights()));
- s.args().append("fw_auxillary_input_to_forget_weights",
- tbl->lookup(node->fw_auxillary_input_to_forget_weights()));
- s.args().append("fw_auxillary_input_to_cell_weights",
- tbl->lookup(node->fw_auxillary_input_to_cell_weights()));
- s.args().append("fw_auxillary_input_to_output_weights",
- tbl->lookup(node->fw_auxillary_input_to_output_weights()));
- s.args().append("bw_auxillary_input_to_input_weights",
- tbl->lookup(node->bw_auxillary_input_to_input_weights()));
- s.args().append("bw_auxillary_input_to_forget_weights",
- tbl->lookup(node->bw_auxillary_input_to_forget_weights()));
- s.args().append("bw_auxillary_input_to_cell_weights",
- tbl->lookup(node->bw_auxillary_input_to_cell_weights()));
- s.args().append("bw_auxillary_input_to_output_weights",
- tbl->lookup(node->bw_auxillary_input_to_output_weights()));
-
- s.args().append("cell_clip", to_str(node->cell_clip()));
- s.args().append("proj_clip", to_str(node->proj_clip()));
- s.args().append("merge_outputs", to_str(node->merge_outputs()));
- s.args().append("time_major", to_str(node->time_major()));
- s.args().append("asymmetric_quantize_inputs", to_str(node->asymmetric_quantize_inputs()));
-
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleCast *node,
- locop::NodeSummary &s)
-{
- s.args().append("x", tbl->lookup(node->x()));
- s.args().append("in_data_type", to_str(node->in_data_type()));
- s.args().append("out_data_type", to_str(node->out_data_type()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleConcatenation *node,
- locop::NodeSummary &s)
-{
- assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED);
-
- for (uint32_t i = 0; i < node->numValues(); ++i)
- s.args().append("values", tbl->lookup(node->values(i)));
- s.args().append("axis", pepper::str(node->axis()));
- s.args().append("fused", to_str(node->fusedActivationFunction()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleConv2D *node,
- locop::NodeSummary &s)
-{
- assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED);
- assert(node->padding() != luci::Padding::UNDEFINED);
-
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("filter", tbl->lookup(node->filter()));
- s.args().append("bias", tbl->lookup(node->bias()));
- s.args().append("stride(h,w)", to_str(node->stride()));
- s.args().append("dilation(h,w)", to_str(node->dilation()));
- s.args().append("padding", to_str(node->padding()));
- s.args().append("fused", to_str(node->fusedActivationFunction()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleCustom *node,
- locop::NodeSummary &s)
-{
- for (uint32_t i = 0; i < node->numInputs(); i++)
- {
- s.args().append("input" + std::to_string(i), tbl->lookup(node->inputs(i)));
- }
- s.args().append("custom_code", node->custom_code());
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleDepthToSpace *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("block_size", std::to_string(node->block_size()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleDepthwiseConv2D *node,
- locop::NodeSummary &s)
-{
- assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED);
- assert(node->padding() != luci::Padding::UNDEFINED);
-
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("filter", tbl->lookup(node->filter()));
- s.args().append("bias", tbl->lookup(node->bias()));
- s.args().append("stride(h,w)", to_str(node->stride()));
- s.args().append("dilation(h,w)", to_str(node->dilation()));
- s.args().append("padding", to_str(node->padding()));
- s.args().append("depthMultiplier", std::to_string(node->depthMultiplier()));
- s.args().append("fused", to_str(node->fusedActivationFunction()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleExpandDims *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("axis", tbl->lookup(node->axis()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleFakeQuant *node,
- locop::NodeSummary &s)
-{
- s.args().append("inputs", tbl->lookup(node->inputs()));
- s.args().append("min", pepper::str(node->min()));
- s.args().append("max", pepper::str(node->max()));
- s.args().append("num_bits", pepper::str(node->num_bits()));
- s.args().append("narrow_range", node->narrow_range() ? "true" : "false");
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleFill *node,
- locop::NodeSummary &s)
-{
- s.args().append("dims", tbl->lookup(node->dims()));
- s.args().append("value", tbl->lookup(node->value()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleFullyConnected *node,
- locop::NodeSummary &s)
-{
- assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED);
-
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("weights", tbl->lookup(node->weights()));
- s.args().append("bias", tbl->lookup(node->bias()));
- s.args().append("fused", to_str(node->fusedActivationFunction()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleGather *node,
- locop::NodeSummary &s)
-{
- s.args().append("params", tbl->lookup(node->params()));
- s.args().append("indices", tbl->lookup(node->indices()));
- s.args().append("axis", pepper::str(node->axis()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleGatherNd *node,
- locop::NodeSummary &s)
-{
- s.args().append("params", tbl->lookup(node->params()));
- s.args().append("indices", tbl->lookup(node->indices()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleIf *node, locop::NodeSummary &s)
-{
- s.args().append("cond", tbl->lookup(node->cond()));
- for (uint32_t i = 0; i < node->input_count(); ++i)
- s.args().append("input", tbl->lookup(node->input(i)));
-
- if (node->then_graph() != nullptr)
- s.args().append("then_graph", node->then_graph()->name());
- else
- s.args().append("then_branch", pepper::str(node->then_branch()));
-
- if (node->else_graph() != nullptr)
- s.args().append("else_graph", node->else_graph()->name());
- else
- s.args().append("else_branch", pepper::str(node->else_branch()));
-
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleL2Normalize *node,
- locop::NodeSummary &s)
-{
- s.args().append("x", tbl->lookup(node->x()));
- s.args().append("fused_activation_function", to_str(node->fusedActivationFunction()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleL2Pool2D *node,
- locop::NodeSummary &s)
-{
- assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED);
-
- s.args().append("value", tbl->lookup(node->value()));
- s.args().append("filter(h,w)", to_str(node->filter()));
- s.args().append("stride(h,w)", to_str(node->stride()));
- s.args().append("padding", to_str(node->padding()));
- s.args().append("fused", to_str(node->fusedActivationFunction()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleLeakyRelu *node,
- locop::NodeSummary &s)
-{
- s.args().append("features", tbl->lookup(node->features()));
- s.args().append("alpha", std::to_string(node->alpha()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleLocalResponseNormalization *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("radius", pepper::str(node->radius()));
- s.args().append("bias", pepper::str(node->bias()));
- s.args().append("alpha", pepper::str(node->alpha()));
- s.args().append("beta", pepper::str(node->beta()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleLogSoftmax *node,
- locop::NodeSummary &s)
-{
- s.args().append("logits", tbl->lookup(node->logits()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleMatrixDiag *node,
- locop::NodeSummary &s)
-{
- s.args().append("diagonal", tbl->lookup(node->diagonal()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleMatrixSetDiag *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("diagonal", tbl->lookup(node->diagonal()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleMaxPool2D *node,
- locop::NodeSummary &s)
-{
- assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED);
-
- s.args().append("value", tbl->lookup(node->value()));
- s.args().append("filter(h,w)", to_str(node->filter()));
- s.args().append("stride(h,w)", to_str(node->stride()));
- s.args().append("padding", to_str(node->padding()));
- s.args().append("fused", to_str(node->fusedActivationFunction()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleMirrorPad *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("paddings", tbl->lookup(node->paddings()));
- s.args().append("mode", to_str(node->mode()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleNonMaxSuppressionV4 *node,
- locop::NodeSummary &s)
-{
- s.args().append("boxes", tbl->lookup(node->boxes()));
- s.args().append("scores", tbl->lookup(node->scores()));
- s.args().append("max_output_size", tbl->lookup(node->max_output_size()));
- s.args().append("iou_threshold", tbl->lookup(node->iou_threshold()));
- s.args().append("score_threshold", tbl->lookup(node->score_threshold()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleNonMaxSuppressionV5 *node,
- locop::NodeSummary &s)
-{
- s.args().append("boxes", tbl->lookup(node->boxes()));
- s.args().append("scores", tbl->lookup(node->scores()));
- s.args().append("max_output_size", tbl->lookup(node->max_output_size()));
- s.args().append("iou_threshold", tbl->lookup(node->iou_threshold()));
- s.args().append("score_threshold", tbl->lookup(node->score_threshold()));
- s.args().append("soft_nms_sigma", tbl->lookup(node->soft_nms_sigma()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleOneHot *node,
- locop::NodeSummary &s)
-{
- s.args().append("indices", tbl->lookup(node->indices()));
- s.args().append("depth", tbl->lookup(node->depth()));
- s.args().append("on_value", tbl->lookup(node->on_value()));
- s.args().append("off_value", tbl->lookup(node->off_value()));
- s.args().append("axis", pepper::str(node->axis()));
-
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CirclePack *node,
- locop::NodeSummary &s)
-{
- for (uint32_t i = 0; i < node->values_count(); ++i)
- s.args().append("values", tbl->lookup(node->values(i)));
- s.args().append("values_count", pepper::str(node->values_count()));
- s.args().append("axis", pepper::str(node->axis()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CirclePad *node, locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("paddings", tbl->lookup(node->paddings()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CirclePadV2 *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("paddings", tbl->lookup(node->paddings()));
- s.args().append("constant_values", tbl->lookup(node->constant_values()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CirclePRelu *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("alpha", tbl->lookup(node->alpha()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleRange *node,
- locop::NodeSummary &s)
-{
- s.args().append("start", tbl->lookup(node->start()));
- s.args().append("limit", tbl->lookup(node->limit()));
- s.args().append("delta", tbl->lookup(node->delta()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleReshape *node,
- locop::NodeSummary &s)
-{
- s.args().append("tensor", tbl->lookup(node->tensor()));
- s.args().append("shape", tbl->lookup(node->shape()));
- // TODO Show newShape info
- s.state(locop::NodeSummary::State::PartiallyKnown);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleResizeBilinear *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("size", tbl->lookup(node->size()));
- s.args().append("align_corners", node->align_corners() ? "true" : "false");
- s.args().append("half_pixel_centers", node->half_pixel_centers() ? "true" : "false");
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleResizeNearestNeighbor *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("size", tbl->lookup(node->size()));
- s.args().append("align_corners", node->align_corners() ? "true" : "false");
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleReverseSequence *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("seq_lengths", tbl->lookup(node->seq_lengths()));
- s.args().append("seq_axis", std::to_string(node->seq_axis()));
- s.args().append("batch_axis", std::to_string(node->batch_axis()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleReverseV2 *node,
- locop::NodeSummary &s)
-{
- s.args().append("tensor", tbl->lookup(node->tensor()));
- s.args().append("axis", tbl->lookup(node->axis()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleScatterNd *node,
- locop::NodeSummary &s)
-{
- s.args().append("indices", tbl->lookup(node->indices()));
- s.args().append("updates", tbl->lookup(node->updates()));
- s.args().append("shape", tbl->lookup(node->shape()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSegmentSum *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("segment_ids", tbl->lookup(node->segment_ids()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSelect *node,
- locop::NodeSummary &s)
-{
- s.args().append("condition", tbl->lookup(node->condition()));
- s.args().append("t", tbl->lookup(node->t()));
- s.args().append("e", tbl->lookup(node->e()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSelectV2 *node,
- locop::NodeSummary &s)
-{
- s.args().append("condition", tbl->lookup(node->condition()));
- s.args().append("t", tbl->lookup(node->t()));
- s.args().append("e", tbl->lookup(node->e()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleShape *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("out_type", to_str(node->out_type()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSlice *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("begin", tbl->lookup(node->begin()));
- s.args().append("size", tbl->lookup(node->size()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSoftmax *node,
- locop::NodeSummary &s)
-{
- s.args().append("logits", tbl->lookup(node->logits()));
- s.args().append("beta", pepper::str(node->beta()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSpaceToBatchND *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("block_shape", tbl->lookup(node->block_shape()));
- s.args().append("paddings", tbl->lookup(node->paddings()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSpaceToDepth *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("block_size", pepper::str(node->block_size()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSparseToDense *node,
- locop::NodeSummary &s)
-{
- s.args().append("indices", tbl->lookup(node->indices()));
- s.args().append("output_shape", tbl->lookup(node->output_shape()));
- s.args().append("values", tbl->lookup(node->values()));
- s.args().append("default_value", tbl->lookup(node->default_value()));
- s.args().append("Validate_indices", pepper::str(node->validate_indices()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSplit *node,
- locop::NodeSummary &s)
-{
- s.args().append("split_dim", tbl->lookup(node->split_dim()));
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("num_split", pepper::str(node->num_split()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSplitV *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("size_splits", tbl->lookup(node->size_splits()));
- s.args().append("split_dim", tbl->lookup(node->split_dim()));
- s.args().append("num_split", pepper::str(node->num_split()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSqueeze *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
-
- std::stringstream ss{"("};
- for (size_t i = 0; i < node->squeeze_dims().size(); ++i)
- {
- if (i != 0)
- ss << ", ";
- ss << node->squeeze_dims()[i];
- }
- ss << ")";
- s.args().append("squeeze_dims", ss.str());
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleStridedSlice *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("begin", tbl->lookup(node->begin()));
- s.args().append("end", tbl->lookup(node->end()));
- s.args().append("strides", tbl->lookup(node->strides()));
- s.args().append("begin_mask", pepper::str(node->begin_mask()));
- s.args().append("end_mask", pepper::str(node->end_mask()));
- s.args().append("ellipsis_mask", pepper::str(node->ellipsis_mask()));
- s.args().append("new_axis_mask", pepper::str(node->new_axis_mask()));
- s.args().append("shrink_axis_mask", pepper::str(node->shrink_axis_mask()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleTile *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("multiples", tbl->lookup(node->multiples()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleTopKV2 *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("k", tbl->lookup(node->k()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleTranspose *node,
- locop::NodeSummary &s)
-{
- s.args().append("a", tbl->lookup(node->a()));
- s.args().append("perm", tbl->lookup(node->perm()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleTransposeConv *node,
- locop::NodeSummary &s)
-{
- assert(node->padding() != luci::Padding::UNDEFINED);
-
- s.args().append("inputSizes", tbl->lookup(node->inputSizes()));
- s.args().append("filter", tbl->lookup(node->filter()));
- s.args().append("outBackprop", tbl->lookup(node->outBackprop()));
- s.args().append("bias", tbl->lookup(node->bias()));
- s.args().append("stride(h,w)", to_str(node->stride()));
- s.args().append("padding", to_str(node->padding()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleUnidirectionalSequenceLSTM *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
-
- s.args().append("input_to_input_weights", tbl->lookup(node->input_to_input_weights()));
- s.args().append("input_to_forget_weights", tbl->lookup(node->input_to_forget_weights()));
- s.args().append("input_to_cell_weights", tbl->lookup(node->input_to_cell_weights()));
- s.args().append("input_to_output_weights", tbl->lookup(node->input_to_output_weights()));
-
- s.args().append("recurrent_to_input_weights", tbl->lookup(node->recurrent_to_input_weights()));
- s.args().append("recurrent_to_forget_weights", tbl->lookup(node->recurrent_to_forget_weights()));
- s.args().append("recurrent_to_cell_weights", tbl->lookup(node->recurrent_to_cell_weights()));
- s.args().append("recurrent_to_output_weights", tbl->lookup(node->recurrent_to_output_weights()));
-
- s.args().append("cell_to_input_weights", tbl->lookup(node->cell_to_input_weights()));
- s.args().append("cell_to_forget_weights", tbl->lookup(node->cell_to_forget_weights()));
- s.args().append("cell_to_output_weights", tbl->lookup(node->cell_to_output_weights()));
-
- s.args().append("input_gate_bias", tbl->lookup(node->input_gate_bias()));
- s.args().append("forget_gate_bias", tbl->lookup(node->forget_gate_bias()));
- s.args().append("cell_gate_bias", tbl->lookup(node->cell_gate_bias()));
- s.args().append("output_gate_bias", tbl->lookup(node->output_gate_bias()));
-
- s.args().append("projection_weights", tbl->lookup(node->projection_weights()));
- s.args().append("projection_bias", tbl->lookup(node->projection_bias()));
-
- s.args().append("activation_state", tbl->lookup(node->activation_state()));
- s.args().append("cell_state", tbl->lookup(node->cell_state()));
-
- s.args().append("input_layer_norm_coefficients",
- tbl->lookup(node->input_layer_norm_coefficients()));
- s.args().append("forget_layer_norm_coefficients",
- tbl->lookup(node->forget_layer_norm_coefficients()));
- s.args().append("cell_layer_norm_coefficients",
- tbl->lookup(node->cell_layer_norm_coefficients()));
- s.args().append("output_layer_norm_coefficients",
- tbl->lookup(node->output_layer_norm_coefficients()));
-
- s.args().append("cell_clip", to_str(node->cell_clip()));
- s.args().append("proj_clip", to_str(node->proj_clip()));
- s.args().append("time_major", to_str(node->time_major()));
- s.args().append("asymmetric_quantize_inputs", to_str(node->asymmetric_quantize_inputs()));
-
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleUnique *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("idx_out_type", to_str(node->idx_out_type()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleUnpack *node,
- locop::NodeSummary &s)
-{
- s.args().append("value", tbl->lookup(node->value()));
- s.args().append("num", pepper::str(node->num()));
- s.args().append("axis", pepper::str(node->axis()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleWhere *node,
- locop::NodeSummary &s)
-{
- s.args().append("condition", tbl->lookup(node->condition()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleWhile *node,
- locop::NodeSummary &s)
-{
- for (uint32_t i = 0; i < node->input_count(); ++i)
- s.args().append("input", tbl->lookup(node->input(i)));
-
- if (node->cond_graph() != nullptr)
- s.args().append("cond_graph", node->cond_graph()->name());
- else
- s.args().append("cond_branch", pepper::str(node->cond_branch()));
-
- if (node->body_graph() != nullptr)
- s.args().append("body_graph", node->body_graph()->name());
- else
- s.args().append("body_branch", pepper::str(node->body_branch()));
-
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleTopKV2Out *node,
- locop::NodeSummary &s)
-{
- s.args().append("topkv2", tbl->lookup(node->input()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleUniqueOut *node,
- locop::NodeSummary &s)
-{
- s.args().append("unique", tbl->lookup(node->input()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleUnpackOut *node,
- locop::NodeSummary &s)
-{
- s.args().append("unpack", tbl->lookup(node->input()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleWhileOut *node,
- locop::NodeSummary &s)
-{
- s.args().append("while", tbl->lookup(node->input()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleOutput *node,
- locop::NodeSummary &s)
-{
- s.args().append("from", tbl->lookup(node->from()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *, const luci::CircleOutputDummy *,
- locop::NodeSummary &s)
-{
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *, const luci::CircleOutputExclude *,
- locop::NodeSummary &s)
-{
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleBCQFullyConnected *node,
- locop::NodeSummary &s)
-{
- assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED);
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("weights_scales", tbl->lookup(node->weights_scales()));
- s.args().append("weights_binary", tbl->lookup(node->weights_binary()));
- s.args().append("bias", tbl->lookup(node->bias()));
- s.args().append("weights_clusters", tbl->lookup(node->weights_clusters()));
- s.args().append("fused", to_str(node->fusedActivationFunction()));
- s.args().append("weights_hidden_size", pepper::str(node->weights_hidden_size()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleBCQGather *node,
- locop::NodeSummary &s)
-{
- s.args().append("input_scales", tbl->lookup(node->input_scales()));
- s.args().append("input_binary", tbl->lookup(node->input_binary()));
- s.args().append("indices", tbl->lookup(node->indices()));
- s.args().append("input_clusters", tbl->lookup(node->input_clusters()));
- s.args().append("axis", pepper::str(node->axis()));
- s.args().append("input_hidden_size", pepper::str(node->input_hidden_size()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleInstanceNorm *node,
- locop::NodeSummary &s)
-{
- auto fused = node->fusedActivationFunction();
- assert(fused != luci::FusedActFunc::UNDEFINED);
-
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("gamma", tbl->lookup(node->gamma()));
- s.args().append("beta", tbl->lookup(node->beta()));
- s.args().append("epsilon", pepper::str(node->epsilon()));
- s.args().append("fused_activation_function", to_str(fused));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-// SummaryBuilderLet type
-enum class SB
-{
- ABC,
- DEF,
- GHIJ,
- KLMN,
- OPQR,
- STUV,
- WXYZ,
- CIRC, // circle only
- VIRT, // virtual
-};
-
-template <SB sb> class SummaryBuilderLet;
-
-#define IMPLEMENT(CLASS) bool summary(const CLASS *, locop::NodeSummary &) const final;
-
-template <> class SummaryBuilderLet<SB::ABC> final : public CircleNodeSummaryBuilderBase
-{
-public:
- SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
- {
- // DO NOTHING
- }
-
-private:
- IMPLEMENT(luci::CircleAbs)
- IMPLEMENT(luci::CircleAdd)
- IMPLEMENT(luci::CircleAddN)
- IMPLEMENT(luci::CircleArgMax)
- IMPLEMENT(luci::CircleArgMin)
- IMPLEMENT(luci::CircleAveragePool2D)
- IMPLEMENT(luci::CircleBatchMatMul)
- IMPLEMENT(luci::CircleBatchToSpaceND)
- IMPLEMENT(luci::CircleBidirectionalSequenceLSTM)
- IMPLEMENT(luci::CircleCast)
- IMPLEMENT(luci::CircleCeil)
- IMPLEMENT(luci::CircleConcatenation)
- IMPLEMENT(luci::CircleConst)
- IMPLEMENT(luci::CircleConv2D)
- IMPLEMENT(luci::CircleCos)
- IMPLEMENT(luci::CircleCustom)
-};
-
-template <> class SummaryBuilderLet<SB::DEF> final : public CircleNodeSummaryBuilderBase
-{
-public:
- SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
- {
- // DO NOTHING
- }
-
-private:
- IMPLEMENT(luci::CircleDepthToSpace)
- IMPLEMENT(luci::CircleDepthwiseConv2D)
- IMPLEMENT(luci::CircleDequantize)
- IMPLEMENT(luci::CircleDiv)
- IMPLEMENT(luci::CircleElu)
- IMPLEMENT(luci::CircleEqual)
- IMPLEMENT(luci::CircleExp)
- IMPLEMENT(luci::CircleExpandDims)
- IMPLEMENT(luci::CircleFakeQuant)
- IMPLEMENT(luci::CircleFill)
- IMPLEMENT(luci::CircleFloor)
- IMPLEMENT(luci::CircleFloorDiv)
- IMPLEMENT(luci::CircleFloorMod)
- IMPLEMENT(luci::CircleFullyConnected)
-};
-
-template <> class SummaryBuilderLet<SB::GHIJ> final : public CircleNodeSummaryBuilderBase
-{
-public:
- SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
- {
- // DO NOTHING
- }
-
-private:
- IMPLEMENT(luci::CircleGather)
- IMPLEMENT(luci::CircleGatherNd)
- IMPLEMENT(luci::CircleGreater)
- IMPLEMENT(luci::CircleGreaterEqual)
- IMPLEMENT(luci::CircleIf)
-};
-
-template <> class SummaryBuilderLet<SB::KLMN> final : public CircleNodeSummaryBuilderBase
-{
-public:
- SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
- {
- // DO NOTHING
- }
-
-private:
- IMPLEMENT(luci::CircleL2Normalize)
- IMPLEMENT(luci::CircleL2Pool2D)
- IMPLEMENT(luci::CircleLeakyRelu)
- IMPLEMENT(luci::CircleLess)
- IMPLEMENT(luci::CircleLessEqual)
- IMPLEMENT(luci::CircleLocalResponseNormalization)
- IMPLEMENT(luci::CircleLog)
- IMPLEMENT(luci::CircleLogicalAnd)
- IMPLEMENT(luci::CircleLogicalNot)
- IMPLEMENT(luci::CircleLogicalOr)
- IMPLEMENT(luci::CircleLogistic)
- IMPLEMENT(luci::CircleLogSoftmax)
- IMPLEMENT(luci::CircleMatrixDiag)
- IMPLEMENT(luci::CircleMatrixSetDiag)
- IMPLEMENT(luci::CircleMaximum)
- IMPLEMENT(luci::CircleMaxPool2D)
- IMPLEMENT(luci::CircleMean)
- IMPLEMENT(luci::CircleMinimum)
- IMPLEMENT(luci::CircleMirrorPad)
- IMPLEMENT(luci::CircleMul)
- IMPLEMENT(luci::CircleNeg)
- IMPLEMENT(luci::CircleNonMaxSuppressionV4)
- IMPLEMENT(luci::CircleNonMaxSuppressionV5)
- IMPLEMENT(luci::CircleNotEqual)
-};
-
-template <> class SummaryBuilderLet<SB::OPQR> final : public CircleNodeSummaryBuilderBase
-{
-public:
- SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
- {
- // DO NOTHING
- }
-
-private:
- IMPLEMENT(luci::CircleOneHot)
- IMPLEMENT(luci::CirclePack)
- IMPLEMENT(luci::CirclePad)
- IMPLEMENT(luci::CirclePadV2)
- IMPLEMENT(luci::CirclePow)
- IMPLEMENT(luci::CirclePRelu)
- IMPLEMENT(luci::CircleQuantize)
- IMPLEMENT(luci::CircleRange)
- IMPLEMENT(luci::CircleRank)
- IMPLEMENT(luci::CircleReduceAny)
- IMPLEMENT(luci::CircleReduceMax)
- IMPLEMENT(luci::CircleReduceMin)
- IMPLEMENT(luci::CircleReduceProd)
- IMPLEMENT(luci::CircleRelu)
- IMPLEMENT(luci::CircleRelu6)
- IMPLEMENT(luci::CircleReluN1To1)
- IMPLEMENT(luci::CircleReshape)
- IMPLEMENT(luci::CircleResizeBilinear)
- IMPLEMENT(luci::CircleResizeNearestNeighbor)
- IMPLEMENT(luci::CircleReverseSequence)
- IMPLEMENT(luci::CircleReverseV2)
- IMPLEMENT(luci::CircleRound)
- IMPLEMENT(luci::CircleRsqrt)
-};
-
-template <> class SummaryBuilderLet<SB::STUV> final : public CircleNodeSummaryBuilderBase
-{
-public:
- SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
- {
- // DO NOTHING
- }
-
-private:
- IMPLEMENT(luci::CircleScatterNd)
- IMPLEMENT(luci::CircleSegmentSum)
- IMPLEMENT(luci::CircleSelect)
- IMPLEMENT(luci::CircleSelectV2)
- IMPLEMENT(luci::CircleShape)
- IMPLEMENT(luci::CircleSin)
- IMPLEMENT(luci::CircleSlice)
- IMPLEMENT(luci::CircleSoftmax)
- IMPLEMENT(luci::CircleSpaceToBatchND)
- IMPLEMENT(luci::CircleSpaceToDepth)
- IMPLEMENT(luci::CircleSparseToDense)
- IMPLEMENT(luci::CircleSplit)
- IMPLEMENT(luci::CircleSplitV)
- IMPLEMENT(luci::CircleSqrt)
- IMPLEMENT(luci::CircleSquare)
- IMPLEMENT(luci::CircleSquaredDifference)
- IMPLEMENT(luci::CircleSqueeze)
- IMPLEMENT(luci::CircleStridedSlice)
- IMPLEMENT(luci::CircleSub)
- IMPLEMENT(luci::CircleSum)
- IMPLEMENT(luci::CircleTanh)
- IMPLEMENT(luci::CircleTile)
- IMPLEMENT(luci::CircleTopKV2)
- IMPLEMENT(luci::CircleTranspose)
- IMPLEMENT(luci::CircleTransposeConv)
- IMPLEMENT(luci::CircleUnidirectionalSequenceLSTM)
- IMPLEMENT(luci::CircleUnique)
- IMPLEMENT(luci::CircleUnpack)
-};
-
-template <> class SummaryBuilderLet<SB::WXYZ> final : public CircleNodeSummaryBuilderBase
-{
-public:
- SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
- {
- // DO NOTHING
- }
-
-private:
- IMPLEMENT(luci::CircleWhere)
- IMPLEMENT(luci::CircleWhile)
- IMPLEMENT(luci::CircleZerosLike)
-};
-
-template <> class SummaryBuilderLet<SB::CIRC> final : public CircleNodeSummaryBuilderBase
-{
-public:
- SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
- {
- // DO NOTHING
- }
-
-private:
- IMPLEMENT(luci::CircleBCQFullyConnected)
- IMPLEMENT(luci::CircleBCQGather)
- IMPLEMENT(luci::CircleInstanceNorm)
-};
-
-template <> class SummaryBuilderLet<SB::VIRT> final : public CircleNodeSummaryBuilderBase
-{
-public:
- SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
- {
- // DO NOTHING
- }
-
-private:
- IMPLEMENT(luci::CircleInput)
- IMPLEMENT(luci::CircleOutput)
- IMPLEMENT(luci::CircleCustomOut)
- IMPLEMENT(luci::CircleIfOut)
- IMPLEMENT(luci::CircleNonMaxSuppressionV4Out)
- IMPLEMENT(luci::CircleNonMaxSuppressionV5Out)
- IMPLEMENT(luci::CircleOutputDummy)
- IMPLEMENT(luci::CircleOutputExclude)
- IMPLEMENT(luci::CircleSplitOut)
- IMPLEMENT(luci::CircleSplitVOut)
- IMPLEMENT(luci::CircleTopKV2Out)
- IMPLEMENT(luci::CircleUniqueOut)
- IMPLEMENT(luci::CircleUnpackOut)
- IMPLEMENT(luci::CircleWhileOut)
-};
-
-#undef IMPLEMENT
-
-bool CircleNodeSummaryBuilderBase::build(const loco::Node *node, locop::NodeSummary &s) const
-{
- if (node->dialect() != luci::CircleDialect::get())
- return false;
-
- auto ptr_to_str = [](const void *ptr) {
- std::stringstream ss;
- ss << ptr;
- return ss.str();
- };
-
- auto add_comment = [&]() {
- auto cnode = loco::must_cast<const luci::CircleNode *>(node);
- s.opname(circle_opname(node->opnum()));
- s.comments().append("[" + cnode->name() + "] = " + ptr_to_str(node));
- };
-
-#define CIRCLE_NODE(OPCODE, CLASS) \
- if (dynamic_cast<const CLASS *>(node)) \
- { \
- if (summary(dynamic_cast<const CLASS *>(node), s)) \
- { \
- add_comment(); \
- return true; \
- } \
- }
-#define CIRCLE_VNODE CIRCLE_NODE
-#include <luci/IR/CircleNodes.lst>
-#undef CIRCLE_VNODE
-#undef CIRCLE_NODE
-
- return false;
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleAbs *node, locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleAdd *node, locop::NodeSummary &s) const
-{
- return use_xy_act(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleAddN *node, locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleArgMax *node,
- locop::NodeSummary &s) const
-{
- return use_ido(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleArgMin *node,
- locop::NodeSummary &s) const
-{
- return use_ido(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleAveragePool2D *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleBatchMatMul *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleBatchToSpaceND *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleBidirectionalSequenceLSTM *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleCast *node, locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleCeil *node, locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleConcatenation *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleConst *, locop::NodeSummary &s) const
-{
- s.state(locop::NodeSummary::State::PartiallyKnown);
- return true;
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleConv2D *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleCos *node, locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleCustom *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleDepthToSpace *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleDepthwiseConv2D *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleDequantize *node,
- locop::NodeSummary &s) const
-{
- return use_input(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleDiv *node, locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleElu *node, locop::NodeSummary &s) const
-{
- return use_features(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleEqual *node, locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleExp *node, locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleExpandDims *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleFakeQuant *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleFill *node, locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleFloor *node, locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleFloorDiv *node,
- locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleFloorMod *node,
- locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleFullyConnected *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::GHIJ>::summary(const luci::CircleGather *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::GHIJ>::summary(const luci::CircleGatherNd *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::GHIJ>::summary(const luci::CircleGreater *node,
- locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::GHIJ>::summary(const luci::CircleGreaterEqual *node,
- locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::GHIJ>::summary(const luci::CircleIf *node, locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleL2Normalize *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleL2Pool2D *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLess *node, locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLessEqual *node,
- locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLeakyRelu *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLocalResponseNormalization *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLog *node, locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLogicalAnd *node,
- locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLogicalNot *node,
- locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLogicalOr *node,
- locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLogistic *node,
- locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLogSoftmax *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMatrixDiag *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMatrixSetDiag *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMaximum *node,
- locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMaxPool2D *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMean *node, locop::NodeSummary &s) const
-{
- return use_reducer(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMinimum *node,
- locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMirrorPad *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMul *node, locop::NodeSummary &s) const
-{
- return use_xy_act(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleNeg *node, locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleNonMaxSuppressionV4 *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleNonMaxSuppressionV5 *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleNotEqual *node,
- locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleOneHot *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CirclePack *node, locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CirclePad *node, locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CirclePadV2 *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CirclePow *node, locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CirclePRelu *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleQuantize *node,
- locop::NodeSummary &s) const
-{
- return use_input(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleRange *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleRank *node, locop::NodeSummary &s) const
-{
- return use_input(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReduceAny *node,
- locop::NodeSummary &s) const
-{
- return use_reducer(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReduceMax *node,
- locop::NodeSummary &s) const
-{
- return use_reducer(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReduceMin *node,
- locop::NodeSummary &s) const
-{
- return use_reducer(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReduceProd *node,
- locop::NodeSummary &s) const
-{
- return use_reducer(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleRelu *node, locop::NodeSummary &s) const
-{
- return use_features(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleRelu6 *node,
- locop::NodeSummary &s) const
-{
- return use_features(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReluN1To1 *node,
- locop::NodeSummary &s) const
-{
- return use_features(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReshape *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleResizeBilinear *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleResizeNearestNeighbor *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReverseSequence *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReverseV2 *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleRound *node,
- locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleRsqrt *node,
- locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleScatterNd *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSegmentSum *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSelect *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSelectV2 *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleShape *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSin *node, locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSlice *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSoftmax *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSpaceToBatchND *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSpaceToDepth *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSparseToDense *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSplit *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSplitV *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSqrt *node, locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSquare *node,
- locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSquaredDifference *node,
- locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSqueeze *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleStridedSlice *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSub *node, locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSum *node, locop::NodeSummary &s) const
-{
- return use_reducer(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleTanh *node, locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleTile *node, locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleTopKV2 *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleTranspose *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleTransposeConv *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleUnidirectionalSequenceLSTM *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleUnique *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleUnpack *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::WXYZ>::summary(const luci::CircleWhere *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::WXYZ>::summary(const luci::CircleWhile *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::WXYZ>::summary(const luci::CircleZerosLike *node,
- locop::NodeSummary &s) const
-{
- return use_input(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::CIRC>::summary(const luci::CircleBCQFullyConnected *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::CIRC>::summary(const luci::CircleBCQGather *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::CIRC>::summary(const luci::CircleInstanceNorm *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleInput *, locop::NodeSummary &s) const
-{
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleOutput *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleCustomOut *node,
- locop::NodeSummary &s) const
-{
- return use_input(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleIfOut *node,
- locop::NodeSummary &s) const
-{
- return use_input(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleNonMaxSuppressionV4Out *node,
- locop::NodeSummary &s) const
-{
- return use_input(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleNonMaxSuppressionV5Out *node,
- locop::NodeSummary &s) const
-{
- return use_input(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleOutputDummy *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleOutputExclude *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleSplitOut *node,
- locop::NodeSummary &s) const
-{
- return use_input(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleSplitVOut *node,
- locop::NodeSummary &s) const
-{
- return use_input(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleTopKV2Out *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleUniqueOut *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleUnpackOut *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleWhileOut *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-} // namespace
-
namespace luci
{
@@ -2208,22 +36,10 @@ bool NodeSummaryBuilder::build(const loco::Node *node, locop::NodeSummary &s) co
return true;
}
-#define BUILD_GRP(GRP) \
- do \
- { \
- if (SummaryBuilderLet<SB::GRP>(_tbl).build(node, s)) \
- return true; \
- } while (false)
-
- BUILD_GRP(ABC);
- BUILD_GRP(DEF);
- BUILD_GRP(GHIJ);
- BUILD_GRP(KLMN);
- BUILD_GRP(OPQR);
- BUILD_GRP(STUV);
- BUILD_GRP(WXYZ);
- BUILD_GRP(CIRC);
- BUILD_GRP(VIRT);
+ if (CircleNodeSummaryBuilder().build(node, _tbl, s))
+ {
+ return true;
+ }
return false;
}
diff --git a/compiler/luci/partition/CMakeLists.txt b/compiler/luci/partition/CMakeLists.txt
index ec8e0b0d6..f28207df2 100644
--- a/compiler/luci/partition/CMakeLists.txt
+++ b/compiler/luci/partition/CMakeLists.txt
@@ -13,7 +13,7 @@ target_link_libraries(luci_partition PUBLIC luci_lang)
target_link_libraries(luci_partition PRIVATE luci_service)
target_link_libraries(luci_partition PRIVATE luci_log)
target_link_libraries(luci_partition PRIVATE luci_logex)
-target_link_libraries(luci_partition PRIVATE mio_circle)
+target_link_libraries(luci_partition PRIVATE mio_circle04)
target_link_libraries(luci_partition PRIVATE nncc_common)
target_link_libraries(luci_partition PRIVATE pepper_csv2vec)
target_link_libraries(luci_partition PRIVATE oops)
diff --git a/compiler/luci/partition/src/ConnectNode.h b/compiler/luci/partition/src/ConnectNode.h
index ebbff7a6a..e60567c69 100644
--- a/compiler/luci/partition/src/ConnectNode.h
+++ b/compiler/luci/partition/src/ConnectNode.h
@@ -161,6 +161,7 @@ public:
void visit(const luci::CircleSquaredDifference *) final;
void visit(const luci::CircleSqueeze *) final;
void visit(const luci::CircleStridedSlice *) final;
+ void visit(const luci::CircleSVDF *) final;
void visit(const luci::CircleSub *) final;
void visit(const luci::CircleSum *) final;
void visit(const luci::CircleTanh *) final;
@@ -197,6 +198,7 @@ public:
void visit(const luci::CircleTopKV2Out *) final;
void visit(const luci::CircleUniqueOut *) final;
void visit(const luci::CircleUnpackOut *) final;
+ void visit(const luci::CircleVariable *) final;
void visit(const luci::CircleWhileOut *) final;
public:
diff --git a/compiler/luci/partition/src/Nodes/CircleSVDF.cpp b/compiler/luci/partition/src/Nodes/CircleSVDF.cpp
new file mode 100644
index 000000000..f661a794c
--- /dev/null
+++ b/compiler/luci/partition/src/Nodes/CircleSVDF.cpp
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleSVDF *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleSVDF *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *weight_feature = loco::must_cast<luci::CircleNode *>(node->weight_feature());
+ luci::CircleNode *weight_time = loco::must_cast<luci::CircleNode *>(node->weight_time());
+ luci::CircleNode *bias = loco::must_cast<luci::CircleNode *>(node->bias());
+ luci::CircleNode *input_activation_state =
+ loco::must_cast<luci::CircleNode *>(node->input_activation_state());
+
+ cloned->input(cn->find_clone(input));
+ cloned->weight_feature(cn->find_clone(weight_feature));
+ cloned->weight_time(cn->find_clone(weight_time));
+ cloned->bias(cn->find_clone(bias));
+ cloned->input_activation_state(cn->find_clone(input_activation_state));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleSVDF *node) { connect(this, node); }
+
+} // namespace luci
diff --git a/compiler/luci/partition/src/Nodes/CircleSVDF.test.cpp b/compiler/luci/partition/src/Nodes/CircleSVDF.test.cpp
new file mode 100644
index 000000000..5fae5206e
--- /dev/null
+++ b/compiler/luci/partition/src/Nodes/CircleSVDF.test.cpp
@@ -0,0 +1,106 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleSVDF>
+{
+public:
+ NodeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g)
+ {
+ NodeGraphletT<luci::CircleSVDF>::init(g);
+
+ _node->fusedActivationFunction(luci::FusedActFunc::RELU);
+ }
+};
+
+class TestNodeGraph : public TestIsOGraph<5>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<5>::init({shape, shape, shape, shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->weight_feature(input(1));
+ node()->weight_time(input(2));
+ node()->bias(input(3));
+ node()->input_activation_state(input(4));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_SVDF)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSVDF *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSVDF *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(5, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+ ASSERT_EQ(cth.inputs(3), clone->arg(3));
+ ASSERT_EQ(cth.inputs(4), clone->arg(4));
+}
+
+TEST(ConnectNodeTest, connect_SVDF_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSVDF *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSVDF *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
diff --git a/compiler/luci/partition/src/Nodes/CircleVariable.cpp b/compiler/luci/partition/src/Nodes/CircleVariable.cpp
new file mode 100644
index 000000000..f7f6f21fd
--- /dev/null
+++ b/compiler/luci/partition/src/Nodes/CircleVariable.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleVariable *)
+{
+ // Nothing to do
+}
+
+} // namespace luci
diff --git a/compiler/luci/partition/src/PartitionIRDump.cpp b/compiler/luci/partition/src/PartitionIRDump.cpp
index 4f2c26800..0fabfc416 100644
--- a/compiler/luci/partition/src/PartitionIRDump.cpp
+++ b/compiler/luci/partition/src/PartitionIRDump.cpp
@@ -32,18 +32,18 @@ void dump(std::ostream &os, const PNode *pnode)
void dump(std::ostream &os, const PGroup *pgroup)
{
os << "--- PGroup: " << pgroup->group << std::endl;
- os << "Input(s): ";
+ os << "Input(s): [ ";
for (auto &node_in : pgroup->inputs)
os << node_in->name() << " ";
- os << std::endl;
+ os << "]" << std::endl;
for (auto &pnode : pgroup->pnodes)
{
dump(os, pnode.get());
}
- os << "Output(s): ";
+ os << "Output(s): [ ";
for (auto &node_out : pgroup->outputs)
os << node_out->name() << " ";
- os << std::endl;
+ os << "]" << std::endl;
}
void dump(std::ostream &os, const PGroups *pgroups)
@@ -57,7 +57,8 @@ void dump(std::ostream &os, const PGroups *pgroups)
{
auto node = it->first;
auto group = it->second;
- os << " Node: " << node << "(" << node->name() << "): " << group << std::endl;
+ os << " Node: " << node << "(" << luci::opcode_name(node) << "," << node->name()
+ << "): " << group << std::endl;
}
}
diff --git a/compiler/luci/partition/src/PartitionMerge.cpp b/compiler/luci/partition/src/PartitionMerge.cpp
index c517bf93f..4c3971bd8 100644
--- a/compiler/luci/partition/src/PartitionMerge.cpp
+++ b/compiler/luci/partition/src/PartitionMerge.cpp
@@ -58,9 +58,6 @@ bool is_input_same(const luci::PGroup *pgroup, const luci::PGroups *pgroups)
// we need to clone this CircleConst for each graph of the group.
if (dynamic_cast<const luci::CircleConst *>(input) != nullptr)
continue;
- // Skip also for OutputExclude
- if (dynamic_cast<const luci::CircleOutputExclude *>(input) != nullptr)
- continue;
auto input_group = pgroups->group_of(input);
// NOTE: all the nodes should be registered and return should be valid group.
@@ -87,7 +84,7 @@ bool is_input_same(const luci::PGroup *pgroup, const luci::PGroups *pgroups)
input_pgroup = pgroup_input;
else
{
- if (input_pgroup != pgroup_input)
+ if (input_pgroup->group != pgroup_input->group)
return false;
}
}
@@ -96,6 +93,48 @@ bool is_input_same(const luci::PGroup *pgroup, const luci::PGroups *pgroups)
}
/**
+ * @brief return true if there is only one output and is fed to same group of nodes
+ * @note pgroups is used to find group of pgroup
+ * ex)
+ * /-- pgroup_user_1 (grp_1)
+ * --- pgroup
+ * \-- pgroup_user_2 (grp_2)
+ *
+ * return false if grp_1 != grp_2
+ */
+bool is_output_same(const luci::PGroup *pgroup, const luci::PGroups *pgroups)
+{
+ assert(pgroups != nullptr);
+ assert(pgroup != nullptr);
+
+ std::string group;
+ for (auto &output : pgroup->outputs)
+ {
+ // get output_group
+ auto output_group = pgroups->group_of(output);
+ assert(not output_group.empty());
+ if (output_group.empty())
+ output_group = pgroups->default_group;
+
+ // find all PGroup that uses output
+ for (auto &pgroup_user : pgroups->pgroups)
+ {
+ for (auto &user_inputs : pgroup_user->inputs)
+ {
+ if (output == user_inputs)
+ {
+ // OK, these are connected, check group is same
+ if (pgroup_user->group != output_group)
+ return false;
+ }
+ }
+ }
+ }
+
+ return true;
+}
+
+/**
* @brief merge pgroup into pgroup_i
* @note output of pgroup_i should be input of pgroup
*/
@@ -191,6 +230,9 @@ std::unique_ptr<luci::PGroups> merge_pgroups(const luci::PGroups *s_pgroups)
// skip if there are multiple inputs but inputs differ in group
if (!is_input_same(pgroup.get(), d_pgroups.get()))
continue;
+ // skip if pgroup has different group for other users of pgroup_i
+ if (!is_output_same(pgroup_i.get(), d_pgroups.get()))
+ continue;
// TODO add more condition may be needed
merge_into(pgroup.get(), pgroup_i.get());
diff --git a/compiler/luci/partition/src/PartitionPGroups.cpp b/compiler/luci/partition/src/PartitionPGroups.cpp
index 0080873e6..eaeacf9c4 100644
--- a/compiler/luci/partition/src/PartitionPGroups.cpp
+++ b/compiler/luci/partition/src/PartitionPGroups.cpp
@@ -46,6 +46,9 @@ public:
bool visit(const luci::CircleUniqueOut *) final { return true; }
bool visit(const luci::CircleUnpackOut *) final { return true; }
bool visit(const luci::CircleWhileOut *) final { return true; }
+ // For inputs not used
+ bool visit(const luci::CircleOutputExclude *) final { return true; }
+ bool visit(const luci::CircleVariable *) final { return true; }
// TODO add all virtual nodes
// default is false
@@ -69,59 +72,80 @@ bool check_allocate_partition(const luci::CircleNode *node)
return true;
}
-class FindGroupToFollow final : public luci::CircleNodeVisitor<const std::string &>
+} // namespace
+
+namespace
{
-public:
- FindGroupToFollow(const luci::PartitionTable &partition, luci::PGroups *pgroups)
- : _partition(partition), _pgroups(pgroups)
- {
- // NOTHING TODO
- }
-private:
- const std::string &groupof(const luci::CircleNode *input) const
+std::string group_from_partition(const luci::CircleNode *node,
+ const luci::PartitionTable &partition)
+{
+ LOGGER(l);
+
+ auto group = partition.default_group;
+
+ std::string opcodename; // opcodename or opname
+
+ switch (partition.comply)
{
- auto group = _pgroups->node2group[input];
- assert(not group.empty());
- if (group.empty())
- return _partition.default_group;
- return _pgroups->node2group[input];
+ case luci::PartitionTable::COMPLY::OPCODE:
+ {
+ opcodename = luci::opcode_name(node);
+ assert(!opcodename.empty());
+
+ auto it = partition.byopcodes.find(opcodename);
+ if (it != partition.byopcodes.end())
+ group = it->second;
+ break;
+ }
+ case luci::PartitionTable::COMPLY::OPNAME:
+ {
+ opcodename = node->name();
+ assert(!opcodename.empty());
+
+ auto it = partition.byopnames.find(opcodename);
+ if (it != partition.byopnames.end())
+ group = it->second;
+ break;
+ }
+
+ default:
+ throw std::runtime_error("Unsupported partition.comply");
}
+ INFO(l) << "Op: " << node->name() << ": " << opcodename << ", " << node << ", " << group
+ << std::endl;
+
+ return group;
+}
+
+class IsVirtualInputNode final : public luci::CircleNodeVisitor<bool>
+{
public:
-#define IMPLEMENT(CLASS) \
- const std::string &visit(const luci::CLASS *node) final \
- { \
- auto input = loco::must_cast<luci::CircleNode *>(node->input()); \
- return groupof(input); \
- }
+ // TODO check CircleOutputDummy
+ bool visit(const luci::CircleOutputExclude *) final { return true; }
+ bool visit(const luci::CircleVariable *) final { return true; }
- IMPLEMENT(CircleCustomOut);
- IMPLEMENT(CircleIfOut);
- IMPLEMENT(CircleNonMaxSuppressionV4Out);
- IMPLEMENT(CircleNonMaxSuppressionV5Out);
- IMPLEMENT(CircleSplitOut);
- IMPLEMENT(CircleSplitVOut);
- IMPLEMENT(CircleTopKV2Out);
- IMPLEMENT(CircleUniqueOut);
- IMPLEMENT(CircleUnpackOut);
- IMPLEMENT(CircleWhileOut);
-
-#undef IMPLEMENT
-
- // return empty for nothing to do
- const std::string &visit(const luci::CircleNode *) final { return _empty_str; }
-
-private:
- const luci::PartitionTable &_partition;
- luci::PGroups *_pgroups = nullptr;
- std::string _empty_str;
+ // default is false
+ bool visit(const luci::CircleNode *) final { return false; }
};
-} // namespace
-
-namespace
+class IsMultiOutputNode final : public luci::CircleNodeVisitor<bool>
{
+public:
+ bool visit(const luci::CircleCustom *) final { return true; }
+ bool visit(const luci::CircleIf *) final { return true; }
+ bool visit(const luci::CircleNonMaxSuppressionV4 *) final { return true; }
+ bool visit(const luci::CircleNonMaxSuppressionV5 *) final { return true; }
+ bool visit(const luci::CircleSplit *) final { return true; }
+ bool visit(const luci::CircleSplitV *) final { return true; }
+ bool visit(const luci::CircleTopKV2 *) final { return true; }
+ bool visit(const luci::CircleUnique *) final { return true; }
+ bool visit(const luci::CircleUnpack *) final { return true; }
+ bool visit(const luci::CircleWhile *) final { return true; }
+ // default is false
+ bool visit(const luci::CircleNode *) final { return false; }
+};
void append(luci::CircleNode *node, luci::PGroups *pgroups, const std::string &group, uint32_t idx)
{
@@ -136,17 +160,56 @@ void append(luci::CircleNode *node, luci::PGroups *pgroups, const std::string &g
pgroup->pnodes.push_back(std::move(pnode));
+ IsVirtualInputNode queryvi;
// Set input of PGroup
for (uint32_t in = 0; in < node->arity(); ++in)
{
auto input = loco::must_cast<luci::CircleNode *>(node->arg(in));
- // this input maybe CircleInput in source graph
- // --> not confident this is safe
- pgroup->inputs.push_back(input);
+ if (input->accept(&queryvi))
+ {
+ auto pnode = std::make_unique<luci::PNode>();
+ pnode->node = input;
+ pnode->group = group;
+ pnode->pgroup = pgroup.get();
+
+ pgroup->pnodes.push_back(std::move(pnode));
+
+ pgroups->node2group[input] = group;
+ }
+ else
+ {
+ // this input maybe CircleInput in source graph
+ // --> not confident this is safe
+ pgroup->inputs.push_back(input);
+ }
+ }
+
+ IsMultiOutputNode query;
+ if (node->accept(&query))
+ {
+ // Include CircleXXXOut virtual nodes in this group
+ auto succs = loco::succs(node);
+ for (auto &succ_node : succs)
+ {
+ auto nodeout = loco::must_cast<luci::CircleNode *>(succ_node);
+
+ auto pnode = std::make_unique<luci::PNode>();
+ pnode->node = nodeout;
+ pnode->group = group;
+ pnode->pgroup = pgroup.get();
+
+ pgroup->pnodes.push_back(std::move(pnode));
+
+ pgroups->node2group[nodeout] = group;
+
+ pgroup->outputs.push_back(nodeout);
+ }
+ }
+ else
+ {
+ // Set output of PGroup: node itself
+ pgroup->outputs.push_back(node);
}
- // Set output of PGroup: node itself or multiple virtual outputs
- // TODO support multiple virtual outputs
- pgroup->outputs.push_back(node);
pgroups->node2group[node] = group;
pgroups->id2pgroup[pgroup->id] = pgroup.get();
@@ -182,70 +245,9 @@ std::unique_ptr<luci::PGroups> produce_pgroups(const luci::Module *source,
// check if node is normal node that we are interested
if (check_allocate_partition(node))
{
- auto group = partition.default_group;
-
- std::string opcodename; // opcodename or opname
-
- switch (partition.comply)
- {
- case luci::PartitionTable::COMPLY::OPCODE:
- {
- opcodename = luci::opcode_name(node);
- assert(!opcodename.empty());
-
- auto it = partition.byopcodes.find(opcodename);
- if (it != partition.byopcodes.end())
- group = it->second;
- break;
- }
- case luci::PartitionTable::COMPLY::OPNAME:
- {
- opcodename = node->name();
- assert(!opcodename.empty());
-
- auto it = partition.byopnames.find(opcodename);
- if (it != partition.byopnames.end())
- group = it->second;
- break;
- }
-
- default:
- throw std::runtime_error("Unsupported partition.comply");
- }
-
- INFO(l) << "Op: " << node->name() << ": " << opcodename << ", " << node << ", " << group
- << std::endl;
+ auto group = group_from_partition(node, partition);
append(node, pgroups.get(), group, idx);
-#if 0
- auto pgroup = std::make_unique<luci::PGroup>();
- pgroup->group = group;
- pgroup->id = idx + 1;
-
- auto pnode = std::make_unique<luci::PNode>();
- pnode->node = node;
- pnode->group = group;
- pnode->pgroup = pgroup.get();
-
- pgroup->pnodes.push_back(std::move(pnode));
-
- // Set input of PGroup
- for (uint32_t in = 0; in < node->arity(); ++in)
- {
- auto input = loco::must_cast<luci::CircleNode *>(node->arg(in));
- // this input maybe CircleInput in source graph
- // --> not confident this is safe
- pgroup->inputs.push_back(input);
- }
- // Set output of PGroup: node itself or multiple virtual outputs
- // TODO support multiple virtual outputs
- pgroup->outputs.push_back(node);
-
- pgroups->node2group[node] = group;
- pgroups->id2pgroup[pgroup->id] = pgroup.get();
-
- pgroups->pgroups.push_back(std::move(pgroup));
-#endif
}
else
{
@@ -255,22 +257,6 @@ std::unique_ptr<luci::PGroups> produce_pgroups(const luci::Module *source,
}
}
- // handle for virtual nodes like multiple outputs
- // these nodes should follow group of the input
- for (uint32_t idx = 0; idx < nodes->size(); ++idx)
- {
- auto node = loco::must_cast<luci::CircleNode *>(nodes->at(idx));
-
- // for virtual nodes like CircleUnpackOut should follow it's input (owner)
- // or just set to default
- FindGroupToFollow query(partition, pgroups.get());
- const auto &group = node->accept(&query);
- if (not group.empty())
- {
- append(node, pgroups.get(), group, idx);
- }
- }
-
return std::move(pgroups);
}
diff --git a/compiler/luci/pass/CMakeLists.txt b/compiler/luci/pass/CMakeLists.txt
index b8b406a38..5237c6d3f 100644
--- a/compiler/luci/pass/CMakeLists.txt
+++ b/compiler/luci/pass/CMakeLists.txt
@@ -1,4 +1,4 @@
-nnas_find_package(FlatBuffers EXACT 1.12 QUIET)
+nnas_find_package(FlatBuffers EXACT 2.0 QUIET)
if(NOT FlatBuffers_FOUND)
message(STATUS "FlatBuffers NOT FOUND")
return()
@@ -23,11 +23,11 @@ target_link_libraries(luci_pass PRIVATE luci_log)
target_link_libraries(luci_pass PRIVATE luci_service)
target_link_libraries(luci_pass PRIVATE luci_logex)
target_link_libraries(luci_pass PRIVATE luci_profile)
-target_link_libraries(luci_pass PRIVATE mio_tflite260_inc)
+target_link_libraries(luci_pass PRIVATE mio_tflite280_inc)
target_link_libraries(luci_pass PRIVATE nncc_common)
target_link_libraries(luci_pass PRIVATE pepper_csv2vec)
target_link_libraries(luci_pass PRIVATE oops)
-target_link_libraries(luci_pass PRIVATE flatbuffers-1.12)
+target_link_libraries(luci_pass PRIVATE flatbuffers-2.0)
install(TARGETS luci_pass DESTINATION lib)
install(DIRECTORY include/ DESTINATION include
FILES_MATCHING PATTERN "*.h")
@@ -43,5 +43,5 @@ target_include_directories(luci_pass_test PRIVATE src)
target_link_libraries(luci_pass_test luci_pass)
target_link_libraries(luci_pass_test luci_lang)
target_link_libraries(luci_pass_test luci_testhelper)
-target_link_libraries(luci_pass_test flatbuffers-1.12)
+target_link_libraries(luci_pass_test flatbuffers-2.0)
#target_link_libraries(luci_pass_test oops)
diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h
index 658563ecf..c803898f6 100644
--- a/compiler/luci/pass/include/luci/CircleOptimizer.h
+++ b/compiler/luci/pass/include/luci/CircleOptimizer.h
@@ -47,15 +47,12 @@ public:
ResolveCustomOpBatchMatMul,
ResolveCustomOpMatMul,
ResolveCustomOpMaxPoolWithArgmax,
- QuantizeDequantizeWeights,
- QuantizeWithMinMax,
- Requantize,
FoldAddV2,
FoldCast,
FoldDepthwiseConv2D,
FoldDequantize,
+ FoldGather,
FoldSparseToDense,
- ForceQuantParam,
ForwardReshapeToUnaryOp,
SparsifyTensorPass,
FusePreActivationBatchNorm,
@@ -79,6 +76,7 @@ public:
TransformMinReluToRelu6Pass,
SubstituteStridedSliceToReshape,
SubstituteTransposeToReshape,
+ RemoveRedundantQuantize,
RemoveRedundantReshape,
RemoveFakeQuant,
RemoveQuantDequantSeq,
@@ -86,16 +84,6 @@ public:
enum AlgorithmParameters
{
- // quantize
- Quantize_input_model_dtype,
- Quantize_output_model_dtype,
- Quantize_granularity, // layer-wise or channel-wise
- Quantize_tensor_names,
- Quantize_scales,
- Quantize_zero_points,
- Quantize_input_type,
- Quantize_output_type,
-
// sparsify
Sparsify_tensor_name,
Sparsify_traversal_order,
@@ -114,8 +102,6 @@ public:
virtual bool query(Algorithm) = 0;
virtual void param(AlgorithmParameters, const std::string &) = 0;
virtual const std::string param(AlgorithmParameters) const = 0;
- virtual void params(AlgorithmParameters, std::vector<std::string> &) = 0;
- virtual std::vector<std::string> params(AlgorithmParameters) const = 0;
};
public:
@@ -127,8 +113,6 @@ public:
void optimize(loco::Graph *) const;
- void quantize(loco::Graph *) const;
-
void sparsify(loco::Graph *) const;
private:
diff --git a/compiler/luci/pass/include/luci/CircleQuantizer.h b/compiler/luci/pass/include/luci/CircleQuantizer.h
new file mode 100644
index 000000000..4e7074d98
--- /dev/null
+++ b/compiler/luci/pass/include/luci/CircleQuantizer.h
@@ -0,0 +1,97 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_CIRCLE_QUANTIZER_H__
+#define __LUCI_CIRCLE_QUANTIZER_H__
+
+#include <loco.h>
+
+#include <string>
+#include <vector>
+
+namespace luci
+{
+
+class CircleQuantizer final
+{
+public:
+ struct Options
+ {
+ struct LayerParam
+ {
+ std::string name;
+ std::string dtype;
+ std::string granularity;
+ };
+
+ enum Algorithm
+ {
+ QuantizeDequantizeWeights,
+ QuantizeWithMinMax,
+ Requantize,
+ CopyQuantParam,
+ ForceQuantParam,
+ ConvertToFakeQuantizedModel,
+ };
+
+ enum AlgorithmParameters
+ {
+ // quantize
+ Quantize_input_model_dtype,
+ Quantize_output_model_dtype,
+ Quantize_granularity, // layer-wise or channel-wise
+ Quantize_tensor_names,
+ Quantize_scales,
+ Quantize_zero_points,
+ Quantize_layer_params,
+
+ // copy_quantparam
+ Quantize_src_tensor_names,
+ Quantize_dst_tensor_names,
+
+ Quantize_input_type,
+ Quantize_output_type,
+ Quantize_TF_style_maxpool,
+ };
+
+ virtual ~Options() = default;
+
+ virtual void enable(Algorithm) = 0;
+ virtual bool query(Algorithm) = 0;
+ virtual void param(AlgorithmParameters, const std::string &) = 0;
+ virtual const std::string param(AlgorithmParameters) const = 0;
+ virtual void params(AlgorithmParameters, std::vector<std::string> &) = 0;
+ virtual std::vector<std::string> params(AlgorithmParameters) const = 0;
+
+ // Quantization parameters for multiple layers
+ virtual void layer_params(AlgorithmParameters, std::vector<std::shared_ptr<LayerParam>> &) = 0;
+ virtual std::vector<std::shared_ptr<LayerParam>> layer_params(AlgorithmParameters) const = 0;
+ };
+
+public:
+ // TODO maybe caller can provide Options as ctor parameters
+ Options *options(void);
+
+public:
+ void quantize(loco::Graph *) const;
+
+private:
+ std::unique_ptr<Options> _options;
+};
+
+} // namespace luci
+
+#endif // __LUCI_CIRCLE_QUANTIZER_H__
diff --git a/compiler/luci/pass/include/luci/Pass/ConvertToFakeQuantizedModelPass.h b/compiler/luci/pass/include/luci/Pass/ConvertToFakeQuantizedModelPass.h
new file mode 100644
index 000000000..91dd2300e
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/ConvertToFakeQuantizedModelPass.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_CONVERT_TO_FAKE_QUANTIZED_MODEL_PASS_H__
+#define __LUCI_CONVERT_TO_FAKE_QUANTIZED_MODEL_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to convert a quantized model to a fake-quantized fp32 model.
+ */
+struct ConvertToFakeQuantizedModelPass final : public logo::Pass
+{
+ ConvertToFakeQuantizedModelPass() {}
+
+ const char *name(void) const final { return "luci::ConvertToFakeQuantizedModelPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_CONVERT_TO_FAKE_QUANTIZED_MODEL_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/CopyQuantParamPass.h b/compiler/luci/pass/include/luci/Pass/CopyQuantParamPass.h
new file mode 100644
index 000000000..18c9cd56a
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/CopyQuantParamPass.h
@@ -0,0 +1,53 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_COPY_QUANT_PARAM_PASS_H__
+#define __LUCI_COPY_QUANT_PARAM_PASS_H__
+
+#include <loco.h>
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Pass to copy quantparam (scale, zerop) of a tensor to another tensor
+ */
+class CopyQuantParamPass : public logo::Pass
+{
+public:
+ using TensorVector = std::vector<std::string>;
+
+public:
+ CopyQuantParamPass(TensorVector &src_tensors, TensorVector &dst_tensors)
+ : _src_tensors{src_tensors}, _dst_tensors{dst_tensors}
+ {
+ // DO NOTHING
+ }
+ virtual const char *name(void) const { return "luci::CopyQuantParamPass"; }
+
+public:
+ bool run(loco::Graph *graph);
+
+private:
+ TensorVector _src_tensors;
+ TensorVector _dst_tensors;
+};
+
+} // namespace luci
+
+#endif //__LUCI_COPY_QUANT_PARAM_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/FoldGatherPass.h b/compiler/luci/pass/include/luci/Pass/FoldGatherPass.h
new file mode 100644
index 000000000..de08c8845
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/FoldGatherPass.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_FOLD_GATHER_PASS_H__
+#define __LUCI_FOLD_GATHER_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to fold Gather to a constant tensor
+ *
+ */
+struct FoldGatherPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::FoldGatherPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_FOLD_GATHER_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/PropagateQParamBackwardPass.h b/compiler/luci/pass/include/luci/Pass/PropagateQParamBackwardPass.h
new file mode 100644
index 000000000..0c489fc30
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/PropagateQParamBackwardPass.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_PROPAGATE_QPARAM_BACKWARD_PASS_H__
+#define __LUCI_PROPAGATE_QPARAM_BACKWARD_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to propagate quantization parameters of an operator's output to input
+ */
+struct PropagateQParamBackwardPass final : public logo::Pass
+{
+ PropagateQParamBackwardPass(loco::DataType output) : _output_model_dtype(output) {}
+
+ const char *name(void) const final { return "luci::PropagateQParamBackwardPass"; }
+
+ bool run(loco::Graph *g) final;
+
+private:
+ loco::DataType _output_model_dtype;
+};
+
+} // namespace luci
+
+#endif // __LUCI_PROPAGATE_QPARAM_BACKWARD_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/PropagateQuantParamPass.h b/compiler/luci/pass/include/luci/Pass/PropagateQParamForwardPass.h
index 7e0c44b8c..952bd9614 100644
--- a/compiler/luci/pass/include/luci/Pass/PropagateQuantParamPass.h
+++ b/compiler/luci/pass/include/luci/Pass/PropagateQParamForwardPass.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef __LUCI_PROPAGATE_QUANT_PARAM_PASS_H__
-#define __LUCI_PROPAGATE_QUANT_PARAM_PASS_H__
+#ifndef __LUCI_PROPAGATE_QPARAM_FORWARD_PASS_H__
+#define __LUCI_PROPAGATE_QPARAM_FORWARD_PASS_H__
#include <logo/Pass.h>
@@ -23,15 +23,22 @@ namespace luci
{
/**
- * @brief Class to propagate quantization parameters of an operator's output to input
+ * @brief Class to propagate quantization parameters of an operator's input to output
*/
-struct PropagateQuantParamPass final : public logo::Pass
+struct PropagateQParamForwardPass final : public logo::Pass
{
- const char *name(void) const final { return "luci::PropagateQuantParamPass"; }
+ PropagateQParamForwardPass(bool TF_style_maxpool) : _TF_style_maxpool(TF_style_maxpool) {}
+
+ PropagateQParamForwardPass() {}
+
+ const char *name(void) const final { return "luci::PropagateQParamForwardPass"; }
bool run(loco::Graph *g) final;
+
+private:
+ bool _TF_style_maxpool = false;
};
} // namespace luci
-#endif // __LUCI_PROPAGATE_QUANT_PARAM_PASS_H__
+#endif // __LUCI_PROPAGATE_QPARAM_FORWARD_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/QuantizationParameters.h b/compiler/luci/pass/include/luci/Pass/QuantizationParameters.h
index 5c9cd427f..30c8db058 100644
--- a/compiler/luci/pass/include/luci/Pass/QuantizationParameters.h
+++ b/compiler/luci/pass/include/luci/Pass/QuantizationParameters.h
@@ -17,6 +17,10 @@
#ifndef __LUCI_QUANTIZATION_PARAMETERS_H__
#define __LUCI_QUANTIZATION_PARAMETERS_H__
+#include <loco.h>
+
+#include <string>
+
namespace luci
{
@@ -26,6 +30,13 @@ enum QuantizationGranularity
ChannelWise = 1,
};
+struct LayerInfo
+{
+ std::string name;
+ loco::DataType dtype;
+ QuantizationGranularity granularity;
+};
+
} // namespace luci
#endif // __LUCI_QUANTIZATION_PARAMETERS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/QuantizeDequantizeWeightsPass.h b/compiler/luci/pass/include/luci/Pass/QuantizeDequantizeWeightsPass.h
index 68765ec5b..1825ee1aa 100644
--- a/compiler/luci/pass/include/luci/Pass/QuantizeDequantizeWeightsPass.h
+++ b/compiler/luci/pass/include/luci/Pass/QuantizeDequantizeWeightsPass.h
@@ -32,12 +32,30 @@ namespace luci
class QuantizeDequantizeWeightsPass : public logo::Pass
{
public:
+ struct Context
+ {
+ loco::DataType input_model_dtype = loco::DataType::Unknown;
+ loco::DataType output_model_dtype = loco::DataType::Unknown;
+ QuantizationGranularity granularity = QuantizationGranularity::ChannelWise;
+ std::vector<LayerInfo> layers_info;
+ };
+
+public:
+ QuantizeDequantizeWeightsPass(std::unique_ptr<Context> &&ctx) : _ctx{std::move(ctx)}
+ {
+ // DO NOTHING
+ }
+
+public:
QuantizeDequantizeWeightsPass(loco::DataType input_model_dtype, loco::DataType output_model_dtype,
QuantizationGranularity granularity)
- : _input_model_dtype{input_model_dtype}, _output_model_dtype{output_model_dtype}, _granularity{
- granularity}
{
- // DO NOTHING
+ _ctx = std::make_unique<Context>();
+ {
+ _ctx->input_model_dtype = input_model_dtype;
+ _ctx->output_model_dtype = output_model_dtype;
+ _ctx->granularity = granularity;
+ }
}
virtual const char *name(void) const { return "luci::QuantizeDequantizeWeightsPass"; }
@@ -45,9 +63,7 @@ public:
bool run(loco::Graph *graph);
private:
- loco::DataType _input_model_dtype;
- loco::DataType _output_model_dtype;
- QuantizationGranularity _granularity;
+ std::unique_ptr<Context> _ctx;
};
} // namespace luci
diff --git a/compiler/luci/pass/include/luci/Pass/QuantizePreCheckerPass.h b/compiler/luci/pass/include/luci/Pass/QuantizePreCheckerPass.h
new file mode 100644
index 000000000..c852f88e0
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/QuantizePreCheckerPass.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_QUANTIZE_PRE_CHECKER_PASS_H__
+#define __LUCI_QUANTIZE_PRE_CHECKER_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Pass to verify the input model has the form acceptable by quantizer
+ */
+class QuantizePreCheckerPass : public logo::Pass
+{
+public:
+ const char *name(void) const final { return "luci::QuantizePreCheckerPass"; }
+
+public:
+ bool run(loco::Graph *graph) final;
+};
+
+} // namespace luci
+
+#endif //__LUCI_QUANTIZE_PRE_CHECKER_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h b/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h
index 648abad70..ea6db85d1 100644
--- a/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h
+++ b/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h
@@ -23,6 +23,8 @@
#include <luci/Pass/QuantizationParameters.h>
+#include <vector>
+
namespace luci
{
@@ -31,26 +33,41 @@ namespace luci
*/
class QuantizeWithMinMaxPass : public logo::Pass
{
+public:
+ struct Context
+ {
+ loco::DataType input_model_dtype = loco::DataType::Unknown;
+ loco::DataType output_model_dtype = loco::DataType::Unknown;
+ QuantizationGranularity granularity = QuantizationGranularity::ChannelWise;
+ loco::DataType input_type = loco::DataType::Unknown;
+ loco::DataType output_type = loco::DataType::Unknown;
+ bool TF_style_maxpool = false;
+ std::vector<LayerInfo> layers_info;
+ };
+
// For backward-compatibility
// TODO Remove this constructor
public:
QuantizeWithMinMaxPass(loco::DataType input_model_dtype, loco::DataType output_model_dtype,
QuantizationGranularity granularity)
- : _input_model_dtype{input_model_dtype}, _output_model_dtype{output_model_dtype},
- _granularity{granularity}, _input_type{output_model_dtype}, _output_type{output_model_dtype}
{
- // DO NOTHING
+ _ctx = std::make_unique<Context>();
+ {
+ _ctx->input_model_dtype = input_model_dtype;
+ _ctx->output_model_dtype = output_model_dtype;
+ _ctx->granularity = granularity;
+ _ctx->input_type = output_model_dtype;
+ _ctx->output_type = output_model_dtype;
+ _ctx->TF_style_maxpool = false;
+ }
}
public:
- QuantizeWithMinMaxPass(loco::DataType input_model_dtype, loco::DataType output_model_dtype,
- QuantizationGranularity granularity, loco::DataType input_type,
- loco::DataType output_type)
- : _input_model_dtype{input_model_dtype}, _output_model_dtype{output_model_dtype},
- _granularity{granularity}, _input_type{input_type}, _output_type{output_type}
+ QuantizeWithMinMaxPass(std::unique_ptr<Context> &&ctx) : _ctx{std::move(ctx)}
{
// DO NOTHING
}
+
virtual const char *name(void) const { return "luci::QuantizeWithMinMaxPass"; }
public:
@@ -61,11 +78,7 @@ private:
void set_output_type(loco::Graph *graph) const;
private:
- loco::DataType _input_model_dtype;
- loco::DataType _output_model_dtype;
- QuantizationGranularity _granularity;
- loco::DataType _input_type;
- loco::DataType _output_type;
+ std::unique_ptr<Context> _ctx;
};
} // namespace luci
diff --git a/compiler/luci/pass/include/luci/Pass/RemoveRedundantQuantizePass.h b/compiler/luci/pass/include/luci/Pass/RemoveRedundantQuantizePass.h
new file mode 100644
index 000000000..3e76bcdc3
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/RemoveRedundantQuantizePass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_REMOVE_REDUNDANT_QUANTIZE_PASS_H__
+#define __LUCI_REMOVE_REDUNDANT_QUANTIZE_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to remove redundant quantize operations
+ */
+struct RemoveRedundantQuantizePass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::RemoveRedundantQuantizePass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_REMOVE_REDUNDANT_QUANTIZE_PASS_H__
diff --git a/compiler/luci/pass/src/BatchNormPatternFinder.cpp b/compiler/luci/pass/src/BatchNormPatternFinder.cpp
index c1a06bfda..e3f126b15 100644
--- a/compiler/luci/pass/src/BatchNormPatternFinder.cpp
+++ b/compiler/luci/pass/src/BatchNormPatternFinder.cpp
@@ -44,10 +44,26 @@ bool is_batchnorm_add(const luci::CircleAdd *add, luci::CircleMul *&mul, luci::C
return false;
}
- if (constant->rank() != 1)
+ uint32_t channel_dim = 0;
+
+ if (constant->rank() == 1)
+ {
+ channel_dim = constant->dim(0).value();
+ }
+ else if (constant->rank() == 4)
+ {
+ for (uint32_t i = 0; i < 3; i++)
+ {
+ if (constant->dim(i).value() != 1)
+ return false;
+ }
+ channel_dim = constant->dim(3).value();
+ }
+ else
+ {
return false;
+ }
- auto channel_dim = constant->dim(0);
// Assumption: Layout is channel-last
if (!(channel_dim == add->dim(add->rank() - 1)))
return false;
@@ -90,10 +106,26 @@ bool is_batchnorm_mul(const luci::CircleMul *mul, luci::CircleNode *&pred_node,
return false;
}
- if (constant->rank() != 1)
+ uint32_t channel_dim = 0;
+
+ if (constant->rank() == 1)
+ {
+ channel_dim = constant->dim(0).value();
+ }
+ else if (constant->rank() == 4)
+ {
+ for (uint32_t i = 0; i < 3; i++)
+ {
+ if (constant->dim(i).value() != 1)
+ return false;
+ }
+ channel_dim = constant->dim(3).value();
+ }
+ else
+ {
return false;
+ }
- auto channel_dim = constant->dim(0);
// Assumption: Layout is channel-last
if (!(channel_dim == mul->dim(mul->rank() - 1)))
return false;
diff --git a/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp b/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp
index 08e7fac1c..cc8c5615f 100644
--- a/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp
+++ b/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp
@@ -50,7 +50,7 @@ public:
auto channel_size = *last_it;
_add->shape(shape);
- _add_beta->shape({channel_size});
+ set_beta_shape(channel_size);
_add_beta->size<loco::DataType::FLOAT32>(channel_size);
for (uint32_t i = 0; i < channel_size; i++)
_add_beta->at<loco::DataType::FLOAT32>(i) = i;
@@ -63,10 +63,23 @@ public:
luci::CircleAdd *add() { return _add; }
protected:
+ virtual void set_beta_shape(uint32_t channel) = 0;
+
+protected:
luci::CircleAdd *_add = nullptr;
luci::CircleConst *_add_beta = nullptr;
};
+class AddRank1BetaGraphlet : public AddBetaGraphlet
+{
+ void set_beta_shape(uint32_t channel) final { _add_beta->shape({channel}); }
+};
+
+class AddRank4BetaGraphlet : public AddBetaGraphlet
+{
+ void set_beta_shape(uint32_t channel) final { _add_beta->shape({1, 1, 1, channel}); }
+};
+
/**
* @brief Graphlet with Mul and Const as gamma from BatchNorm
*/
@@ -90,7 +103,7 @@ public:
auto channel_size = *last_it;
_mul->shape(shape);
- _mul_gamma->shape({channel_size});
+ set_gamma_shape(channel_size);
_mul_gamma->size<loco::DataType::FLOAT32>(channel_size);
for (uint32_t i = 0; i < channel_size; i++)
_mul_gamma->at<loco::DataType::FLOAT32>(i) = i;
@@ -103,14 +116,27 @@ public:
luci::CircleMul *mul(void) { return _mul; }
protected:
+ virtual void set_gamma_shape(uint32_t channel) = 0;
+
+protected:
luci::CircleMul *_mul = nullptr;
luci::CircleConst *_mul_gamma = nullptr;
};
+class MulRank1GammaGraphlet : public MulGammaGraphlet
+{
+ void set_gamma_shape(uint32_t channel) final { _mul_gamma->shape({channel}); }
+};
+
+class MulRank4GammaGraphlet : public MulGammaGraphlet
+{
+ void set_gamma_shape(uint32_t channel) final { _mul_gamma->shape({1, 1, 1, channel}); }
+};
+
/**
* @brief Graph of Mul-Add pattern from BatchNorm
*/
-class MulAddGraph : public TestIOGraph, public AddBetaGraphlet, public MulGammaGraphlet
+class MulAddGraph : public TestIOGraph, public AddRank1BetaGraphlet, public MulRank1GammaGraphlet
{
public:
MulAddGraph() = default;
@@ -118,8 +144,30 @@ public:
void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
{
TestIOGraph::init(shape_in, shape_out);
- MulGammaGraphlet::init(g(), shape_in, luci::FusedActFunc::NONE);
- AddBetaGraphlet::init(g(), shape_out, luci::FusedActFunc::RELU);
+ MulRank1GammaGraphlet::init(g(), shape_in, luci::FusedActFunc::NONE);
+ AddRank1BetaGraphlet::init(g(), shape_out, luci::FusedActFunc::RELU);
+
+ // connect network
+ _mul->x(input());
+ _mul->y(_mul_gamma);
+ _add->x(_mul);
+ _add->y(_add_beta);
+ output()->from(_add);
+ }
+};
+
+class MulAddRank4Graph : public TestIOGraph,
+ public AddRank4BetaGraphlet,
+ public MulRank4GammaGraphlet
+{
+public:
+ MulAddRank4Graph() = default;
+
+ void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
+ {
+ TestIOGraph::init(shape_in, shape_out);
+ MulRank4GammaGraphlet::init(g(), shape_in, luci::FusedActFunc::NONE);
+ AddRank4BetaGraphlet::init(g(), shape_out, luci::FusedActFunc::RELU);
// connect network
_mul->x(input());
@@ -133,7 +181,7 @@ public:
/**
* @brief Graph of Add with Const
*/
-class AddGraph : public TestIOGraph, public AddBetaGraphlet
+class AddGraph : public TestIOGraph, public AddRank1BetaGraphlet
{
public:
AddGraph() = default;
@@ -141,7 +189,24 @@ public:
void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
{
TestIOGraph::init(shape_in, shape_out);
- AddBetaGraphlet::init(g(), shape_in, luci::FusedActFunc::RELU);
+ AddRank1BetaGraphlet::init(g(), shape_in, luci::FusedActFunc::RELU);
+
+ // connect network
+ _add->x(input());
+ _add->y(_add_beta);
+ output()->from(_add);
+ }
+};
+
+class AddRank4Graph : public TestIOGraph, public AddRank4BetaGraphlet
+{
+public:
+ AddRank4Graph() = default;
+
+ void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
+ {
+ TestIOGraph::init(shape_in, shape_out);
+ AddRank4BetaGraphlet::init(g(), shape_in, luci::FusedActFunc::RELU);
// connect network
_add->x(input());
@@ -160,6 +225,7 @@ public:
protected:
luci::test::MulAddGraph _mag;
+ luci::test::MulAddRank4Graph _mag_r4;
};
class BatchNormPatternFinderAddTest : public ::testing::Test
@@ -169,6 +235,7 @@ public:
protected:
luci::test::AddGraph _ag;
+ luci::test::AddRank4Graph _ag_r4;
};
TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_add)
@@ -192,6 +259,19 @@ TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_add2)
ASSERT_TRUE(res);
}
+TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_add_rank4)
+{
+ _mag_r4.init({1, 16, 16, 4}, {1, 16, 16, 4});
+
+ luci::CircleMul *mul = nullptr;
+ luci::CircleConst *beta = nullptr;
+
+ auto res = luci::is_batchnorm_add(_mag_r4.add(), mul, beta);
+ ASSERT_TRUE(res);
+ ASSERT_NE(nullptr, mul);
+ ASSERT_NE(nullptr, beta);
+}
+
TEST_F(BatchNormPatternFinderAddTest, is_batchnorm_add_NEG)
{
_ag.init({1, 16, 16, 4}, {1, 16, 16, 4});
@@ -215,3 +295,16 @@ TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_mul)
ASSERT_NE(nullptr, pred);
ASSERT_NE(nullptr, gamma);
}
+
+TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_mul_rank4)
+{
+ _mag_r4.init({1, 16, 16, 4}, {1, 16, 16, 4});
+
+ luci::CircleNode *pred = nullptr;
+ luci::CircleConst *gamma = nullptr;
+
+ auto res = luci::is_batchnorm_mul(_mag_r4.mul(), pred, gamma);
+ ASSERT_TRUE(res);
+ ASSERT_NE(nullptr, pred);
+ ASSERT_NE(nullptr, gamma);
+}
diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp
index 75f04b3b5..6dbb22d7c 100644
--- a/compiler/luci/pass/src/CircleOptimizer.cpp
+++ b/compiler/luci/pass/src/CircleOptimizer.cpp
@@ -22,9 +22,9 @@
#include "luci/Pass/FoldCastPass.h"
#include "luci/Pass/FoldDepthwiseConv2DPass.h"
#include "luci/Pass/FoldDequantizePass.h"
+#include "luci/Pass/FoldGatherPass.h"
#include "luci/Pass/FoldSparseToDensePass.h"
#include "luci/Pass/ForwardReshapeToUnaryOpPass.h"
-#include "luci/Pass/ForceQuantParamPass.h"
#include "luci/Pass/FuseActivationFunctionPass.h"
#include "luci/Pass/FuseAddWithFullyConnectedPass.h"
#include "luci/Pass/FuseAddWithTConvPass.h"
@@ -37,11 +37,11 @@
#include "luci/Pass/FusePreActivationBatchNormPass.h"
#include "luci/Pass/FuseTransposeWithMeanPass.h"
#include "luci/Pass/MakeBatchNormGammaPositivePass.h"
-#include "luci/Pass/PropagateQuantParamPass.h"
#include "luci/Pass/RemoveFakeQuantPass.h"
#include "luci/Pass/RemoveQuantDequantSeqPass.h"
#include "luci/Pass/RemoveRedundantReshapePass.h"
#include "luci/Pass/RemoveRedundantTransposePass.h"
+#include "luci/Pass/RemoveRedundantQuantizePass.h"
#include "luci/Pass/RemoveUnnecessaryReshapePass.h"
#include "luci/Pass/RemoveUnnecessarySlicePass.h"
#include "luci/Pass/RemoveUnnecessaryStridedSlicePass.h"
@@ -52,9 +52,6 @@
#include "luci/Pass/ResolveCustomOpBatchMatMulPass.h"
#include "luci/Pass/ResolveCustomOpMatMulPass.h"
#include "luci/Pass/ResolveCustomOpMaxPoolWithArgmaxPass.h"
-#include "luci/Pass/RequantizePass.h"
-#include "luci/Pass/QuantizeWithMinMaxPass.h"
-#include "luci/Pass/QuantizeDequantizeWeightsPass.h"
#include "luci/Pass/SparsifyTensorPass.h"
#include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h"
#include "luci/Pass/SubstitutePackToReshapePass.h"
@@ -75,9 +72,6 @@
#include "ModulePhase.h"
#include "ProgressReporter.h"
-#include "helpers/Strings.h"
-
-#include "QuantizedModelVerifier.h"
#include <luci/IR/CircleNodes.h>
#include <logo/Phase.h>
@@ -91,37 +85,17 @@ namespace
using namespace luci;
-template <typename T> T lexical_cast(const std::string &str)
-{
- std::istringstream ss;
- ss.str(str);
- T data;
- ss >> data;
- return data;
-}
-
-template <typename T> std::vector<T> lexical_cast(std::vector<std::string> &sv)
-{
- std::vector<T> result;
- std::transform(sv.begin(), sv.end(), std::back_inserter(result),
- [](std::string str) -> T { return lexical_cast<T>(str); });
- return result;
-}
-
class OptimizeOptionsImpl final : public luci::CircleOptimizer::Options
{
public:
void enable(Algorithm) final;
void param(AlgorithmParameters, const std::string &) final;
const std::string param(AlgorithmParameters) const final;
- void params(AlgorithmParameters, std::vector<std::string> &) final;
- std::vector<std::string> params(AlgorithmParameters) const final;
bool query(Algorithm) final;
private:
std::vector<Algorithm> _algorithms;
std::map<AlgorithmParameters, const std::string> _algorithm_params;
- std::map<AlgorithmParameters, std::vector<std::string>> _multiple_params;
};
void OptimizeOptionsImpl::enable(Algorithm algo) { _algorithms.push_back(algo); }
@@ -144,24 +118,6 @@ const std::string OptimizeOptionsImpl::param(AlgorithmParameters param) const
}
}
-void OptimizeOptionsImpl::params(AlgorithmParameters param, std::vector<std::string> &vec)
-{
- _multiple_params[param] = vec;
-}
-
-std::vector<std::string> OptimizeOptionsImpl::params(AlgorithmParameters param) const
-{
- auto param_vec = _multiple_params.find(param);
- if (param_vec != _multiple_params.end())
- {
- return param_vec->second;
- }
- else
- {
- return std::vector<std::string>();
- }
-}
-
bool OptimizeOptionsImpl::query(Algorithm algo)
{
std::vector<Algorithm>::iterator it = std::find(_algorithms.begin(), _algorithms.end(), algo);
@@ -312,6 +268,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<luci::FoldDequantizePass>());
}
+ if (_options->query(Options::Algorithm::FoldGather))
+ {
+ phase.emplace_back(std::make_unique<luci::FoldGatherPass>());
+ }
if (_options->query(Options::Algorithm::FoldSparseToDense))
{
phase.emplace_back(std::make_unique<luci::FoldSparseToDensePass>());
@@ -368,6 +328,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<luci::RemoveRedundantTransposePass>());
}
+ if (_options->query(Options::Algorithm::RemoveRedundantQuantize))
+ {
+ phase.emplace_back(std::make_unique<luci::RemoveRedundantQuantizePass>());
+ }
if (_options->query(Options::Algorithm::ReplaceMulAddWithDepthwiseConv))
{
phase.emplace_back(std::make_unique<luci::ReplaceMulAddWithDepthwiseConvPass>());
@@ -417,174 +381,6 @@ void CircleOptimizer::optimize(loco::Graph *g) const
phase_runner.run(phase);
}
-void CircleOptimizer::quantize(loco::Graph *g) const
-{
- // Fake quantization of weights
- if (_options->query(Options::Algorithm::QuantizeDequantizeWeights))
- {
- static const std::vector<std::string> fakeq_supported_input_model_dtype{"float32"};
- static const std::vector<std::string> fakeq_supported_output_model_dtype{"uint8", "int16"};
- static const std::vector<std::string> fakeq_supported_granularity{"layer", "channel"};
-
- auto input_model_dtype =
- _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
- auto output_model_dtype =
- _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
- auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
-
- if (!in_array(to_lower_case(input_model_dtype), fakeq_supported_input_model_dtype))
- throw std::runtime_error("Unsupported input type. List of supported input type: " +
- to_string(fakeq_supported_input_model_dtype));
-
- if (!in_array(to_lower_case(output_model_dtype), fakeq_supported_output_model_dtype))
- throw std::runtime_error("Unsupported output type. List of supported output type: " +
- to_string(fakeq_supported_output_model_dtype));
-
- if (!in_array(to_lower_case(granularity), fakeq_supported_granularity))
- throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
- to_string(fakeq_supported_granularity));
-
- if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise &&
- str_to_dtype(output_model_dtype) != loco::DataType::U8)
- throw std::runtime_error("Layer-wise quantization only supports uint8 dtype.");
-
- // Clear existing quantparams before doing fake quantization
- for (auto node : loco::active_nodes(loco::output_nodes(g)))
- {
- auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- if (circle_node->quantparam() != nullptr)
- circle_node->quantparam(nullptr);
- }
-
- luci::QuantizeDequantizeWeightsPass fake_quantizer(str_to_dtype(input_model_dtype),
- str_to_dtype(output_model_dtype),
- str_to_granularity(granularity));
- fake_quantizer.run(g);
- }
-
- // Actual quantization of weights, bias, and activation
- if (_options->query(Options::Algorithm::QuantizeWithMinMax))
- {
- static const std::vector<std::string> qwmm_supported_input_model_dtype{"float32"};
- static const std::vector<std::string> qwmm_supported_output_model_dtype{"uint8", "int16"};
- static const std::vector<std::string> qwmm_supported_granularity{"layer", "channel"};
- static const std::vector<std::string> qwmm_supported_input_type{"uint8", "int16"};
- static const std::vector<std::string> qwmm_supported_output_type{"uint8", "int16"};
-
- auto input_model_dtype =
- _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
- auto output_model_dtype =
- _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
- auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
- auto input_type = _options->param(Options::AlgorithmParameters::Quantize_input_type);
- if (input_type.empty())
- input_type = output_model_dtype;
- auto output_type = _options->param(Options::AlgorithmParameters::Quantize_output_type);
- if (output_type.empty())
- output_type = output_model_dtype;
-
- if (!in_array(to_lower_case(input_model_dtype), qwmm_supported_input_model_dtype))
- throw std::runtime_error("Unsupported input type. List of supported input types: " +
- to_string(qwmm_supported_input_model_dtype));
-
- if (!in_array(to_lower_case(output_model_dtype), qwmm_supported_output_model_dtype))
- throw std::runtime_error("Unsupported output type. List of supported output types: " +
- to_string(qwmm_supported_output_model_dtype));
-
- if (!in_array(to_lower_case(granularity), qwmm_supported_granularity))
- throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
- to_string(qwmm_supported_granularity));
-
- if (!in_array(to_lower_case(input_type), qwmm_supported_input_type))
- throw std::runtime_error("Unsupported input type. List of supported input types: " +
- to_string(qwmm_supported_input_type));
-
- if (!in_array(to_lower_case(output_type), qwmm_supported_output_type))
- throw std::runtime_error("Unsupported output type. List of supported output types: " +
- to_string(qwmm_supported_output_type));
-
- if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise &&
- str_to_dtype(output_model_dtype) != loco::DataType::U8)
- throw std::runtime_error("Layer-wise quantization only supports uint8 dtype.");
-
- luci::QuantizeWithMinMaxPass quantizer(
- str_to_dtype(input_model_dtype), str_to_dtype(output_model_dtype),
- str_to_granularity(granularity), str_to_dtype(input_type), str_to_dtype(output_type));
- quantizer.run(g);
-
- // Post-quantization optimizations
- logo::Phase phase;
-
- phase.emplace_back(std::make_unique<luci::PropagateQuantParamPass>());
-
- phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
- phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
- phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
-
- ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
- logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
- phase_runner.attach(&prog);
- phase_runner.run(phase);
-
- // Verify the type/granularity of the quantized model
- luci::QuantizedModelVerifier verifier(str_to_dtype(output_model_dtype),
- str_to_granularity(granularity));
- verifier.verify(g);
- }
-
- // Requantize
- if (_options->query(Options::Algorithm::Requantize))
- {
- static const std::vector<std::string> rq_supported_input_model_dtype{"int8"};
- static const std::vector<std::string> rq_supported_output_model_dtype{"uint8"};
-
- auto input_model_dtype =
- _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
- auto output_model_dtype =
- _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
-
- if (!in_array(to_lower_case(input_model_dtype), rq_supported_input_model_dtype))
- throw std::runtime_error("Unsupported input type. List of supported input types: " +
- to_string(rq_supported_input_model_dtype));
-
- if (!in_array(to_lower_case(output_model_dtype), rq_supported_output_model_dtype))
- throw std::runtime_error("Unsupported output type. List of supported output types: " +
- to_string(rq_supported_output_model_dtype));
-
- luci::RequantizePass requantizer(str_to_dtype(input_model_dtype),
- str_to_dtype(output_model_dtype));
- requantizer.run(g);
- }
-
- // Force to write quantparam to specified tensors
- // NOTE Only per-tensor (not per-channel) qparam can be written
- if (_options->query(Options::Algorithm::ForceQuantParam))
- {
- ForceQuantParamPass::TensorVector tensors =
- _options->params(Options::AlgorithmParameters::Quantize_tensor_names);
- auto str_scales = _options->params(Options::AlgorithmParameters::Quantize_scales);
- auto str_zero_points = _options->params(Options::AlgorithmParameters::Quantize_zero_points);
-
- // Cast scales/zero_points to proper types
- ForceQuantParamPass::ScaleVector scales = lexical_cast<float>(str_scales);
- ForceQuantParamPass::ZPVector zero_points = lexical_cast<int64_t>(str_zero_points);
-
- ForceQuantParamPass fq(tensors, scales, zero_points);
- fq.run(g);
- }
-
- logo::Phase phase;
-
- // Do Shape/Type inference
- phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
- phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
-
- ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
- logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
- phase_runner.attach(&prog);
- phase_runner.run(phase);
-}
-
void CircleOptimizer::sparsify(loco::Graph *g) const
{
if (_options->query(Options::Algorithm::SparsifyTensorPass))
diff --git a/compiler/luci/pass/src/CircleOptimizer.test.cpp b/compiler/luci/pass/src/CircleOptimizer.test.cpp
index a1b5c7f80..041fc7d75 100644
--- a/compiler/luci/pass/src/CircleOptimizer.test.cpp
+++ b/compiler/luci/pass/src/CircleOptimizer.test.cpp
@@ -71,171 +71,3 @@ TEST(CircleOptimizerTest, sparsify_simple)
SUCCEED();
}
-
-TEST(CircleOptimizerTest, quantize_quantdequant_simple)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::QuantizeDequantizeWeights);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
- options->param(AlgorithmParameters::Quantize_granularity, "layer");
-
- o.quantize(&g);
-
- SUCCEED();
-}
-
-TEST(CircleOptimizerTest, quantize_quantdequant_input_NEG)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::QuantizeDequantizeWeights);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "invalid");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
- options->param(AlgorithmParameters::Quantize_granularity, "layer");
-
- EXPECT_THROW(o.quantize(&g), std::runtime_error);
-}
-
-TEST(CircleOptimizerTest, quantize_quantdequant_output_NEG)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::QuantizeDequantizeWeights);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "invalid");
- options->param(AlgorithmParameters::Quantize_granularity, "layer");
-
- EXPECT_THROW(o.quantize(&g), std::runtime_error);
-}
-
-TEST(CircleOptimizerTest, quantize_quantdequant_gran_NEG)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::QuantizeDequantizeWeights);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
- options->param(AlgorithmParameters::Quantize_granularity, "invalid");
-
- EXPECT_THROW(o.quantize(&g), std::runtime_error);
-}
-
-TEST(CircleOptimizerTest, quantize_minmax_simple)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::QuantizeWithMinMax);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
- options->param(AlgorithmParameters::Quantize_granularity, "layer");
-
- o.quantize(&g);
-
- SUCCEED();
-}
-
-TEST(CircleOptimizerTest, quantize_minmax_input_NEG)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::QuantizeWithMinMax);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "invalid");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
- options->param(AlgorithmParameters::Quantize_granularity, "layer");
-
- EXPECT_THROW(o.quantize(&g), std::runtime_error);
-}
-
-TEST(CircleOptimizerTest, quantize_minmax_output_NEG)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::QuantizeWithMinMax);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "invalid");
- options->param(AlgorithmParameters::Quantize_granularity, "layer");
-
- EXPECT_THROW(o.quantize(&g), std::runtime_error);
-}
-
-TEST(CircleOptimizerTest, quantize_minmax_gran_NEG)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::QuantizeWithMinMax);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
- options->param(AlgorithmParameters::Quantize_granularity, "invalid");
-
- EXPECT_THROW(o.quantize(&g), std::runtime_error);
-}
-
-TEST(CircleOptimizerTest, quantize_requant_simple)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::Requantize);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "int8");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
-
- o.quantize(&g);
-
- SUCCEED();
-}
-
-TEST(CircleOptimizerTest, quantize_requant_input_NEG)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::Requantize);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "invalid");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
-
- EXPECT_THROW(o.quantize(&g), std::runtime_error);
-}
-
-TEST(CircleOptimizerTest, quantize_requant_output_NEG)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::Requantize);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "int8");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "invalid");
-
- EXPECT_THROW(o.quantize(&g), std::runtime_error);
-}
diff --git a/compiler/luci/pass/src/CircleQuantizer.cpp b/compiler/luci/pass/src/CircleQuantizer.cpp
new file mode 100644
index 000000000..ce38a90b9
--- /dev/null
+++ b/compiler/luci/pass/src/CircleQuantizer.cpp
@@ -0,0 +1,458 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/CircleQuantizer.h"
+
+#include "luci/Pass/CopyQuantParamPass.h"
+#include "luci/Pass/ForceQuantParamPass.h"
+#include "luci/Pass/PropagateQParamForwardPass.h"
+#include "luci/Pass/RequantizePass.h"
+#include "luci/Pass/ConvertToFakeQuantizedModelPass.h"
+#include "luci/Pass/FoldDequantizePass.h"
+#include "luci/Pass/QuantizePreCheckerPass.h"
+#include "luci/Pass/QuantizeWithMinMaxPass.h"
+#include "luci/Pass/QuantizeDequantizeWeightsPass.h"
+
+#include "luci/Pass/CircleShapeInferencePass.h"
+#include "luci/Pass/CircleTypeInferencePass.h"
+
+// logo passes
+#include <logo/RemoveDeadNodeWithQueryPass.h>
+
+#include "ProgressReporter.h"
+#include "helpers/Strings.h"
+
+#include "QuantizedModelVerifier.h"
+
+#include <luci/IR/CircleNode.h>
+#include <logo/Phase.h>
+
+#include <memory>
+
+namespace
+{
+
+using namespace luci;
+using LayerParam = luci::CircleQuantizer::Options::LayerParam;
+
+template <typename T> T lexical_cast(const std::string &str)
+{
+ std::istringstream ss;
+ ss.str(str);
+ T data;
+ ss >> data;
+ return data;
+}
+
+template <typename T> std::vector<T> lexical_cast(std::vector<std::string> &sv)
+{
+ std::vector<T> result;
+ std::transform(sv.begin(), sv.end(), std::back_inserter(result),
+ [](std::string str) -> T { return lexical_cast<T>(str); });
+ return result;
+}
+
+class QuantizeOptionsImpl final : public luci::CircleQuantizer::Options
+{
+public:
+ void enable(Algorithm) final;
+ void param(AlgorithmParameters, const std::string &) final;
+ const std::string param(AlgorithmParameters) const final;
+ void params(AlgorithmParameters, std::vector<std::string> &) final;
+ std::vector<std::string> params(AlgorithmParameters) const final;
+ void layer_params(AlgorithmParameters, std::vector<std::shared_ptr<LayerParam>> &) final;
+ std::vector<std::shared_ptr<LayerParam>> layer_params(AlgorithmParameters) const final;
+ bool query(Algorithm) final;
+
+private:
+ std::vector<Algorithm> _algorithms;
+ std::map<AlgorithmParameters, const std::string> _algorithm_params;
+ std::map<AlgorithmParameters, std::vector<std::string>> _multiple_params;
+ std::map<AlgorithmParameters, std::vector<std::shared_ptr<LayerParam>>> _layer_params;
+};
+
+void QuantizeOptionsImpl::enable(Algorithm algo) { _algorithms.push_back(algo); }
+
+void QuantizeOptionsImpl::param(AlgorithmParameters param, const std::string &str)
+{
+ _algorithm_params.insert(std::pair<AlgorithmParameters, const std::string>(param, str));
+}
+
+const std::string QuantizeOptionsImpl::param(AlgorithmParameters param) const
+{
+ auto param_str = _algorithm_params.find(param);
+ if (param_str != _algorithm_params.end())
+ {
+ return param_str->second;
+ }
+ else
+ {
+ return std::string();
+ }
+}
+
+void QuantizeOptionsImpl::params(AlgorithmParameters param, std::vector<std::string> &vec)
+{
+ _multiple_params[param] = vec;
+}
+
+std::vector<std::string> QuantizeOptionsImpl::params(AlgorithmParameters param) const
+{
+ auto param_vec = _multiple_params.find(param);
+ if (param_vec != _multiple_params.end())
+ {
+ return param_vec->second;
+ }
+ else
+ {
+ return std::vector<std::string>();
+ }
+}
+
+void QuantizeOptionsImpl::layer_params(AlgorithmParameters param,
+ std::vector<std::shared_ptr<LayerParam>> &vec)
+{
+ _layer_params[param] = vec;
+}
+
+std::vector<std::shared_ptr<LayerParam>>
+QuantizeOptionsImpl::layer_params(AlgorithmParameters param) const
+{
+ auto param_vec = _layer_params.find(param);
+ if (param_vec != _layer_params.end())
+ {
+ return param_vec->second;
+ }
+ else
+ {
+ return std::vector<std::shared_ptr<LayerParam>>();
+ }
+}
+
+bool QuantizeOptionsImpl::query(Algorithm algo)
+{
+ std::vector<Algorithm>::iterator it = std::find(_algorithms.begin(), _algorithms.end(), algo);
+ if (it == _algorithms.end())
+ return false;
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+CircleQuantizer::Options *CircleQuantizer::options(void)
+{
+ if (_options == nullptr)
+ {
+ _options = std::make_unique<QuantizeOptionsImpl>();
+ }
+
+ return _options.get();
+}
+
+void CircleQuantizer::quantize(loco::Graph *g) const
+{
+ // Fake quantization of weights
+ if (_options->query(Options::Algorithm::QuantizeDequantizeWeights))
+ {
+ static const std::vector<std::string> fakeq_supported_input_model_dtype{"float32"};
+ static const std::vector<std::string> fakeq_supported_output_model_dtype{"uint8", "int16"};
+ static const std::vector<std::string> fakeq_supported_granularity{"layer", "channel"};
+
+ auto input_model_dtype =
+ _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
+ auto output_model_dtype =
+ _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
+ auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
+ auto layer_params = _options->layer_params(Options::AlgorithmParameters::Quantize_layer_params);
+
+ if (!in_array(to_lower_case(input_model_dtype), fakeq_supported_input_model_dtype))
+ throw std::runtime_error("Unsupported input type. List of supported input type: " +
+ to_string(fakeq_supported_input_model_dtype));
+
+ if (!in_array(to_lower_case(output_model_dtype), fakeq_supported_output_model_dtype))
+ throw std::runtime_error("Unsupported output type. List of supported output type: " +
+ to_string(fakeq_supported_output_model_dtype));
+
+ if (!in_array(to_lower_case(granularity), fakeq_supported_granularity))
+ throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
+ to_string(fakeq_supported_granularity));
+
+ if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise &&
+ str_to_dtype(output_model_dtype) != loco::DataType::U8)
+ throw std::runtime_error("Layer-wise quantization only supports uint8 dtype.");
+
+ // Check dtype/granularity of layer params
+ for (auto layer_param : layer_params)
+ {
+ auto name = layer_param->name;
+ if (!in_array(to_lower_case(layer_param->dtype), fakeq_supported_output_model_dtype))
+ {
+ throw std::runtime_error("Unsupported dtype in " + name + ". List of supported dtype: " +
+ to_string(fakeq_supported_output_model_dtype));
+ }
+ if (!in_array(to_lower_case(layer_param->granularity), fakeq_supported_granularity))
+ {
+ throw std::runtime_error(
+ "Unsupported granularity in " + name +
+ ". List of supported granularity: " + to_string(fakeq_supported_granularity));
+ }
+ }
+
+ // Clear existing quantparams before doing fake quantization
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ if (circle_node->quantparam() != nullptr)
+ circle_node->quantparam(nullptr);
+ }
+
+ auto ctx = std::make_unique<luci::QuantizeDequantizeWeightsPass::Context>();
+ {
+ ctx->input_model_dtype = str_to_dtype(input_model_dtype);
+ ctx->output_model_dtype = str_to_dtype(output_model_dtype);
+ ctx->granularity = str_to_granularity(granularity);
+
+ for (auto layer_param : layer_params)
+ {
+ LayerInfo info;
+ {
+ info.name = layer_param->name;
+ info.dtype = str_to_dtype(layer_param->dtype);
+ info.granularity = str_to_granularity(layer_param->granularity);
+ }
+ ctx->layers_info.emplace_back(info);
+ }
+ }
+
+ luci::QuantizeDequantizeWeightsPass fake_quantizer(std::move(ctx));
+
+ fake_quantizer.run(g);
+ }
+
+ // Actual quantization of weights, bias, and activation
+ if (_options->query(Options::Algorithm::QuantizeWithMinMax))
+ {
+ static const std::vector<std::string> qwmm_supported_input_model_dtype{"float32"};
+ static const std::vector<std::string> qwmm_supported_output_model_dtype{"uint8", "int16"};
+ static const std::vector<std::string> qwmm_supported_granularity{"layer", "channel"};
+ static const std::vector<std::string> qwmm_supported_input_type{"uint8", "int16"};
+ static const std::vector<std::string> qwmm_supported_output_type{"uint8", "int16"};
+
+ auto input_model_dtype =
+ _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
+ auto output_model_dtype =
+ _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
+ auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
+ auto input_type = _options->param(Options::AlgorithmParameters::Quantize_input_type);
+ if (input_type.empty())
+ input_type = output_model_dtype;
+ auto output_type = _options->param(Options::AlgorithmParameters::Quantize_output_type);
+ if (output_type.empty())
+ output_type = output_model_dtype;
+
+ bool TF_style_maxpool =
+ _options->param(Options::AlgorithmParameters::Quantize_TF_style_maxpool) == "True";
+
+ auto layer_params = _options->layer_params(Options::AlgorithmParameters::Quantize_layer_params);
+
+ if (!in_array(to_lower_case(input_model_dtype), qwmm_supported_input_model_dtype))
+ throw std::runtime_error("Unsupported input type. List of supported input types: " +
+ to_string(qwmm_supported_input_model_dtype));
+
+ if (!in_array(to_lower_case(output_model_dtype), qwmm_supported_output_model_dtype))
+ throw std::runtime_error("Unsupported output type. List of supported output types: " +
+ to_string(qwmm_supported_output_model_dtype));
+
+ if (!in_array(to_lower_case(granularity), qwmm_supported_granularity))
+ throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
+ to_string(qwmm_supported_granularity));
+
+ if (!in_array(to_lower_case(input_type), qwmm_supported_input_type))
+ throw std::runtime_error("Unsupported input type. List of supported input types: " +
+ to_string(qwmm_supported_input_type));
+
+ if (!in_array(to_lower_case(output_type), qwmm_supported_output_type))
+ throw std::runtime_error("Unsupported output type. List of supported output types: " +
+ to_string(qwmm_supported_output_type));
+
+ if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise &&
+ str_to_dtype(output_model_dtype) != loco::DataType::U8)
+ throw std::runtime_error("Layer-wise quantization only supports uint8 dtype.");
+
+ // Check dtype/granularity of layer params
+ for (auto layer_param : layer_params)
+ {
+ auto name = layer_param->name;
+ if (!in_array(to_lower_case(layer_param->dtype), qwmm_supported_output_model_dtype))
+ {
+ throw std::runtime_error("Unsupported dtype in " + name + ". List of supported dtype: " +
+ to_string(qwmm_supported_output_model_dtype));
+ }
+ if (!in_array(to_lower_case(layer_param->granularity), qwmm_supported_granularity))
+ {
+ throw std::runtime_error(
+ "Unsupported granularity in " + name +
+ ". List of supported granularity: " + to_string(qwmm_supported_granularity));
+ }
+ }
+
+ // Input model checker for quantization
+ luci::QuantizePreCheckerPass input_model_checker{};
+ input_model_checker.run(g);
+
+ auto ctx = std::make_unique<luci::QuantizeWithMinMaxPass::Context>();
+ {
+ ctx->input_model_dtype = str_to_dtype(input_model_dtype);
+ ctx->output_model_dtype = str_to_dtype(output_model_dtype);
+ ctx->granularity = str_to_granularity(granularity);
+ ctx->input_type = str_to_dtype(input_type);
+ ctx->output_type = str_to_dtype(output_type);
+ ctx->TF_style_maxpool = TF_style_maxpool;
+
+ for (auto layer_param : layer_params)
+ {
+ LayerInfo info;
+ {
+ info.name = layer_param->name;
+ info.dtype = str_to_dtype(layer_param->dtype);
+ info.granularity = str_to_granularity(layer_param->granularity);
+ }
+ ctx->layers_info.emplace_back(info);
+ }
+ }
+
+ luci::QuantizeWithMinMaxPass quantizer(std::move(ctx));
+
+ quantizer.run(g);
+
+ auto verify_ctx = std::make_unique<luci::QuantizedModelVerifier::Context>();
+ {
+ verify_ctx->output_model_dtype = str_to_dtype(output_model_dtype);
+ verify_ctx->granularity = str_to_granularity(granularity);
+ verify_ctx->input_type = str_to_dtype(input_type);
+ verify_ctx->output_type = str_to_dtype(output_type);
+ verify_ctx->TF_style_maxpool = TF_style_maxpool;
+
+ for (auto layer_param : layer_params)
+ {
+ LayerInfo info;
+ {
+ info.name = layer_param->name;
+ info.dtype = str_to_dtype(layer_param->dtype);
+ info.granularity = str_to_granularity(layer_param->granularity);
+ }
+ verify_ctx->layers_info.emplace_back(info);
+ }
+ }
+
+ // Verify the type/granularity of the quantized model
+ luci::QuantizedModelVerifier verifier(std::move(verify_ctx));
+
+ verifier.verify(g);
+ }
+
+ // Requantize
+ if (_options->query(Options::Algorithm::Requantize))
+ {
+ static const std::vector<std::string> rq_supported_input_model_dtype{"int8"};
+ static const std::vector<std::string> rq_supported_output_model_dtype{"uint8"};
+
+ auto input_model_dtype =
+ _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
+ auto output_model_dtype =
+ _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
+
+ if (!in_array(to_lower_case(input_model_dtype), rq_supported_input_model_dtype))
+ throw std::runtime_error("Unsupported input type. List of supported input types: " +
+ to_string(rq_supported_input_model_dtype));
+
+ if (!in_array(to_lower_case(output_model_dtype), rq_supported_output_model_dtype))
+ throw std::runtime_error("Unsupported output type. List of supported output types: " +
+ to_string(rq_supported_output_model_dtype));
+
+ luci::RequantizePass requantizer(str_to_dtype(input_model_dtype),
+ str_to_dtype(output_model_dtype));
+ requantizer.run(g);
+ }
+
+ // Force to write quantparam to specified tensors
+ // NOTE Only per-tensor (not per-channel) qparam can be written
+ if (_options->query(Options::Algorithm::ForceQuantParam))
+ {
+ ForceQuantParamPass::TensorVector tensors =
+ _options->params(Options::AlgorithmParameters::Quantize_tensor_names);
+ auto str_scales = _options->params(Options::AlgorithmParameters::Quantize_scales);
+ auto str_zero_points = _options->params(Options::AlgorithmParameters::Quantize_zero_points);
+
+ // Cast scales/zero_points to proper types
+ ForceQuantParamPass::ScaleVector scales = lexical_cast<float>(str_scales);
+ ForceQuantParamPass::ZPVector zero_points = lexical_cast<int64_t>(str_zero_points);
+
+ ForceQuantParamPass fq(tensors, scales, zero_points);
+ fq.run(g);
+ }
+
+ // Copy quantparam of a tensor to another tensor
+ if (_options->query(Options::Algorithm::CopyQuantParam))
+ {
+ CopyQuantParamPass::TensorVector src_tensors =
+ _options->params(Options::AlgorithmParameters::Quantize_src_tensor_names);
+ CopyQuantParamPass::TensorVector dst_tensors =
+ _options->params(Options::AlgorithmParameters::Quantize_dst_tensor_names);
+
+ CopyQuantParamPass cq(src_tensors, dst_tensors);
+ cq.run(g);
+ }
+
+ // Convert quantized model to fake-quantized model
+ if (_options->query(Options::Algorithm::ConvertToFakeQuantizedModel))
+ {
+ luci::ConvertToFakeQuantizedModelPass fake_quantizer;
+ fake_quantizer.run(g);
+
+ logo::Phase phase;
+
+ // Default passes
+ phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
+ phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
+ phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
+
+ // Fold Dequantize Ops generated during fake quantization
+ phase.emplace_back(std::make_unique<luci::FoldDequantizePass>());
+
+ ProgressReporter prog(g, logo::PhaseStrategy::Restart);
+ logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
+ phase_runner.attach(&prog);
+ phase_runner.run(phase);
+ }
+
+ logo::Phase phase;
+
+ // Do Shape/Type inference
+ phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
+ phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
+
+ ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
+ logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
+ phase_runner.attach(&prog);
+ phase_runner.run(phase);
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/CircleQuantizer.test.cpp b/compiler/luci/pass/src/CircleQuantizer.test.cpp
new file mode 100644
index 000000000..5766d5fe5
--- /dev/null
+++ b/compiler/luci/pass/src/CircleQuantizer.test.cpp
@@ -0,0 +1,191 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/CircleQuantizer.h"
+
+#include <gtest/gtest.h>
+
+using namespace luci;
+using Algorithms = luci::CircleQuantizer::Options::Algorithm;
+using AlgorithmParameters = luci::CircleQuantizer::Options::AlgorithmParameters;
+
+TEST(CircleQuantizerTest, quantize_quantdequant_simple)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeDequantizeWeights);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
+ options->param(AlgorithmParameters::Quantize_granularity, "layer");
+
+ o.quantize(&g);
+
+ SUCCEED();
+}
+
+TEST(CircleQuantizerTest, quantize_quantdequant_input_NEG)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeDequantizeWeights);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "invalid");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
+ options->param(AlgorithmParameters::Quantize_granularity, "layer");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleQuantizerTest, quantize_quantdequant_output_NEG)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeDequantizeWeights);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "invalid");
+ options->param(AlgorithmParameters::Quantize_granularity, "layer");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleQuantizerTest, quantize_quantdequant_gran_NEG)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeDequantizeWeights);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
+ options->param(AlgorithmParameters::Quantize_granularity, "invalid");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleQuantizerTest, quantize_minmax_simple)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeWithMinMax);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
+ options->param(AlgorithmParameters::Quantize_granularity, "layer");
+
+ o.quantize(&g);
+
+ SUCCEED();
+}
+
+TEST(CircleQuantizerTest, quantize_minmax_input_NEG)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeWithMinMax);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "invalid");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
+ options->param(AlgorithmParameters::Quantize_granularity, "layer");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleQuantizerTest, quantize_minmax_output_NEG)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeWithMinMax);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "invalid");
+ options->param(AlgorithmParameters::Quantize_granularity, "layer");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleQuantizerTest, quantize_minmax_gran_NEG)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeWithMinMax);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
+ options->param(AlgorithmParameters::Quantize_granularity, "invalid");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleQuantizerTest, quantize_requant_simple)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::Requantize);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "int8");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
+
+ o.quantize(&g);
+
+ SUCCEED();
+}
+
+TEST(CircleQuantizerTest, quantize_requant_input_NEG)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::Requantize);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "invalid");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleQuantizerTest, quantize_requant_output_NEG)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::Requantize);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "int8");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "invalid");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp
index 270714049..ce4f54035 100644
--- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp
+++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp
@@ -228,6 +228,9 @@ bool check_4d_reshape(loco::Node *node, const std::vector<int32_t> indices)
if (input->shape_status() != luci::ShapeStatus::VALID)
return false;
+ if (input->rank() != 4)
+ return false;
+
if (reshape->shape_status() != luci::ShapeStatus::VALID)
return false;
@@ -804,6 +807,8 @@ class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool>
return true;
}
+ bool visit(luci::CircleElu *node) { return convert_unary_features<luci::CircleElu>(node); }
+
bool visit(luci::CircleLeakyRelu *node)
{
return convert_unary_features<luci::CircleLeakyRelu>(node);
@@ -1240,6 +1245,7 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g)
break;
case luci::CircleOpcode::ADD:
case luci::CircleOpcode::CONCATENATION:
+ case luci::CircleOpcode::ELU:
case luci::CircleOpcode::LEAKY_RELU:
case luci::CircleOpcode::LOGISTIC:
case luci::CircleOpcode::MAXIMUM:
diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
index c9412fbb1..dd81d1380 100644
--- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
+++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
@@ -264,6 +264,22 @@ public:
luci::CircleConst *input2 = nullptr;
};
+class EluGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ elu = g.nodes()->create<luci::CircleElu>();
+ elu->features(input);
+ elu->name("elu");
+
+ return elu;
+ }
+
+public:
+ luci::CircleElu *elu = nullptr;
+};
+
class LeakyReluGraph final : public SimpleGraph
{
protected:
@@ -941,6 +957,26 @@ TEST(ConvertNCHWToNHWC, Concatenation)
EXPECT_EQ(3, g.concat->axis());
}
+TEST(ConvertNCHWToNHWC, Elu)
+{
+ EluGraph g;
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.elu->features());
+
+ auto elu_succs = loco::succs(g.elu);
+ EXPECT_EQ(1, elu_succs.size());
+ check_post_trans(*elu_succs.begin());
+
+ // Check elu shape
+ EXPECT_EQ(1, g.elu->dim(0).value());
+ EXPECT_EQ(4, g.elu->dim(1).value());
+ EXPECT_EQ(4, g.elu->dim(2).value());
+ EXPECT_EQ(16, g.elu->dim(3).value());
+}
+
TEST(ConvertNCHWToNHWC, LeakyRelu)
{
LeakyReluGraph g;
diff --git a/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp
new file mode 100644
index 000000000..11970fff5
--- /dev/null
+++ b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp
@@ -0,0 +1,214 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ConvertToFakeQuantizedModelPass.h"
+#include "luci/Pass/QuantizationParameters.h"
+
+#include "QuantizationUtils.h"
+
+#include <luci/Profile/CircleNodeOrigin.h>
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Log.h>
+
+namespace
+{
+
+// Create Quantize Op whose dtype/shape/qparam are the same with node
+luci::CircleQuantize *create_quantize(luci::CircleNode *node)
+{
+ auto quantize = node->graph()->nodes()->create<luci::CircleQuantize>();
+ quantize->name(node->name() + "_Quantize");
+ quantize->dtype(node->dtype());
+ quantize->rank(node->rank());
+ for (uint32_t i = 0; i < node->rank(); i++)
+ quantize->dim(i).set(node->dim(i).value());
+
+ quantize->shape_status(luci::ShapeStatus::VALID);
+
+ copy_quantparam(node, quantize);
+
+ luci::add_origin(quantize, luci::get_origin(node));
+
+ return quantize;
+}
+
+// Create Dequantize Op whose shape is the same with node
+luci::CircleDequantize *create_dequantize(luci::CircleNode *node)
+{
+ auto dequantize = node->graph()->nodes()->create<luci::CircleDequantize>();
+ dequantize->name(node->name() + "_Dequantize");
+ dequantize->dtype(loco::DataType::FLOAT32);
+ dequantize->rank(node->rank());
+ for (uint32_t i = 0; i < node->rank(); i++)
+ dequantize->dim(i).set(node->dim(i).value());
+
+ dequantize->shape_status(luci::ShapeStatus::VALID);
+
+ luci::add_origin(dequantize, luci::get_origin(node));
+
+ return dequantize;
+}
+
+// Return true if node is quantized activation
+// 1. dtype is u8 or s16
+// 2. node has qparam
+bool is_quant_act(const luci::CircleNode *node)
+{
+ if (node->dtype() != loco::DataType::U8 and node->dtype() != loco::DataType::S16)
+ return false;
+
+ if (not node->quantparam())
+ return false;
+
+ return true;
+}
+
+// Return true if node is quantized const
+// 1. dtype is not fp32
+// 2. node has qparam
+// NOTE Quantized const can have the following types
+// u8 (weights, activation), s16 (weights, activation), s32 (bias), s64 (bias)
+bool is_quant_const(const luci::CircleConst *node)
+{
+ if (node->dtype() == loco::DataType::FLOAT32)
+ return false;
+
+ if (not node->quantparam())
+ return false;
+
+ return true;
+}
+
+// Insert dequantize Op after node
+void insert_dequantize(loco::Node *lnode)
+{
+ auto node = loco::must_cast<luci::CircleNode *>(lnode);
+ auto dequant = create_dequantize(node);
+ loco::replace(node).with(dequant);
+ dequant->input(node);
+}
+
+// Insert quantize Op after node and return the quantize Op
+luci::CircleQuantize *insert_quantize(loco::Node *lnode)
+{
+ auto node = loco::must_cast<luci::CircleNode *>(lnode);
+ auto quant = create_quantize(node);
+ loco::replace(node).with(quant);
+ quant->input(node);
+ return quant;
+}
+
+// Dequantize node
+void dequantize(luci::CircleNode *node)
+{
+ node->dtype(loco::DataType::FLOAT32);
+ node->quantparam(nullptr);
+}
+
+// Do fake quantization on quantized activation
+// 1. Insert Quantize-Dequantize Ops
+// 2. Update dtype/quantparam of node
+void fq_activation(luci::CircleNode *node)
+{
+ if (not is_quant_act(node))
+ return;
+
+ auto quant = insert_quantize(node);
+ insert_dequantize(quant);
+
+ dequantize(node);
+}
+
+#define RETURN_UNLESS(COND) \
+ if (not(COND)) \
+ return;
+
+// Visitor to do fake quantization for each Op
+// For non-const activation, insert Quantize-Dequantize after the ofm
+// For quantized const, insert Dequantize after the const
+struct FakeQuantize final : public luci::CircleNodeMutableVisitor<void>
+{
+ void visit(luci::CircleNode *node)
+ {
+ throw std::runtime_error("Unsupported op for fake quantization in " + node->name());
+ }
+
+ void visit(luci::CircleInput *node)
+ {
+ RETURN_UNLESS(is_quant_act(node));
+
+ auto quant = insert_quantize(node);
+ insert_dequantize(quant);
+
+ dequantize(node);
+
+ // Update graph input
+ const auto inputs = node->graph()->inputs();
+ auto graph_input = inputs->at(node->index());
+ graph_input->dtype(loco::DataType::FLOAT32);
+ }
+
+ void visit(luci::CircleOutput *node)
+ {
+ RETURN_UNLESS(is_quant_act(node));
+
+ dequantize(node);
+
+ // Update graph output
+ const auto outputs = node->graph()->outputs();
+ auto graph_output = outputs->at(node->index());
+ graph_output->dtype(loco::DataType::FLOAT32);
+ }
+
+ // For quantized const, insert Dequantize Op
+ void visit(luci::CircleConst *node)
+ {
+ RETURN_UNLESS(is_quant_const(node));
+
+ insert_dequantize(node);
+ }
+
+ // For non-const activation, insert Quantize-Dequantize Ops
+ // and dequantize the node
+ void visit(luci::CircleConv2D *node) { fq_activation(node); }
+ void visit(luci::CircleAdd *node) { fq_activation(node); }
+};
+
+#undef RETURN_UNLESS
+
+} // namespace
+
+namespace luci
+{
+
+bool ConvertToFakeQuantizedModelPass::run(loco::Graph *g)
+{
+ LOGGER(l);
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ INFO(l) << "ConvertToFakeQuantizedModelPass visit node: " << circle_node->name() << std::endl;
+
+ FakeQuantize fq;
+ circle_node->accept(&fq);
+ }
+
+ // One time run
+ return false;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.test.cpp b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.test.cpp
new file mode 100644
index 000000000..560d68a74
--- /dev/null
+++ b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.test.cpp
@@ -0,0 +1,277 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <logo/Phase.h>
+
+#include "luci/Pass/ConvertToFakeQuantizedModelPass.h"
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+// Check the below pattern
+// Quantize (scale, zp) -> Dequantize (node)
+void check_q_dq(loco::Node *node, float scale, int64_t zp)
+{
+ auto dequant = dynamic_cast<luci::CircleDequantize *>(node);
+ EXPECT_TRUE(dequant != nullptr);
+ auto quant = dynamic_cast<luci::CircleQuantize *>(dequant->input());
+ EXPECT_TRUE(quant != nullptr);
+ auto qparam = quant->quantparam();
+ EXPECT_EQ(scale, qparam->scale[0]);
+ EXPECT_EQ(zp, qparam->zerop[0]);
+}
+
+// Check the below pattern
+// Dequantize (node)
+void check_dq(loco::Node *node)
+{
+ auto dequant = dynamic_cast<luci::CircleDequantize *>(node);
+ EXPECT_TRUE(dequant != nullptr);
+}
+
+void set_qparam(luci::CircleNode *node, float scale, int64_t zp)
+{
+ auto qparam = std::make_unique<luci::CircleQuantParam>();
+ {
+ qparam->scale.push_back(scale);
+ qparam->zerop.push_back(zp);
+ }
+ node->quantparam(std::move(qparam));
+}
+
+/**
+ * SimpleGraph for testing
+ * - Child class should implement insertGraphBody()
+ *
+ * Example (U8ConvGraph inherits SimpleGraph and create Conv2D Op)
+ *
+ * BEFORE
+ * - A model is quantized (ex: u8)
+ *
+ * [Input(u8)] [Filter(u8)] [Bias(s32)]
+ * \ | /
+ * \ | /
+ * \ | /
+ * [Conv2D(u8)]
+ * |
+ * [Output(u8)]
+ *
+ * AFTER
+ * - Ops are converted to fp32
+ * - Quantize/Dequantize Ops are inserted properly
+ * - Q-DQ is inserted after non-const activation
+ * - DQ is inserted after const
+ *
+ * [Input(u8)]
+ * |
+ * [Quant(u8)] [Filter(u8)] [Bias(s32)]
+ * | | |
+ * [Dequant(fp32)] [Dequant(fp32)] [Dequant(fp32)]
+ * \ | /
+ * \ | /
+ * \ | /
+ * [Conv2D(fp32)]
+ * |
+ * [Quant(u8)]
+ * |
+ * [Dequant(fp32)]
+ * |
+ * [Output(fp32)]
+ */
+template <loco::DataType T> class SimpleGraph
+{
+public:
+ void init()
+ {
+ input = g.nodes()->create<luci::CircleInput>();
+ output = g.nodes()->create<luci::CircleOutput>();
+ input->name("input");
+ output->name("output");
+
+ auto graph_input = g.inputs()->create();
+ input->index(graph_input->index());
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ graph_input->dtype(T);
+ input->dtype(T);
+ output->dtype(T);
+ graph_output->dtype(T);
+
+ graph_input->shape({1, 4, 4, 4});
+ input->shape({1, 4, 4, 4});
+ output->shape({1, 4, 4, 4});
+ graph_output->shape({1, 4, 4, 4});
+
+ set_qparam(input, 1.0, 0);
+ set_qparam(output, 1.0, 0);
+
+ auto graph_body = insertGraphBody(input);
+ output->from(graph_body);
+ }
+
+ virtual ~SimpleGraph() = default;
+
+protected:
+ virtual loco::Node *insertGraphBody(loco::Node *input) = 0;
+
+public:
+ loco::Graph g;
+ luci::CircleInput *input = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+class U8ConvGraph final : public SimpleGraph<loco::DataType::U8>
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ conv = g.nodes()->create<luci::CircleConv2D>();
+ weights = g.nodes()->create<luci::CircleConst>();
+ bias = g.nodes()->create<luci::CircleConst>();
+
+ conv->dtype(loco::DataType::U8);
+ weights->dtype(loco::DataType::U8);
+ bias->dtype(loco::DataType::S32);
+
+ conv->shape({1, 4, 4, 4});
+ weights->shape({4, 1, 1, 4});
+ bias->shape({4});
+
+ weights->size<loco::DataType::U8>(16);
+ for (uint32_t i = 0; i < 16; i++)
+ weights->at<loco::DataType::U8>(i) = i;
+
+ bias->size<loco::DataType::S32>(4);
+ for (uint32_t i = 0; i < 4; i++)
+ bias->at<loco::DataType::S32>(i) = i;
+
+ set_qparam(conv, 2.0, 127);
+ set_qparam(weights, 2.0, 127);
+ set_qparam(bias, 2.0, 127);
+
+ conv->input(input);
+ conv->filter(weights);
+ conv->bias(bias);
+
+ conv->name("conv");
+ weights->name("weights");
+ bias->name("bias");
+
+ return conv;
+ }
+
+public:
+ luci::CircleConv2D *conv = nullptr;
+ luci::CircleConst *weights = nullptr;
+ luci::CircleConst *bias = nullptr;
+};
+
+class FP32ConvGraph final : public SimpleGraph<loco::DataType::FLOAT32>
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ conv = g.nodes()->create<luci::CircleConv2D>();
+ weights = g.nodes()->create<luci::CircleConst>();
+ bias = g.nodes()->create<luci::CircleConst>();
+
+ conv->dtype(loco::DataType::FLOAT32);
+ weights->dtype(loco::DataType::FLOAT32);
+ bias->dtype(loco::DataType::FLOAT32);
+
+ conv->shape({1, 4, 4, 4});
+ weights->shape({4, 1, 1, 4});
+ bias->shape({4});
+
+ weights->size<loco::DataType::FLOAT32>(16);
+ for (uint32_t i = 0; i < 16; i++)
+ weights->at<loco::DataType::FLOAT32>(i) = i;
+
+ bias->size<loco::DataType::FLOAT32>(4);
+ for (uint32_t i = 0; i < 4; i++)
+ bias->at<loco::DataType::FLOAT32>(i) = i;
+
+ conv->input(input);
+ conv->filter(weights);
+ conv->bias(bias);
+
+ conv->name("conv");
+ weights->name("weights");
+ bias->name("bias");
+
+ return conv;
+ }
+
+public:
+ luci::CircleConv2D *conv = nullptr;
+ luci::CircleConst *weights = nullptr;
+ luci::CircleConst *bias = nullptr;
+};
+
+} // namespace
+
+TEST(ConvertToFakeQuantizedModelTest, U8Conv2D)
+{
+ U8ConvGraph g;
+ g.init();
+
+ luci::ConvertToFakeQuantizedModelPass fq;
+ fq.run(&g.g);
+
+ // Check ifm
+ check_q_dq(g.conv->input(), 1.0, 0);
+
+ // Check weights
+ check_dq(g.conv->filter());
+
+ // Check bias
+ check_dq(g.conv->bias());
+
+ // Check ofm
+ check_q_dq(g.output->from(), 2.0, 127);
+
+ SUCCEED();
+}
+
+TEST(ConvertToFakeQuantizedModelTest, F32Conv2D_NEG)
+{
+ FP32ConvGraph g;
+ g.init();
+
+ luci::ConvertToFakeQuantizedModelPass fq;
+ fq.run(&g.g);
+
+ uint32_t dequant_count = 0;
+ uint32_t quant_count = 0;
+
+ for (auto node : loco::active_nodes(loco::output_nodes(&g.g)))
+ {
+ auto cnode = loco::must_cast<luci::CircleNode *>(node);
+ auto opcode = cnode->opcode();
+ if (opcode == luci::CircleOpcode::DEQUANTIZE)
+ dequant_count++;
+ if (opcode == luci::CircleOpcode::QUANTIZE)
+ quant_count++;
+ }
+
+ // Check no quant/dequant Op is inserted
+ EXPECT_EQ(0, quant_count);
+ EXPECT_EQ(0, dequant_count);
+}
diff --git a/compiler/luci/pass/src/CopyQuantParamPass.cpp b/compiler/luci/pass/src/CopyQuantParamPass.cpp
new file mode 100644
index 000000000..9b1bb0ea9
--- /dev/null
+++ b/compiler/luci/pass/src/CopyQuantParamPass.cpp
@@ -0,0 +1,82 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/CopyQuantParamPass.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Log.h>
+
+namespace luci
+{
+
+namespace
+{
+
+struct SrcDst
+{
+ CircleNode *src = nullptr;
+ CircleNode *dst = nullptr;
+};
+
+} // namespace
+
+bool CopyQuantParamPass::run(loco::Graph *g)
+{
+ LOGGER(l);
+
+ INFO(l) << "CopyQuantParamPass Start" << std::endl;
+
+ if (_src_tensors.size() != _dst_tensors.size())
+ throw std::runtime_error("The numbers of Source/Destination tensors do not match.");
+
+ // Return src/dst CircleNodes
+ auto get_src_dst = [&g](std::string src, std::string dst) {
+ SrcDst src_dst;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto const cnode = loco::must_cast<CircleNode *>(node);
+ auto const name = cnode->name();
+ if (name == src)
+ src_dst.src = cnode;
+
+ if (name == dst)
+ src_dst.dst = cnode;
+ }
+ return src_dst;
+ };
+
+ for (uint32_t i = 0; i < _src_tensors.size(); i++)
+ {
+ auto src = _src_tensors[i];
+ auto dst = _dst_tensors[i];
+
+ auto nodes = get_src_dst(src, dst);
+ if (not nodes.src)
+ throw std::runtime_error("The tensor named " + src + " does not exist.");
+
+ if (not nodes.dst)
+ throw std::runtime_error("The tensor named " + dst + " does not exist.");
+
+ copy_quantparam(nodes.src, nodes.dst);
+
+ INFO(l) << "Quantparam of " << src << " is copied to " << dst << std::endl;
+ }
+
+ INFO(l) << "CopyQuantParamPass End" << std::endl;
+
+ return false; // one time run
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/FoldGatherPass.cpp b/compiler/luci/pass/src/FoldGatherPass.cpp
new file mode 100644
index 000000000..f179d74bd
--- /dev/null
+++ b/compiler/luci/pass/src/FoldGatherPass.cpp
@@ -0,0 +1,185 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FoldGatherPass.h"
+#include "CircleOptimizerUtils.h"
+
+#include <luci/IR/CircleNodes.h>
+
+namespace
+{
+
+/**
+ * Fold to const if
+ *
+ * 1. params: const and dtype = S32 or S64
+ * 2. indices: const and dtype = S32 or S64
+ *
+ * BEFORE
+ *
+ * [CircleConst] [CircleConst]
+ * | |
+ * +---------[Gather]---------+
+ *
+ * AFTER
+ *
+ * [CircleConst]
+ *
+ **/
+template <loco::DataType InputT, loco::DataType IndexT>
+bool fold_gather(luci::CircleGather *gather_node)
+{
+ const auto params = loco::must_cast<luci::CircleConst *>(gather_node->params());
+ const auto indices = loco::must_cast<luci::CircleConst *>(gather_node->indices());
+
+ const auto rank = params->rank();
+ auto axis = gather_node->axis();
+ if (axis < 0)
+ {
+ axis += static_cast<int32_t>(rank);
+ }
+
+ if (axis < 0 or axis >= static_cast<int32_t>(rank))
+ throw std::runtime_error("Unsupported axis value");
+
+ const auto name = gather_node->name();
+ assert(name.length() > 0);
+
+ auto constant = gather_node->graph()->nodes()->create<luci::CircleConst>();
+ constant->dtype(InputT);
+ constant->name(name + "_folded");
+
+ constant->rank(rank + indices->rank() - 1);
+
+ assert(constant->rank() > 0);
+
+ std::vector<uint32_t> shape;
+ for (uint32_t i = 0; i < rank; ++i)
+ {
+ if (i != static_cast<uint32_t>(axis))
+ {
+ const auto dim = params->dim(i).value();
+ shape.push_back(dim);
+ }
+ else
+ {
+ for (uint32_t j = 0; j < indices->rank(); ++j)
+ {
+ const auto dim = indices->dim(j).value();
+ shape.push_back(dim);
+ }
+ }
+ }
+
+ uint32_t size = 1;
+ for (uint32_t i = 0; i < shape.size(); ++i)
+ {
+ constant->dim(i).set(shape.at(i));
+ size *= shape.at(i);
+ }
+
+ constant->size<InputT>(size);
+
+ uint32_t outer_size = 1;
+ for (uint32_t i = 0; i < static_cast<uint32_t>(axis); ++i)
+ {
+ outer_size *= params->dim(i).value();
+ }
+
+ uint32_t inner_size = 1;
+ for (uint32_t i = axis + 1; i < rank; ++i)
+ {
+ inner_size *= params->dim(i).value();
+ }
+
+ uint32_t coord_size = 1;
+ for (uint32_t i = 0; i < indices->rank(); ++i)
+ {
+ coord_size *= indices->dim(i).value();
+ }
+
+ const auto axis_size = params->dim(axis).value();
+
+ for (uint32_t outer = 0; outer < outer_size; ++outer)
+ {
+ for (uint32_t i = 0; i < coord_size; ++i)
+ {
+ constant->at<InputT>((outer * coord_size + i) * inner_size) =
+ params->at<InputT>((outer * axis_size + indices->at<IndexT>(i)) * inner_size);
+ }
+ }
+ loco::replace(gather_node).with(constant);
+
+ return true;
+}
+
+bool fold_gather(luci::CircleGather *gather_node)
+{
+ const auto params = dynamic_cast<luci::CircleConst *>(gather_node->params());
+ if (not params)
+ return false;
+
+ const auto indices = dynamic_cast<luci::CircleConst *>(gather_node->indices());
+ if (not indices)
+ return false;
+
+ // TODO: support more types
+ if (params->dtype() != loco::DataType::S32 and params->dtype() != loco::DataType::S64)
+ return false;
+
+ if (indices->dtype() != loco::DataType::S32 and indices->dtype() != loco::DataType::S64)
+ throw std::runtime_error("Unsupported type");
+
+ if (params->dtype() == loco::DataType::S64)
+ {
+ if (indices->dtype() == loco::DataType::S64)
+ return fold_gather<loco::DataType::S64, loco::DataType::S64>(gather_node);
+ else
+ return fold_gather<loco::DataType::S64, loco::DataType::S32>(gather_node);
+ }
+ else
+ {
+ if (indices->dtype() == loco::DataType::S64)
+ return fold_gather<loco::DataType::S32, loco::DataType::S64>(gather_node);
+ else
+ return fold_gather<loco::DataType::S32, loco::DataType::S32>(gather_node);
+ }
+}
+
+} // namespace
+
+namespace luci
+{
+
+/**
+ * Constant Folding for Gather Op
+ **/
+bool FoldGatherPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto gather_node = dynamic_cast<luci::CircleGather *>(node))
+ {
+ if (fold_gather(gather_node))
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/FoldGatherPass.test.cpp b/compiler/luci/pass/src/FoldGatherPass.test.cpp
new file mode 100644
index 000000000..b02c034a5
--- /dev/null
+++ b/compiler/luci/pass/src/FoldGatherPass.test.cpp
@@ -0,0 +1,214 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FoldGatherPass.h"
+#include "PassTestGraphs.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+/**
+ *
+ * Graph that has a Gather S64 Op with const inputs
+ *
+ * BEFORE
+ * params: [Const] (shape: [3], values: [1, 2, 3])
+ * indices: [Const] (shape: [1], values: [1])
+ *
+ * [params] [indices]
+ * | |
+ * ---[Gather]---
+ *
+ * AFTER
+ * [Const] (shape: [1], values: [2])
+ *
+ */
+class S64FoldGatherSimpleTest : public luci::ConstantFoldingAddTestGraph, public ::testing::Test
+{
+public:
+ S64FoldGatherSimpleTest() : luci::ConstantFoldingAddTestGraph({1}, loco::DataType::S64) {}
+
+ virtual void SetUp() { init(); }
+
+ loco::Node *createFoldedPattern() override
+ {
+ _gather = _g.nodes()->create<luci::CircleGather>();
+ _params = _g.nodes()->create<luci::CircleConst>();
+ _indices = _g.nodes()->create<luci::CircleConst>();
+
+ _gather->dtype(loco::DataType::S64);
+ _params->dtype(loco::DataType::S64);
+ _indices->dtype(loco::DataType::S64);
+
+ _params->shape({3});
+ _indices->shape({1});
+
+ _params->size<loco::DataType::S64>(3);
+ _params->at<loco::DataType::S64>(0) = 1;
+ _params->at<loco::DataType::S64>(1) = 2;
+ _params->at<loco::DataType::S64>(2) = 3;
+
+ _indices->size<loco::DataType::S64>(1);
+ _indices->at<loco::DataType::S64>(0) = 1;
+
+ _gather->params(_params);
+ _gather->indices(_indices);
+
+ _gather->name("gather");
+ _params->name("params");
+ _indices->name("indices");
+
+ return _gather;
+ }
+
+protected:
+ luci::CircleGather *_gather = nullptr;
+ luci::CircleConst *_params = nullptr;
+ luci::CircleConst *_indices = nullptr;
+};
+
+/**
+ *
+ * Graph that has a Gather S32 Op with axis = 1 and with const inputs
+ *
+ * BEFORE
+ * params: [Const] (shape: [2, 3], values: [0, 1, 2, 3, 4, 5])
+ * indices: [Const] (shape: [2], values: [2, 1])
+ *
+ * [params] [indices]
+ * | |
+ * ---[Gather]---
+ *
+ * AFTER
+ * [Const] (shape: [2, 2], values: [2, 1, 5, 4])
+ *
+ */
+
+class S32FoldGatherTwoDimsTest : public luci::ConstantFoldingAddTestGraph, public ::testing::Test
+{
+public:
+ S32FoldGatherTwoDimsTest() : luci::ConstantFoldingAddTestGraph({4, 2}, loco::DataType::S32) {}
+
+ virtual void SetUp() { init(); }
+
+ loco::Node *createFoldedPattern() override
+ {
+ _gather = _g.nodes()->create<luci::CircleGather>();
+ _params = _g.nodes()->create<luci::CircleConst>();
+ _indices = _g.nodes()->create<luci::CircleConst>();
+
+ _gather->dtype(loco::DataType::S32);
+ _params->dtype(loco::DataType::S32);
+ _indices->dtype(loco::DataType::S32);
+
+ _params->shape({2, 3});
+ _indices->shape({2});
+
+ _params->size<loco::DataType::S32>(6);
+ _params->at<loco::DataType::S32>(0) = 0;
+ _params->at<loco::DataType::S32>(1) = 1;
+ _params->at<loco::DataType::S32>(2) = 2;
+ _params->at<loco::DataType::S32>(3) = 3;
+ _params->at<loco::DataType::S32>(4) = 4;
+ _params->at<loco::DataType::S32>(5) = 5;
+
+ _indices->size<loco::DataType::S32>(2);
+ _indices->at<loco::DataType::S32>(0) = 2;
+ _indices->at<loco::DataType::S32>(1) = 1;
+
+ _gather->params(_params);
+ _gather->indices(_indices);
+
+ _gather->axis(1);
+
+ _gather->name("gather");
+ _params->name("params");
+ _indices->name("indices");
+
+ return _gather;
+ }
+
+protected:
+ luci::CircleGather *_gather = nullptr;
+ luci::CircleConst *_params = nullptr;
+ luci::CircleConst *_indices = nullptr;
+};
+
+} // namespace
+
+TEST(FoldGatherTest, name)
+{
+ luci::FoldGatherPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST_F(S64FoldGatherSimpleTest, fold_gather_simple)
+{
+ luci::FoldGatherPass pass;
+ while (pass.run(graph()))
+ ;
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_NE(nullptr, folded_const);
+
+ // Chec type, shape, values of folded const
+ EXPECT_EQ(loco::DataType::S64, folded_const->dtype());
+ EXPECT_EQ(1, folded_const->rank());
+ EXPECT_EQ(1, folded_const->dim(0).value());
+ EXPECT_EQ(2, folded_const->at<loco::DataType::S64>(0));
+}
+
+TEST_F(S32FoldGatherTwoDimsTest, fold_gather_with_two_dim)
+{
+ luci::FoldGatherPass pass;
+ while (pass.run(graph()))
+ ;
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_NE(nullptr, folded_const);
+
+ // Chec type, shape, values of folded const
+ EXPECT_EQ(loco::DataType::S32, folded_const->dtype());
+ EXPECT_EQ(2, folded_const->rank());
+ EXPECT_EQ(2, folded_const->dim(0).value());
+ EXPECT_EQ(2, folded_const->dim(1).value());
+
+ EXPECT_EQ(2, folded_const->at<loco::DataType::S32>(0));
+ EXPECT_EQ(1, folded_const->at<loco::DataType::S32>(1));
+ EXPECT_EQ(5, folded_const->at<loco::DataType::S32>(2));
+ EXPECT_EQ(4, folded_const->at<loco::DataType::S32>(3));
+}
+
+TEST_F(S64FoldGatherSimpleTest, illegal_input_NEG)
+{
+ _indices->dtype(loco::DataType::FLOAT32);
+
+ luci::FoldGatherPass pass;
+ EXPECT_ANY_THROW(pass.run(graph()));
+}
+
+TEST_F(S64FoldGatherSimpleTest, illegal_axis_NEG)
+{
+ _gather->axis(1);
+
+ luci::FoldGatherPass pass;
+ EXPECT_ANY_THROW(pass.run(graph()));
+}
diff --git a/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp b/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp
index de973a431..68136b244 100644
--- a/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp
+++ b/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp
@@ -186,12 +186,12 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8)
// (1) normal case: qparam is propagated to input_1 and input_2
// (2) input used by other Op: input_1 is an input of input_2. qparam is propagated only to
// input_2
- // (3) subsequent concat: input_1 is concat. qparam is propagated only to input_2
+ // (3) subsequent concat: input_1 is concat. qparam is propagated to subsequent concat
// (4) const input: input_1 is const. constant values are quantized
// normal case: qparam of concat_node is propagated to input_1 and input_2
SimpleConcatGraph g(loco::DataType::U8);
- luci::propagate_concat_quantparam(&g.concat_node, loco::DataType::U8);
+ luci::propagate_concat_quantparam(&g.concat_node);
EXPECT_FLOAT_EQ(3.14, g.concat_node.quantparam()->scale[0]);
EXPECT_EQ(77, g.concat_node.quantparam()->zerop[0]);
EXPECT_FLOAT_EQ(3.14, g.input_1.quantparam()->scale[0]);
@@ -202,7 +202,7 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8)
// input_1 is an input of input_2. qparam is propagated only to input_2
SimpleConcatGraph g2(loco::DataType::U8);
g2.input_2.input(&g2.input_1);
- luci::propagate_concat_quantparam(&g2.concat_node, loco::DataType::U8);
+ luci::propagate_concat_quantparam(&g2.concat_node);
EXPECT_FLOAT_EQ(3.14, g2.concat_node.quantparam()->scale[0]);
EXPECT_EQ(77, g2.concat_node.quantparam()->zerop[0]);
EXPECT_FLOAT_EQ(1.0, g2.input_1.quantparam()->scale[0]);
@@ -210,19 +210,19 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8)
EXPECT_FLOAT_EQ(3.14, g2.input_2.quantparam()->scale[0]);
EXPECT_EQ(77, g2.input_2.quantparam()->zerop[0]);
- // input_1 is concat. qparam is propagated only to input_2
+ // input_1 is concat. qparam is propagated to subsequent concat
SubsequentConcatGraph sg(loco::DataType::U8);
- luci::propagate_concat_quantparam(&sg.concat_node, loco::DataType::U8);
+ luci::propagate_concat_quantparam(&sg.concat_node);
EXPECT_FLOAT_EQ(3.14, sg.concat_node.quantparam()->scale[0]);
EXPECT_EQ(77, sg.concat_node.quantparam()->zerop[0]);
- EXPECT_FLOAT_EQ(1.0, sg.input_1.quantparam()->scale[0]);
- EXPECT_EQ(1, sg.input_1.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(3.14, sg.input_1.quantparam()->scale[0]);
+ EXPECT_EQ(77, sg.input_1.quantparam()->zerop[0]);
EXPECT_FLOAT_EQ(3.14, sg.input_2.quantparam()->scale[0]);
EXPECT_EQ(77, sg.input_2.quantparam()->zerop[0]);
// input_1 is const. const values are quantized with the qparam of concat
ConstInputConcatGraph cg(loco::DataType::U8);
- luci::propagate_concat_quantparam(cg.concat_node, loco::DataType::U8);
+ luci::propagate_concat_quantparam(cg.concat_node);
EXPECT_FLOAT_EQ(0.1, cg.concat_node->quantparam()->scale[0]);
EXPECT_EQ(10, cg.concat_node->quantparam()->zerop[0]);
const auto cg_input_1 = loco::must_cast<luci::CircleConst *>(cg.concat_node->values(0));
@@ -248,7 +248,7 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8_NEG)
// concat has fused activation function
g.concat_node.fusedActivationFunction(luci::FusedActFunc::RELU);
- luci::propagate_concat_quantparam(&g.concat_node, loco::DataType::U8);
+ luci::propagate_concat_quantparam(&g.concat_node);
EXPECT_FLOAT_EQ(3.14, g.concat_node.quantparam()->scale[0]);
EXPECT_EQ(77, g.concat_node.quantparam()->zerop[0]);
EXPECT_FLOAT_EQ(1.0, g.input_1.quantparam()->scale[0]);
@@ -261,7 +261,7 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8_NEG)
// const values are quantized using its min/max
ConstInputConcatGraph cg(loco::DataType::U8);
cg.concat_node->fusedActivationFunction(luci::FusedActFunc::RELU);
- luci::propagate_concat_quantparam(cg.concat_node, loco::DataType::U8);
+ luci::propagate_concat_quantparam(cg.concat_node);
EXPECT_FLOAT_EQ(0.1, cg.concat_node->quantparam()->scale[0]);
EXPECT_EQ(10, cg.concat_node->quantparam()->zerop[0]);
const auto cg_input_1 = loco::must_cast<luci::CircleConst *>(cg.concat_node->values(0));
@@ -283,12 +283,12 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16)
// (1) normal case: qparam is propagated to input_1 and input_2
// (2) input used by other Op: input_1 is an input of input_2. qparam is propagated only to
// input_2
- // (3) subsequent concat: input_1 is concat. qparam is propagated only to input_2
+ // (3) subsequent concat: input_1 is concat. qparam is propagated to subsequent concat
// (4) const input: input_1 is const. constant values are quantized
// normal case: qparam of concat_node is propagated to input_1 and input_2
SimpleConcatGraph g(loco::DataType::S16);
- luci::propagate_concat_quantparam(&g.concat_node, loco::DataType::S16);
+ luci::propagate_concat_quantparam(&g.concat_node);
EXPECT_FLOAT_EQ(3.14, g.concat_node.quantparam()->scale[0]);
EXPECT_EQ(0, g.concat_node.quantparam()->zerop[0]);
EXPECT_FLOAT_EQ(3.14, g.input_1.quantparam()->scale[0]);
@@ -299,7 +299,7 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16)
// input_1 is an input of input_2. qparam is propagated only to input_2
SimpleConcatGraph g2(loco::DataType::S16);
g2.input_2.input(&g2.input_1);
- luci::propagate_concat_quantparam(&g2.concat_node, loco::DataType::S16);
+ luci::propagate_concat_quantparam(&g2.concat_node);
EXPECT_FLOAT_EQ(3.14, g2.concat_node.quantparam()->scale[0]);
EXPECT_EQ(0, g2.concat_node.quantparam()->zerop[0]);
EXPECT_FLOAT_EQ(1.0, g2.input_1.quantparam()->scale[0]);
@@ -309,17 +309,17 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16)
// input_1 is concat. qparam is propagated only to input_2
SubsequentConcatGraph sg(loco::DataType::S16);
- luci::propagate_concat_quantparam(&sg.concat_node, loco::DataType::S16);
+ luci::propagate_concat_quantparam(&sg.concat_node);
EXPECT_FLOAT_EQ(3.14, sg.concat_node.quantparam()->scale[0]);
EXPECT_EQ(0, sg.concat_node.quantparam()->zerop[0]);
- EXPECT_FLOAT_EQ(1.0, sg.input_1.quantparam()->scale[0]);
+ EXPECT_FLOAT_EQ(3.14, sg.input_1.quantparam()->scale[0]);
EXPECT_EQ(0, sg.input_1.quantparam()->zerop[0]);
EXPECT_FLOAT_EQ(3.14, sg.input_2.quantparam()->scale[0]);
EXPECT_EQ(0, sg.input_2.quantparam()->zerop[0]);
// input_1 is const. const values are quantized with the qparam of concat
ConstInputConcatGraph cg(loco::DataType::S16);
- luci::propagate_concat_quantparam(cg.concat_node, loco::DataType::S16);
+ luci::propagate_concat_quantparam(cg.concat_node);
EXPECT_FLOAT_EQ(0.1, cg.concat_node->quantparam()->scale[0]);
EXPECT_EQ(0, cg.concat_node->quantparam()->zerop[0]);
const auto cg_input_1 = loco::must_cast<luci::CircleConst *>(cg.concat_node->values(0));
@@ -345,7 +345,7 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16_NEG)
// concat has fused activation function
g.concat_node.fusedActivationFunction(luci::FusedActFunc::RELU);
- luci::propagate_concat_quantparam(&g.concat_node, loco::DataType::S16);
+ luci::propagate_concat_quantparam(&g.concat_node);
EXPECT_FLOAT_EQ(3.14, g.concat_node.quantparam()->scale[0]);
EXPECT_EQ(0, g.concat_node.quantparam()->zerop[0]);
EXPECT_FLOAT_EQ(1.0, g.input_1.quantparam()->scale[0]);
@@ -358,7 +358,7 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16_NEG)
// const values are quantized using its min/max
ConstInputConcatGraph cg(loco::DataType::S16);
cg.concat_node->fusedActivationFunction(luci::FusedActFunc::RELU);
- luci::propagate_concat_quantparam(cg.concat_node, loco::DataType::S16);
+ luci::propagate_concat_quantparam(cg.concat_node);
EXPECT_FLOAT_EQ(0.1, cg.concat_node->quantparam()->scale[0]);
EXPECT_EQ(0, cg.concat_node->quantparam()->zerop[0]);
const auto cg_input_1 = loco::must_cast<luci::CircleConst *>(cg.concat_node->values(0));
diff --git a/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp b/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp
new file mode 100644
index 000000000..b4975486d
--- /dev/null
+++ b/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp
@@ -0,0 +1,482 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/PropagateQParamBackwardPass.h"
+#include "QuantizationUtils.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Service/Nodes/CircleConst.h>
+#include <luci/Log.h>
+
+#include <cmath>
+
+namespace
+{
+
+void quant_const_values(luci::CircleConst *const_node, float scaling_factor, float zerop,
+ loco::DataType quant_type)
+{
+ uint32_t size = const_node->size<loco::DataType::FLOAT32>();
+
+ const float scaling_factor_inv = 1.0 / scaling_factor;
+ std::vector<int32_t> quantized_values(size);
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ auto data = static_cast<double>(const_node->at<loco::DataType::FLOAT32>(i));
+ double quantized_data = std::round(data * scaling_factor_inv) + zerop;
+ constexpr double int_max = static_cast<double>(std::numeric_limits<int32_t>::max());
+ constexpr double int_min = static_cast<double>(std::numeric_limits<int32_t>::min());
+ quantized_data = std::min(int_max, std::max(int_min, quantized_data));
+
+ quantized_values[i] = static_cast<int32_t>(quantized_data);
+ }
+
+ switch (quant_type)
+ {
+ case loco::DataType::U8:
+ const_node->dtype(loco::DataType::U8); // change the type of tensor
+ const_node->size<loco::DataType::U8>(size); // resize tensor
+ for (uint32_t i = 0; i < size; ++i)
+ const_node->at<loco::DataType::U8>(i) = std::min(255, std::max(0, quantized_values[i]));
+ break;
+ case loco::DataType::S16:
+ assert(zerop == 0);
+ const_node->dtype(loco::DataType::S16); // change the type of tensor
+ const_node->size<loco::DataType::S16>(size); // resize tensor
+ for (uint32_t i = 0; i < size; ++i)
+ const_node->at<loco::DataType::S16>(i) =
+ std::min(32767, std::max(-32767, quantized_values[i]));
+ break;
+ default:
+ throw std::runtime_error("Unsupported data type");
+ }
+}
+
+void overwrite_quantparam(const luci::CircleNode *source, luci::CircleNode *target)
+{
+ auto source_qparam = source->quantparam();
+ if (source_qparam == nullptr)
+ throw std::runtime_error("source quantparam is not found during overwrite");
+
+ auto target_qparam = target->quantparam();
+ if (target_qparam == nullptr)
+ {
+ auto quantparam = std::make_unique<luci::CircleQuantParam>();
+ target->quantparam(std::move(quantparam));
+ target_qparam = target->quantparam();
+
+ if (target_qparam == nullptr)
+ throw std::runtime_error("Creating new quant param failed");
+ }
+ target_qparam->min = source_qparam->min;
+ target_qparam->max = source_qparam->max;
+ target_qparam->scale = source_qparam->scale;
+ target_qparam->zerop = source_qparam->zerop;
+ target_qparam->quantized_dimension = source_qparam->quantized_dimension;
+}
+
+/**
+ * Tells if pad_v2 quantization should ignore padding value
+ * In that case padding const will be quantized with input parameters, and probably clipped
+ */
+bool ignore_pad_v2_const_quantization(const luci::CirclePadV2 *pad)
+{
+ // This is a workaround to quantize pad generated from MaxPoolWithArgmax operation properly
+ // TODO use metadata hints to detect this case
+ auto const_value_node = dynamic_cast<const luci::CircleConst *>(pad->arg(2));
+ if (!const_value_node)
+ return false;
+ if (const_value_node->dtype() == loco::DataType::FLOAT32)
+ {
+ float const_value = const_value_node->at<loco::DataType::FLOAT32>(0);
+ if (const_value == std::numeric_limits<float>::lowest())
+ return true;
+ }
+ return false;
+}
+
+/** EXAMPLE
+ *
+ * BEFORE
+ *
+ * [CircleNode] [CircleConst]
+ * (qparam1) (FP32)
+ * \ /
+ * \ /
+ * [CirclePack]
+ * (qparam2)
+ *
+ * AFTER
+ *
+ * [CircleNode] [CircleConst] [CircleConst] <- Dead node
+ * (qparam2) (qparam2) (FP32)
+ * \ /
+ * \ /
+ * [CirclePack]
+ * (qparam2)
+ *
+ * NOTE Quantization parameter of CirclePack (qparam2) is propagated to the inputs.
+ */
+void propagate_pack_quantparam(luci::CirclePack *pack)
+{
+ assert(pack->quantparam() != nullptr);
+
+ const auto num_inputs = pack->values_count();
+
+ for (uint32_t i = 0; i < num_inputs; i++)
+ {
+ auto node = loco::must_cast<luci::CircleNode *>(pack->arg(i));
+
+ // Quantize constant values
+ if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
+ {
+ luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
+ if (const_node->dtype() != loco::DataType::FLOAT32)
+ throw std::runtime_error("Unsupported data type for constant input of pack Op");
+
+ const auto pack_qparam = pack->quantparam();
+ if (pack_qparam == nullptr)
+ throw std::runtime_error("quantparam of pack is not found during propagation");
+
+ assert(pack_qparam->scale.size() == 1);
+ assert(pack_qparam->zerop.size() == 1);
+ const auto scaling_factor = pack_qparam->scale[0];
+ const auto zerop = pack_qparam->zerop[0];
+
+ auto new_const = luci::clone(const_node);
+ quant_const_values(new_const, scaling_factor, zerop, pack->dtype());
+ pack->values(i, new_const);
+ overwrite_quantparam(pack, new_const);
+ }
+ else
+ {
+ const auto succs = loco::succs(node);
+ if (succs.size() > 1)
+ continue;
+
+ // Non-const input must have been quantized
+ assert(node->quantparam() != nullptr);
+ overwrite_quantparam(pack, node);
+ }
+ }
+}
+
+/** EXAMPLE
+ *
+ *
+ *
+ * BEFORE
+ *
+ * [CircleNode] [CircleConst] [CircleConst] [CircleNode]
+ * (S32) (S32) (FP32) (U8 qparam1)
+ * \ \ / /
+ * \ \ / /
+ * \ \ / /
+ * -------[CircleOneHot]-------
+ * (U8 qparam2)
+ *
+ * AFTER
+ *
+ * [CircleNode] [CircleConst] [CircleConst] [CircleNode] [CircleConst] <- Dead node
+ * (S32) (S32) (U8 qparam2) (U8 qparam2) (FP32)
+ * \ \ / /
+ * \ \ / /
+ * \ \ / /
+ * -------[CircleOneHot]-------
+ * (U8 qparam2)
+ *
+ * NOTE Quantization parameter of CircleOneHot (qparam2) is propagated to on_value/off_value.
+ */
+void propagate_one_hot_quantparam(luci::CircleOneHot *one_hot)
+{
+ assert(one_hot->quantparam() != nullptr);
+
+ // Propagate quantization parameters from output to inputs,
+ // to fit both input and counstant_value in one quant range.
+ auto quant_input = [one_hot](void (luci::CircleOneHot::*arg_setter)(loco::Node *),
+ loco::Node *(luci::CircleOneHot::*arg_getter)() const) {
+ auto node = loco::must_cast<luci::CircleNode *>((one_hot->*arg_getter)());
+
+ // Quantize constant values
+ if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
+ {
+ luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
+ if (is_quantized(const_node))
+ return;
+
+ if (const_node->dtype() != loco::DataType::FLOAT32)
+ throw std::runtime_error("Unsupported data type for constant input of OneHot Op");
+
+ const auto qparam = one_hot->quantparam();
+ if (qparam == nullptr)
+ throw std::runtime_error("quantparam of OneHot is not found during propagation");
+
+ assert(qparam->scale.size() == 1);
+ const auto scaling_factor = qparam->scale.at(0);
+ const auto zerop = qparam->zerop.at(0);
+
+ auto new_const = luci::clone(const_node);
+ quant_const_values(new_const, scaling_factor, zerop, one_hot->dtype());
+ overwrite_quantparam(one_hot, new_const);
+ (one_hot->*arg_setter)(new_const);
+ }
+ else
+ {
+ const auto succs = loco::succs(node);
+ if (succs.size() > 1)
+ return;
+
+ // Non-const input must have been quantized
+ assert(node->quantparam() != nullptr);
+ overwrite_quantparam(one_hot, node);
+ }
+ };
+
+ quant_input(&luci::CircleOneHot::on_value, &luci::CircleOneHot::on_value);
+ quant_input(&luci::CircleOneHot::off_value, &luci::CircleOneHot::off_value);
+}
+
+} // namespace
+
+namespace luci
+{
+
+/** BEFORE
+ *
+ * [CircleNode] [CircleConst]
+ * (U8 qparam1) (FP32)
+ * \ /
+ * \ /
+ * [CircleConcatenation]
+ * (U8 qparam2)
+ *
+ * AFTER
+ * [CircleNode] [CircleConst] [CircleConst] <- Dead node
+ * (U8 qparam2) (U8 qparam2) (FP32)
+ * \ /
+ * \ /
+ * [CircleConcatenation]
+ * (U8 qparam2)
+ */
+void propagate_concat_quantparam(luci::CircleConcatenation *concat)
+{
+ assert(concat->quantparam() != nullptr);
+
+ const auto num_inputs = concat->numValues();
+
+ // Quantize const inputs using their values if concat has fused act function
+ if (concat->fusedActivationFunction() != luci::FusedActFunc::NONE)
+ {
+ for (uint32_t i = 0; i < num_inputs; i++)
+ {
+ auto node = concat->arg(i);
+ auto const_node = dynamic_cast<luci::CircleConst *>(node);
+ if (const_node != nullptr)
+ {
+ auto new_const = luci::clone(const_node);
+ quant_const(new_const, concat->dtype());
+ concat->values(i, new_const);
+ }
+ }
+ return;
+ }
+
+ for (uint32_t i = 0; i < num_inputs; i++)
+ {
+ auto node = loco::must_cast<luci::CircleNode *>(concat->arg(i));
+
+ // Quantize constant values
+ if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
+ {
+ luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
+
+ const auto concat_qparam = concat->quantparam();
+ assert(concat_qparam->scale.size() == 1);
+ const auto scaling_factor = concat_qparam->scale[0];
+ const auto zerop = concat_qparam->zerop[0];
+
+ auto new_const = luci::clone(const_node);
+ quant_const_values(new_const, scaling_factor, zerop, concat->dtype());
+ concat->values(i, new_const);
+ overwrite_quantparam(concat, new_const);
+ }
+ else
+ {
+ const auto succs = loco::succs(node);
+ if (succs.size() > 1)
+ continue;
+
+ // Non-const input must have been quantized
+ assert(node->quantparam() != nullptr);
+ overwrite_quantparam(concat, node);
+ }
+ }
+}
+
+/** BEFORE
+ *
+ * [CircleNode] [CircleConst] [CircleConst]
+ * (U8 qparam1) (S32) (FP32)
+ * \ | /
+ * \ | /
+ * [CirclePadV2]
+ * (U8 qparam2)
+ *
+ * AFTER (case 1)
+ *
+ * By default qparam is propagated from output to inputs to meet backend requirements.
+ *
+ * [CircleNode] [CircleConst] [CircleConst] [CircleConst] <- Dead node
+ * (U8 qparam2) (S32) (U8 qparam2) (FP32)
+ * \ | /
+ * \ | /
+ * [CirclePadV2]
+ * (U8 qparam2)
+ *
+ * AFTER (case 2)
+ *
+ * In case padded value is the lowest float value
+ * Qparam is propagated from input to output and constant.
+ *
+ * This is a special case for optimization constructed pad, needed to guarantee that
+ * extremely large negative constant do not stretch output quantization range.
+ *
+ * [CircleNode] [CircleConst] [CircleConst] [CircleConst] <- Dead node
+ * (U8 qparam1) (S32) (U8 qparam1) (FP32)
+ * \ | /
+ * \ | /
+ * [CirclePadV2]
+ * (U8 qparam1)
+ */
+void propagate_pad_v2_quantparam(luci::CirclePadV2 *pad_v2)
+{
+ if (ignore_pad_v2_const_quantization(pad_v2))
+ {
+ // propagate input quantization paramters from input to output and padding const value
+ auto pad_v2_input = loco::must_cast<luci::CircleNode *>(pad_v2->arg(0));
+ overwrite_quantparam(pad_v2_input, pad_v2);
+
+ auto const_value_node = loco::must_cast<luci::CircleConst *>(
+ pad_v2->arg(2)); // FIX ignore_pad_v2_const_quantization UNLESS
+ auto new_const = luci::clone(const_value_node);
+
+ const auto pad_v2_input_qparam = pad_v2_input->quantparam();
+ assert(pad_v2_input_qparam != nullptr);
+ assert(pad_v2_input_qparam->scale.size() == 1);
+ const auto scaling_factor = pad_v2_input_qparam->scale.at(0);
+ const auto zerop = pad_v2_input_qparam->zerop.at(0);
+
+ quant_const_values(new_const, scaling_factor, zerop, pad_v2->dtype());
+ overwrite_quantparam(pad_v2_input, new_const);
+ pad_v2->constant_values(new_const);
+ return;
+ }
+
+ // Propagate quantization paramters from output to inputs,
+ // to fit both input and counstant_value in one quant range.
+ auto quant_input = [pad_v2](void (CirclePadV2::*arg_setter)(loco::Node *), uint32_t arg) {
+ auto node = loco::must_cast<luci::CircleNode *>(pad_v2->arg(arg));
+
+ // Quantize constant values
+ if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
+ {
+ luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
+ if (is_quantized(const_node))
+ return;
+
+ if (const_node->dtype() != loco::DataType::FLOAT32)
+ throw std::runtime_error("Unsupported data type for constant input of PadV2 Op");
+
+ const auto pad_v2_qparam = pad_v2->quantparam();
+ if (pad_v2_qparam == nullptr)
+ throw std::runtime_error("quantparam of PadV2 is not found during propagation");
+
+ assert(pad_v2_qparam->scale.size() == 1);
+ const auto scaling_factor = pad_v2_qparam->scale.at(0);
+ const auto zerop = pad_v2_qparam->zerop.at(0);
+
+ auto new_const = luci::clone(const_node);
+ quant_const_values(new_const, scaling_factor, zerop, pad_v2->dtype());
+ overwrite_quantparam(pad_v2, new_const);
+ (pad_v2->*arg_setter)(new_const);
+ }
+ else
+ {
+ const auto succs = loco::succs(node);
+ if (succs.size() > 1)
+ return;
+
+ // Non-const input must have been quantized
+ assert(node->quantparam() != nullptr);
+ overwrite_quantparam(pad_v2, node);
+ }
+ };
+
+ quant_input(&CirclePadV2::input, 0);
+ quant_input(&CirclePadV2::constant_values, 2);
+}
+
+} // namespace luci
+
+namespace
+{
+
+// Visitor to propagate quantization parameters backwards
+struct PropagateQParamBackward final : public luci::CircleNodeMutableVisitor<void>
+{
+ void visit(luci::CircleNode *) {}
+
+ void visit(luci::CircleConcatenation *node) { propagate_concat_quantparam(node); }
+
+ void visit(luci::CircleOneHot *node) { propagate_one_hot_quantparam(node); }
+
+ void visit(luci::CirclePack *node) { propagate_pack_quantparam(node); }
+
+ void visit(luci::CirclePadV2 *node) { propagate_pad_v2_quantparam(node); }
+};
+
+} // namespace
+
+namespace luci
+{
+
+bool PropagateQParamBackwardPass::run(loco::Graph *g)
+{
+ LOGGER(l);
+
+ // We use reverse post-order traversal as qparam is propagated backward
+ auto nodes = loco::postorder_traversal(loco::output_nodes(g));
+ std::reverse(nodes.begin(), nodes.end());
+ for (auto node : nodes)
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ INFO(l) << "PropagateQParamBackwardPass visit node: " << circle_node->name() << std::endl;
+
+ // We can't propagate non-existent qparam
+ if (circle_node->quantparam() == nullptr)
+ continue;
+
+ PropagateQParamBackward pqb;
+ circle_node->accept(&pqb);
+ }
+
+ // This pass is only run once, so return false
+ // TODO Refactoring not to return meaningless value
+ return false;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/PropagateQParamBackwardPass.test.cpp b/compiler/luci/pass/src/PropagateQParamBackwardPass.test.cpp
new file mode 100644
index 000000000..33af70449
--- /dev/null
+++ b/compiler/luci/pass/src/PropagateQParamBackwardPass.test.cpp
@@ -0,0 +1,167 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/PropagateQParamBackwardPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+using namespace luci;
+
+namespace
+{
+
+void set_qparam(luci::CircleNode *node, float scale, int64_t zp)
+{
+ auto qparam = std::make_unique<luci::CircleQuantParam>();
+ qparam->scale.emplace_back(scale);
+ qparam->zerop.emplace_back(zp);
+
+ node->quantparam(std::move(qparam));
+}
+
+/**
+ * @brief Base Test Graph
+ */
+struct TestGraph
+{
+public:
+ virtual void init(void) = 0;
+};
+
+/**
+ * Graph with two concats
+ *
+ * [CircleInput] [CircleConst]
+ * \ /
+ * [CircleConcatenation] [CircleConst]
+ * | |
+ * [CircleConcatenation]
+ * |
+ * [CircleOutput]
+ *
+ * BEFORE
+ * - Concat1 and Concat 2 have different qparams
+ *
+ * AFTER
+ * - All Ops have the same qparam
+ */
+struct SubsequentConcatGraph : public TestGraph
+{
+public:
+ void init(void) final
+ {
+ // graph input and output
+ auto graph_input = g.inputs()->create();
+ auto graph_output = g.outputs()->create();
+
+ // input
+ input = g.nodes()->create<luci::CircleInput>();
+ input->index(graph_input->index());
+ input->shape({1, 4, 4, 3});
+ input->dtype(loco::DataType::U8);
+ set_qparam(input, 1.0, 1);
+
+ // const1
+ const1 = g.nodes()->create<luci::CircleConst>();
+ const1->shape({1, 4, 4, 3});
+ const1->dtype(loco::DataType::FLOAT32);
+ const1->size<loco::DataType::FLOAT32>(48);
+ for (uint32_t i = 0; i < 48; i++)
+ const1->at<loco::DataType::FLOAT32>(i) = i;
+
+ // concat1
+ concat1 = g.nodes()->create<luci::CircleConcatenation>(2);
+ concat1->shape({1, 4, 4, 6});
+ concat1->dtype(loco::DataType::U8);
+ set_qparam(concat1, 2.0, 2);
+ concat1->values(0, input);
+ concat1->values(1, const1);
+ concat1->fusedActivationFunction(luci::FusedActFunc::NONE);
+
+ // const2
+ const2 = g.nodes()->create<luci::CircleConst>();
+ const2->shape({1, 4, 4, 3});
+ const2->dtype(loco::DataType::FLOAT32);
+ const2->size<loco::DataType::FLOAT32>(48);
+ for (uint32_t i = 0; i < 48; i++)
+ const2->at<loco::DataType::FLOAT32>(i) = i;
+
+ // concat2
+ concat2 = g.nodes()->create<luci::CircleConcatenation>(2);
+ concat2->shape({1, 4, 4, 9});
+ concat2->dtype(loco::DataType::U8);
+ set_qparam(concat2, 3.0, 3);
+ concat2->values(0, concat1);
+ concat2->values(1, const2);
+ concat2->fusedActivationFunction(luci::FusedActFunc::NONE);
+
+ // output
+ output = g.nodes()->create<luci::CircleOutput>();
+ output->index(graph_output->index());
+ output->from(concat2);
+ output->shape({1, 4, 4, 9});
+ output->dtype(loco::DataType::U8);
+ set_qparam(output, 3.0, 3);
+ }
+
+public:
+ loco::Graph g;
+ CircleInput *input = nullptr;
+ CircleConcatenation *concat1 = nullptr;
+ CircleConcatenation *concat2 = nullptr;
+ CircleConst *const1 = nullptr;
+ CircleConst *const2 = nullptr;
+ CircleOutput *output = nullptr;
+};
+
+} // namespace
+
+TEST(PropagateQParamBackwardPassTest, name)
+{
+ luci::PropagateQParamBackwardPass pass(loco::DataType::U8);
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST(PropagateQParamBackwardPassTest, subsequent_propagation)
+{
+ SubsequentConcatGraph graph;
+
+ graph.init();
+
+ luci::PropagateQParamBackwardPass pass(loco::DataType::U8);
+
+ pass.run(&graph.g);
+
+ EXPECT_EQ(3.0, graph.concat2->quantparam()->scale[0]);
+ EXPECT_EQ(3, graph.concat2->quantparam()->zerop[0]);
+
+ auto const2 = loco::must_cast<CircleNode *>(graph.concat2->values(1));
+ EXPECT_EQ(3.0, const2->quantparam()->scale[0]);
+ EXPECT_EQ(3, const2->quantparam()->zerop[0]);
+
+ EXPECT_EQ(3.0, graph.concat1->quantparam()->scale[0]);
+ EXPECT_EQ(3, graph.concat1->quantparam()->zerop[0]);
+
+ auto const1 = loco::must_cast<CircleNode *>(graph.concat1->values(1));
+ EXPECT_EQ(3.0, const1->quantparam()->scale[0]);
+ EXPECT_EQ(3, const1->quantparam()->zerop[0]);
+
+ EXPECT_EQ(3.0, graph.input->quantparam()->scale[0]);
+ EXPECT_EQ(3, graph.input->quantparam()->zerop[0]);
+}
diff --git a/compiler/luci/pass/src/PropagateQParamForwardPass.cpp b/compiler/luci/pass/src/PropagateQParamForwardPass.cpp
new file mode 100644
index 000000000..003e4c293
--- /dev/null
+++ b/compiler/luci/pass/src/PropagateQParamForwardPass.cpp
@@ -0,0 +1,194 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/PropagateQParamForwardPass.h"
+
+#include "QuantizationUtils.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Log.h>
+
+#include <iostream>
+
+namespace
+{
+
+bool copy_qparam(luci::CircleQuantParam *src, luci::CircleQuantParam *dst)
+{
+ assert(src->scale.size() == dst->scale.size());
+ assert(src->zerop.size() == dst->zerop.size());
+
+ // src and dst have the same qparam
+ if (std::equal(src->scale.begin(), src->scale.end(), dst->scale.begin()) &&
+ std::equal(src->zerop.begin(), src->zerop.end(), dst->zerop.begin()) &&
+ src->quantized_dimension == dst->quantized_dimension)
+ return false;
+
+ dst->scale.assign(src->scale.begin(), src->scale.end());
+ dst->zerop.assign(src->zerop.begin(), src->zerop.end());
+ dst->quantized_dimension = src->quantized_dimension;
+ return true;
+}
+
+bool copy_qparam(luci::CircleNode *src, luci::CircleNode *dst)
+{
+ // Skip nodes that do not have quantparams
+ auto src_qparam = src->quantparam();
+ if (not src_qparam)
+ return false;
+
+ auto dst_qparam = dst->quantparam();
+ if (not dst_qparam)
+ return false;
+
+ return copy_qparam(src_qparam, dst_qparam);
+}
+
+// Visitor to propagate quantization parameters
+struct PropagateQParamForward final : public luci::CircleNodeMutableVisitor<bool>
+{
+ PropagateQParamForward() = default;
+
+ bool visit(luci::CircleNode *) { return false; }
+
+ bool visit(luci::CircleGather *node)
+ {
+ auto input_node = loco::must_cast<luci::CircleNode *>(node->params());
+ return copy_qparam(input_node, node);
+ }
+
+ bool visit(luci::CircleReshape *node)
+ {
+ auto input_node = loco::must_cast<luci::CircleNode *>(node->tensor());
+ return copy_qparam(input_node, node);
+ }
+
+ bool visit(luci::CircleTranspose *node)
+ {
+ auto input_node = loco::must_cast<luci::CircleNode *>(node->a());
+ return copy_qparam(input_node, node);
+ }
+
+ bool visit(luci::CircleStridedSlice *node)
+ {
+ auto input_node = loco::must_cast<luci::CircleNode *>(node->input());
+ return copy_qparam(input_node, node);
+ }
+
+ bool visit(luci::CircleSplitOut *node)
+ {
+ auto split = loco::must_cast<luci::CircleSplit *>(node->input());
+ auto input_node = loco::must_cast<luci::CircleNode *>(split->input());
+ return copy_qparam(input_node, node);
+ }
+
+ bool visit(luci::CircleSplitVOut *node)
+ {
+ auto splitv = loco::must_cast<luci::CircleSplitV *>(node->input());
+ auto input_node = loco::must_cast<luci::CircleNode *>(splitv->input());
+ return copy_qparam(input_node, node);
+ }
+
+ bool visit(luci::CircleUnpackOut *node)
+ {
+ auto unpack = loco::must_cast<luci::CircleUnpack *>(node->input());
+ auto input_node = loco::must_cast<luci::CircleNode *>(unpack->value());
+ return copy_qparam(input_node, node);
+ }
+
+ // Propagate qparam across Quantize op to ensure
+ // special qparams (pre-defined values, integer scale)
+ bool visit(luci::CircleQuantize *node)
+ {
+ auto input_node = loco::must_cast<luci::CircleNode *>(node->input());
+
+ // Skip if input_node is not quantized activation
+ if (input_node->dtype() != loco::DataType::U8 and input_node->dtype() != loco::DataType::S16)
+ return false;
+
+ // If input_node and node have the same dtype, Quantize op
+ // will do rescale, not requantize for mixed-precision
+ if (input_node->dtype() == node->dtype())
+ return false;
+
+ assert(node->dtype() == loco::DataType::U8 or node->dtype() == loco::DataType::S16);
+
+ auto prev_qparam = node->quantparam();
+ assert(prev_qparam);
+ assert(prev_qparam->scale.size() == 1);
+ assert(prev_qparam->zerop.size() == 1);
+
+ const auto prev_scale = prev_qparam->scale[0];
+ const auto prev_zerop = prev_qparam->zerop[0];
+
+ auto qtype = luci::activation_qtype(input_node);
+ switch (qtype)
+ {
+ case luci::ActivationQType::PreDefinedValue:
+ node->quantparam(luci::make_predefined_qparam(input_node->opcode(), node->dtype()));
+ break;
+ case luci::ActivationQType::IntScale:
+ luci::set_int_scale(node);
+ break;
+ default:
+ break;
+ }
+
+ assert(node->quantparam());
+ assert(node->quantparam()->scale.size() == 1);
+ assert(node->quantparam()->zerop.size() == 1);
+
+ const auto scale = node->quantparam()->scale[0];
+ const auto zerop = node->quantparam()->zerop[0];
+
+ // Compare qparam with saved values to detect update
+ return scale != prev_scale or zerop != prev_zerop;
+ }
+};
+
+} // namespace
+
+namespace luci
+{
+
+bool PropagateQParamForwardPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ LOGGER(l);
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ INFO(l) << "PropagateQParamForwardPass visit node: " << circle_node->name() << std::endl;
+
+ PropagateQParamForward pqp;
+ if (circle_node->accept(&pqp))
+ changed = true;
+
+ if (_TF_style_maxpool)
+ {
+ if (auto maxpool = dynamic_cast<luci::CircleMaxPool2D *>(node))
+ {
+ auto input = loco::must_cast<luci::CircleNode *>(maxpool->value());
+ copy_qparam(input, maxpool);
+ }
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/PropagateQParamForwardPass.test.cpp b/compiler/luci/pass/src/PropagateQParamForwardPass.test.cpp
new file mode 100644
index 000000000..a734c0873
--- /dev/null
+++ b/compiler/luci/pass/src/PropagateQParamForwardPass.test.cpp
@@ -0,0 +1,260 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/PropagateQParamForwardPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+void addQuantParam(luci::CircleNode *node, const std::vector<float> &scale,
+ const std::vector<int64_t> &zp)
+{
+ assert(node->quantparam() == nullptr);
+
+ auto quantparam = std::make_unique<luci::CircleQuantParam>();
+ quantparam->scale = scale;
+ quantparam->zerop = zp;
+ node->quantparam(std::move(quantparam));
+}
+
+/**
+ * Simple graph for test
+ *
+ * BEFORE
+ *
+ * [Conv] (qparam 1)
+ * |
+ * [Reshape] (qparam 2)
+ *
+ * AFTER
+ *
+ * [Conv] (qparam 2)
+ * |
+ * [Reshape] (qparam 2)
+ *
+ */
+class SimpleGraph
+{
+public:
+ SimpleGraph()
+ {
+ input = g.nodes()->create<luci::CircleInput>();
+ conv = g.nodes()->create<luci::CircleConv2D>();
+ reshape = g.nodes()->create<luci::CircleReshape>();
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_input = g.inputs()->create();
+ input->index(graph_input->index());
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ addQuantParam(conv, {0.1, 0.2, 0.3}, {0, 10, 20});
+ addQuantParam(reshape, {0.2, 0.4, 0.6}, {-10, 0, 10});
+
+ conv->input(input);
+ reshape->tensor(conv);
+ output->from(reshape);
+ }
+
+public:
+ loco::Graph g;
+ luci::CircleInput *input = nullptr;
+ luci::CircleConv2D *conv = nullptr;
+ luci::CircleReshape *reshape = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+/**
+ * Test graph for forward propagation in Quantize Op
+ *
+ * BEFORE
+ *
+ * [Tanh U8] (qparam 1 - pre-defined for U8)
+ * |
+ * [Quantize S16] (qparam 2 - not pre-defined value)
+ *
+ * AFTER
+ *
+ * [Tanh U8] (qparam 1 - pre-defined for U8)
+ * |
+ * [Quantize S16] (qparam 3 - pre-defined for S16)
+ *
+ */
+class TanhQuantizeGraph
+{
+public:
+ TanhQuantizeGraph()
+ {
+ input = g.nodes()->create<luci::CircleInput>();
+ tanh = g.nodes()->create<luci::CircleTanh>();
+ quantize = g.nodes()->create<luci::CircleQuantize>();
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_input = g.inputs()->create();
+ input->index(graph_input->index());
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ tanh->dtype(loco::DataType::U8);
+ quantize->dtype(loco::DataType::S16);
+
+ addQuantParam(tanh, {2.0f / 256.0f}, {128}); // pre-defined qparam for U8
+ addQuantParam(quantize, {1.0}, {0}); // not pre-defined values
+
+ tanh->x(input);
+ quantize->input(tanh);
+ output->from(quantize);
+ }
+
+public:
+ loco::Graph g;
+ luci::CircleInput *input = nullptr;
+ luci::CircleTanh *tanh = nullptr;
+ luci::CircleQuantize *quantize = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+/**
+ * Test graph for forward propagation in Quantize Op
+ *
+ * BEFORE
+ *
+ * [Floor U8] (qparam 1 - int scale)
+ * |
+ * [Quantize S16] (qparam 2 - not int scale)
+ *
+ * AFTER
+ *
+ * [Floor U8] (qparam 1 - int scale)
+ * |
+ * [Quantize S16] (qparam 3 - int scale)
+ *
+ */
+class FloorQuantizeGraph
+{
+public:
+ FloorQuantizeGraph()
+ {
+ input = g.nodes()->create<luci::CircleInput>();
+ floor = g.nodes()->create<luci::CircleFloor>();
+ quantize = g.nodes()->create<luci::CircleQuantize>();
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_input = g.inputs()->create();
+ input->index(graph_input->index());
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ floor->dtype(loco::DataType::U8);
+ quantize->dtype(loco::DataType::S16);
+
+ addQuantParam(floor, {4.0f}, {128}); // int scale
+ addQuantParam(quantize, {0.3}, {0}); // not int scale
+
+ floor->x(input);
+ quantize->input(floor);
+ output->from(quantize);
+ }
+
+public:
+ loco::Graph g;
+ luci::CircleInput *input = nullptr;
+ luci::CircleFloor *floor = nullptr;
+ luci::CircleQuantize *quantize = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+} // namespace
+
+TEST(PropagateQParamForwardPassTest, name)
+{
+ luci::PropagateQParamForwardPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST(PropagateQParamForward, simple)
+{
+ SimpleGraph g;
+
+ luci::PropagateQParamForwardPass pass;
+ while (pass.run(&g.g))
+ ;
+
+ EXPECT_FLOAT_EQ(0.1, g.reshape->quantparam()->scale[0]);
+ EXPECT_FLOAT_EQ(0.2, g.reshape->quantparam()->scale[1]);
+ EXPECT_FLOAT_EQ(0.3, g.reshape->quantparam()->scale[2]);
+ EXPECT_EQ(0, g.reshape->quantparam()->zerop[0]);
+ EXPECT_EQ(10, g.reshape->quantparam()->zerop[1]);
+ EXPECT_EQ(20, g.reshape->quantparam()->zerop[2]);
+}
+
+TEST(PropagateQParamForward, wrong_op_NEG)
+{
+ SimpleGraph g;
+ g.output->from(g.conv);
+ g.reshape->drop();
+
+ luci::PropagateQParamForwardPass pass;
+ while (pass.run(&g.g))
+ ;
+
+ EXPECT_FLOAT_EQ(0.1, g.conv->quantparam()->scale[0]);
+ EXPECT_FLOAT_EQ(0.2, g.conv->quantparam()->scale[1]);
+ EXPECT_FLOAT_EQ(0.3, g.conv->quantparam()->scale[2]);
+ EXPECT_EQ(0, g.conv->quantparam()->zerop[0]);
+ EXPECT_EQ(10, g.conv->quantparam()->zerop[1]);
+ EXPECT_EQ(20, g.conv->quantparam()->zerop[2]);
+}
+
+TEST(PropagateQParamForward, tanh_predefined_value)
+{
+ TanhQuantizeGraph g;
+
+ luci::PropagateQParamForwardPass pass;
+ while (pass.run(&g.g))
+ ;
+
+ EXPECT_FLOAT_EQ(1.0f / 32768.0f, g.quantize->quantparam()->scale[0]);
+}
+
+TEST(PropagateQParamForward, floor_int_scale)
+{
+ FloorQuantizeGraph g;
+
+ luci::PropagateQParamForwardPass pass;
+ while (pass.run(&g.g))
+ ;
+
+ EXPECT_FLOAT_EQ(1.0f, g.quantize->quantparam()->scale[0]);
+}
+
+TEST(PropagateQParamForward, same_dtype_NEG)
+{
+ FloorQuantizeGraph g;
+ g.quantize->dtype(loco::DataType::U8);
+
+ luci::PropagateQParamForwardPass pass;
+ while (pass.run(&g.g))
+ ;
+
+ // Qparam is not propagated as ifm/ofm of Quantize Op have the same dtype
+ EXPECT_FLOAT_EQ(0.3f, g.quantize->quantparam()->scale[0]);
+}
diff --git a/compiler/luci/pass/src/PropagateQuantParamPass.cpp b/compiler/luci/pass/src/PropagateQuantParamPass.cpp
deleted file mode 100644
index b1cb7a418..000000000
--- a/compiler/luci/pass/src/PropagateQuantParamPass.cpp
+++ /dev/null
@@ -1,107 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "luci/Pass/PropagateQuantParamPass.h"
-
-#include <luci/IR/CircleNodes.h>
-#include <luci/IR/CircleNodeVisitor.h>
-#include <luci/Log.h>
-
-#include <iostream>
-
-namespace
-{
-
-bool copy_qparam(luci::CircleQuantParam *src, luci::CircleQuantParam *dst)
-{
- assert(src->scale.size() == dst->scale.size());
- assert(src->zerop.size() == dst->zerop.size());
-
- // src and dst have the same qparam
- if (std::equal(src->scale.begin(), src->scale.end(), dst->scale.begin()) &&
- std::equal(src->zerop.begin(), src->zerop.end(), dst->zerop.begin()) &&
- src->quantized_dimension == dst->quantized_dimension)
- return false;
-
- dst->scale.assign(src->scale.begin(), src->scale.end());
- dst->zerop.assign(src->zerop.begin(), src->zerop.end());
- dst->quantized_dimension = src->quantized_dimension;
- return true;
-}
-
-bool copy_qparam(luci::CircleNode *src, luci::CircleNode *dst)
-{
- // Skip nodes that do not have quantparams
- auto src_qparam = src->quantparam();
- if (not src_qparam)
- return false;
-
- auto dst_qparam = dst->quantparam();
- if (not dst_qparam)
- return false;
-
- return copy_qparam(src_qparam, dst_qparam);
-}
-
-// Visitor to propagate quantization parameters
-struct PropagateQuantParam final : public luci::CircleNodeMutableVisitor<bool>
-{
- PropagateQuantParam() = default;
-
- bool visit(luci::CircleNode *) { return false; }
-
- bool visit(luci::CircleReshape *node)
- {
- auto input = node->tensor();
- if (loco::succs(input).size() != 1)
- return false;
-
- auto input_node = loco::must_cast<luci::CircleNode *>(input);
- return copy_qparam(input_node, node);
- }
-
- bool visit(luci::CircleTranspose *node)
- {
- auto input_node = loco::must_cast<luci::CircleNode *>(node->a());
- return copy_qparam(input_node, node);
- }
-
- // TODO : Add more Ops (e.g., layout-changing Ops)
-};
-
-} // namespace
-
-namespace luci
-{
-
-bool PropagateQuantParamPass::run(loco::Graph *g)
-{
- bool changed = false;
- LOGGER(l);
- for (auto node : loco::active_nodes(loco::output_nodes(g)))
- {
- auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- INFO(l) << "PropagateQuantParamPass visit node: " << circle_node->name() << std::endl;
-
- PropagateQuantParam pqp;
- if (circle_node->accept(&pqp))
- changed = true;
- }
-
- return changed;
-}
-
-} // namespace luci
diff --git a/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp b/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp
deleted file mode 100644
index 0f1564223..000000000
--- a/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp
+++ /dev/null
@@ -1,125 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "luci/Pass/PropagateQuantParamPass.h"
-
-#include <luci/IR/CircleNodes.h>
-
-#include <gtest/gtest.h>
-
-namespace
-{
-
-void addQuantParam(luci::CircleNode *node, const std::vector<float> &scale,
- const std::vector<int64_t> &zp)
-{
- assert(node->quantparam() == nullptr);
-
- auto quantparam = std::make_unique<luci::CircleQuantParam>();
- quantparam->scale = scale;
- quantparam->zerop = zp;
- node->quantparam(std::move(quantparam));
-}
-
-/**
- * Simple graph for test
- *
- * BEFORE
- *
- * [Conv] (qparam 1)
- * |
- * [Reshape] (qparam 2)
- *
- * AFTER
- *
- * [Conv] (qparam 2)
- * |
- * [Reshape] (qparam 2)
- *
- */
-class SimpleGraph
-{
-public:
- SimpleGraph()
- {
- input = g.nodes()->create<luci::CircleInput>();
- conv = g.nodes()->create<luci::CircleConv2D>();
- reshape = g.nodes()->create<luci::CircleReshape>();
- output = g.nodes()->create<luci::CircleOutput>();
-
- auto graph_input = g.inputs()->create();
- input->index(graph_input->index());
- auto graph_output = g.outputs()->create();
- output->index(graph_output->index());
-
- addQuantParam(conv, {0.1, 0.2, 0.3}, {0, 10, 20});
- addQuantParam(reshape, {0.2, 0.4, 0.6}, {-10, 0, 10});
-
- conv->input(input);
- reshape->tensor(conv);
- output->from(reshape);
- }
-
-public:
- loco::Graph g;
- luci::CircleInput *input;
- luci::CircleConv2D *conv;
- luci::CircleReshape *reshape;
- luci::CircleOutput *output;
-};
-
-} // namespace
-
-TEST(PropagateQuantParamPassTest, name)
-{
- luci::PropagateQuantParamPass pass;
- auto const name = pass.name();
- ASSERT_NE(nullptr, name);
-}
-
-TEST(PropagateQuantParam, simple)
-{
- SimpleGraph g;
-
- luci::PropagateQuantParamPass pass;
- while (pass.run(&g.g))
- ;
-
- EXPECT_FLOAT_EQ(0.1, g.reshape->quantparam()->scale[0]);
- EXPECT_FLOAT_EQ(0.2, g.reshape->quantparam()->scale[1]);
- EXPECT_FLOAT_EQ(0.3, g.reshape->quantparam()->scale[2]);
- EXPECT_EQ(0, g.reshape->quantparam()->zerop[0]);
- EXPECT_EQ(10, g.reshape->quantparam()->zerop[1]);
- EXPECT_EQ(20, g.reshape->quantparam()->zerop[2]);
-}
-
-TEST(PropagateQuantParam, wrong_op_NEG)
-{
- SimpleGraph g;
- g.output->from(g.conv);
- g.reshape->drop();
-
- luci::PropagateQuantParamPass pass;
- while (pass.run(&g.g))
- ;
-
- EXPECT_FLOAT_EQ(0.1, g.conv->quantparam()->scale[0]);
- EXPECT_FLOAT_EQ(0.2, g.conv->quantparam()->scale[1]);
- EXPECT_FLOAT_EQ(0.3, g.conv->quantparam()->scale[2]);
- EXPECT_EQ(0, g.conv->quantparam()->zerop[0]);
- EXPECT_EQ(10, g.conv->quantparam()->zerop[1]);
- EXPECT_EQ(20, g.conv->quantparam()->zerop[2]);
-}
diff --git a/compiler/luci/pass/src/QuantizationUtils.cpp b/compiler/luci/pass/src/QuantizationUtils.cpp
index 2f6fed46e..ad86cedf4 100644
--- a/compiler/luci/pass/src/QuantizationUtils.cpp
+++ b/compiler/luci/pass/src/QuantizationUtils.cpp
@@ -33,43 +33,6 @@ bool is_quantized(const CircleNode *node)
node->dtype() == loco::DataType::S64); // bias (int16 quant)
}
-// Check if node is weights of conv2d, depthwise_conv2d, or fully_connected layer
-bool is_weights(CircleNode *node)
-{
- auto circle_const = dynamic_cast<CircleConst *>(node);
- if (circle_const == nullptr)
- return false;
-
- auto succs = loco::succs(node);
-
- // Node is weights if it is the weights of all of its successors
- for (auto out : succs)
- {
- bool is_weights = false;
-
- auto conv = dynamic_cast<CircleConv2D *>(out);
- if (conv != nullptr && conv->filter() == circle_const)
- is_weights = true;
-
- auto dw_conv = dynamic_cast<CircleDepthwiseConv2D *>(out);
- if (dw_conv != nullptr && dw_conv->filter() == circle_const)
- is_weights = true;
-
- auto t_conv = dynamic_cast<CircleTransposeConv *>(out);
- if (t_conv != nullptr && t_conv->filter() == circle_const && circle_const->rank() == 4)
- is_weights = true;
-
- auto fc = dynamic_cast<CircleFullyConnected *>(out);
- if (fc != nullptr && fc->weights() == circle_const)
- is_weights = true;
-
- if (!is_weights)
- return false;
- }
-
- return true;
-}
-
uint8_t fp32_to_uint8_cast(float f)
{
assert(std::numeric_limits<uint8_t>::min() <= f);
@@ -77,7 +40,6 @@ uint8_t fp32_to_uint8_cast(float f)
return static_cast<uint8_t>(f);
}
-// Per-layer quantization of weights (const tensor) using given min/max values
void asymmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float max,
float &scaling_factor, int64_t &zp, float &nudged_min,
float &nudged_max)
@@ -107,7 +69,6 @@ void asymmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float
}
}
-// Per-layer quantization of weights (const tensor) using given min/max values
void symmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float max,
float &scaling_factor, int64_t &zp, float &nudged_min,
float &nudged_max)
@@ -315,4 +276,123 @@ uint32_t cal_offset(loco::TensorShape &dimension, uint32_t *indices)
indices[2] * dimension.dim(3).value() + indices[3];
}
+ActivationQType activation_qtype(const CircleNode *node)
+{
+ auto fused_act_node = dynamic_cast<const CircleNodeMixin<CircleNodeTrait::FusedActFunc> *>(node);
+ if (fused_act_node && fused_act_node->fusedActivationFunction() == FusedActFunc::TANH)
+ return ActivationQType::PreDefinedValue;
+
+ switch (node->opcode())
+ {
+ case CircleOpcode::LOGISTIC:
+ case CircleOpcode::TANH:
+ case CircleOpcode::SOFTMAX:
+ return ActivationQType::PreDefinedValue;
+ case CircleOpcode::FLOOR:
+ case CircleOpcode::FLOOR_DIV:
+ case CircleOpcode::FLOOR_MOD:
+ case CircleOpcode::CEIL:
+ return ActivationQType::IntScale;
+ default:
+ break;
+ }
+
+ return ActivationQType::MinMax;
+}
+
+std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, loco::DataType dtype)
+{
+ auto qparam = std::make_unique<CircleQuantParam>();
+
+ auto set_qparam = [&qparam](float scale, int64_t zp) {
+ qparam->scale.emplace_back(scale);
+ qparam->zerop.emplace_back(zp);
+ };
+
+ switch (opcode)
+ {
+ case CircleOpcode::LOGISTIC:
+ if (dtype == loco::DataType::U8)
+ set_qparam(1.0f / 256.0f, 0);
+ else
+ {
+ assert(dtype == loco::DataType::S16);
+ set_qparam(1.0f / 32768.0f, 0);
+ }
+ break;
+ case CircleOpcode::TANH:
+ if (dtype == loco::DataType::U8)
+ set_qparam(2.0f / 256.0f, 128);
+ else
+ {
+ assert(dtype == loco::DataType::S16);
+ set_qparam(1.0f / 32768.0f, 0);
+ }
+ break;
+ case CircleOpcode::SOFTMAX:
+ if (dtype == loco::DataType::U8)
+ set_qparam(1.0f / 255.0f, 0);
+ else
+ {
+ assert(dtype == loco::DataType::S16);
+ set_qparam(1.0f / 32767.0f, 0);
+ }
+ break;
+ default:
+ throw std::runtime_error("Unsupported opcode with pre-defined qparam");
+ }
+ return std::move(qparam);
+}
+
+// For nodes with integer output, we use integer scale
+void set_int_scale(luci::CircleNode *node)
+{
+ assert(node); // FIX_CALLER_UNLESS
+
+ auto qparam = node->quantparam();
+ assert(qparam); // FIX_CALLER_UNLESS
+ assert(qparam->scale.size() == 1); // FIX_CALLER_UNLESS
+
+ auto fp_scale = qparam->scale[0];
+ qparam->scale[0] = fp_scale < 1 ? 1.0f : std::round(fp_scale);
+}
+
+void quant_const(luci::CircleConst *node, loco::DataType quant_type)
+{
+ assert(node->dtype() == loco::DataType::FLOAT32);
+
+ float min = std::numeric_limits<float>::max();
+ float max = std::numeric_limits<float>::lowest();
+ for (uint32_t i = 0; i < node->size<loco::DataType::FLOAT32>(); i++)
+ {
+ auto data = node->at<loco::DataType::FLOAT32>(i);
+ min = data < min ? data : min;
+ max = data > max ? data : max;
+ }
+
+ float scaling_factor{0.0};
+ int64_t zp{0};
+ float nudged_min{0.0};
+ float nudged_max{0.0};
+
+ switch (quant_type)
+ {
+ case loco::DataType::U8:
+ asymmetric_wquant_with_minmax_per_layer(node, min, max, scaling_factor, zp, nudged_min,
+ nudged_max);
+ break;
+ case loco::DataType::S16:
+ symmetric_wquant_with_minmax_per_layer(node, min, max, scaling_factor, zp, nudged_min,
+ nudged_max);
+ break;
+ default:
+ throw std::runtime_error("Unsupported data type");
+ }
+
+ auto quantparam = std::make_unique<luci::CircleQuantParam>();
+ quantparam->scale.push_back(scaling_factor);
+ quantparam->zerop.push_back(zp);
+ node->quantparam(std::move(quantparam));
+}
+
} // namespace luci
diff --git a/compiler/luci/pass/src/QuantizationUtils.h b/compiler/luci/pass/src/QuantizationUtils.h
index 605f6a77e..cd8cec95a 100644
--- a/compiler/luci/pass/src/QuantizationUtils.h
+++ b/compiler/luci/pass/src/QuantizationUtils.h
@@ -23,33 +23,61 @@
namespace luci
{
+// Compute scale/zp using given min/max for symmetric quantization (int16)
void compute_sym_scale_zp(float min, float max, float &scaling_factor, int64_t &zp,
float &nudged_min, float &nudged_max);
+// Compute scale/zp using given min/max for asymmetric quantization (uint8)
void compute_asym_scale_zp(float min, float max, float &scaling_factor, int64_t &zp,
float &nudged_min, float &nudged_max);
+// Asymmetric per-layer quantization of weights (const tensor) using given min/max values
+// NOTE: in-place update of node data
void asymmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float max,
float &scaling_factor, int64_t &zp, float &nudged_min,
float &nudged_max);
+// Symmetric per-layer quantization of weights (const tensor) using given min/max values
+// NOTE: in-place update of node data
void symmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float max,
float &scaling_factor, int64_t &zp, float &nudged_min,
float &nudged_max);
+// Helper function to get channel dimension
+// TODO Embed this function into iterate_per_channel
bool get_channel_dim_index(CircleConst *node, loco::TensorShape &dimension,
int32_t &channel_dim_index);
+// Calculate offset of the given indices in dimension
uint32_t cal_offset(loco::TensorShape &dimension, uint32_t *indices);
-void propagate_concat_quantparam(luci::CircleConcatenation *concat, loco::DataType quant_type);
+// Backward propagation of concatenation qparam
+void propagate_concat_quantparam(luci::CircleConcatenation *concat);
-void propagate_pad_v2_quantparam(luci::CirclePadV2 *pad_v2, loco::DataType quant_type);
-
-bool is_weights(CircleNode *node);
+// Backward propagation of pad_v2 qparam
+void propagate_pad_v2_quantparam(luci::CirclePadV2 *pad_v2);
+// Return true if the node is quantized
bool is_quantized(const CircleNode *node);
+enum ActivationQType
+{
+ MinMax, // Quantize using recorded min/max
+ PreDefinedValue, // Quantize using pre-defined values
+ IntScale, // Round scale to a positive integer
+};
+
+ActivationQType activation_qtype(const CircleNode *node);
+
+// Create qparam with pre-defined values for speical operators
+std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, loco::DataType dtype);
+
+// Update node's scale to a positive integer (for special Ops e.g., Floor, Ceil)
+void set_int_scale(luci::CircleNode *node);
+
+// Quantize const tensor using its min/max values
+void quant_const(luci::CircleConst *node, loco::DataType quant_type);
+
} // namespace luci
#endif // __LUCI_QUANTIZATION_UTILS_H__
diff --git a/compiler/luci/pass/src/QuantizeActivation.cpp b/compiler/luci/pass/src/QuantizeActivation.cpp
new file mode 100644
index 000000000..149331824
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizeActivation.cpp
@@ -0,0 +1,296 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "QuantizeActivation.h"
+#include "QuantizationUtils.h"
+
+#include <luci/Service/Nodes/CircleConst.h>
+#include <luci/Log.h>
+
+#include <algorithm>
+#include <cmath>
+
+using namespace luci;
+
+namespace
+{
+
+bool has_min_max(const CircleNode *node)
+{
+ return node->quantparam() && !node->quantparam()->min.empty() && !node->quantparam()->max.empty();
+}
+
+} // namespace
+
+// QuantizeActivation
+namespace luci
+{
+
+void QuantizeActivation::visit(luci::CircleNode *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeActivation visit node: " << node->name() << std::endl;
+
+ // Check if this is already quantized
+ if (is_quantized(node))
+ return;
+
+ // Check if this is bool type (bool type is not quantized)
+ if (node->dtype() == loco::DataType::BOOL)
+ return;
+
+ // Check if this is const (const activation is handled by QuantizeConstInputActivation)
+ // NOTE QuantizePreChecker guarantees weights/bias are const.
+ // Update this code when we accept non-const weights/bias.
+ if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
+ return;
+
+ // Check if this is activation
+ // We assume min/max are recorded only for activations
+ if (has_min_max(node))
+ {
+ // Quantize using recorded min/max
+ auto quantparam = node->quantparam();
+ assert(quantparam);
+ assert(quantparam->min.size() == 1); // only support layer-wise quant
+ assert(quantparam->max.size() == 1); // only support layer-wise quant
+ auto min = quantparam->min[0];
+ auto max = quantparam->max[0];
+
+ float scaling_factor{0};
+ int64_t zp{0};
+ float nudged_min{0};
+ float nudged_max{0};
+
+ if (output_type == loco::DataType::U8)
+ {
+ compute_asym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
+ node->dtype(loco::DataType::U8);
+ }
+ else
+ {
+ compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
+ node->dtype(loco::DataType::S16);
+ }
+
+ node->quantparam()->scale.push_back(scaling_factor);
+ node->quantparam()->zerop.push_back(zp);
+ }
+ // Fix special attributes
+ if (node->opcode() == luci::CircleOpcode::CAST)
+ {
+ auto *cast = loco::must_cast<luci::CircleCast *>(node);
+ auto *cast_input = loco::must_cast<luci::CircleNode *>(cast->x());
+
+ // make sure that cast_input is already quantized
+ assert(cast_input->dtype() != loco::DataType::FLOAT32);
+ cast->in_data_type(cast_input->dtype());
+ cast->out_data_type(cast->dtype());
+ }
+}
+
+} // namespace luci
+
+// QuantizeSpecialActivation
+namespace luci
+{
+
+void QuantizeSpecialActivation::visit(luci::CircleNode *node)
+{
+ // Nodes fused with activation functions which need special quantization
+ auto fused_act_node = dynamic_cast<CircleNodeMixin<CircleNodeTrait::FusedActFunc> *>(node);
+ if (fused_act_node != nullptr && fused_act_node->fusedActivationFunction() == FusedActFunc::TANH)
+ {
+ auto qparam = make_predefined_qparam(luci::CircleOpcode::TANH, output_type);
+ node->quantparam(std::move(qparam));
+ }
+}
+
+void QuantizeSpecialActivation::visit(luci::CircleLogistic *node)
+{
+ assert(activation_qtype(node) == luci::ActivationQType::PreDefinedValue);
+ auto qparam = make_predefined_qparam(luci::CircleOpcode::LOGISTIC, output_type);
+ node->quantparam(std::move(qparam));
+}
+
+void QuantizeSpecialActivation::visit(luci::CircleTanh *node)
+{
+ assert(activation_qtype(node) == luci::ActivationQType::PreDefinedValue);
+ auto qparam = make_predefined_qparam(luci::CircleOpcode::TANH, output_type);
+ node->quantparam(std::move(qparam));
+}
+
+void QuantizeSpecialActivation::visit(luci::CircleSoftmax *node)
+{
+ assert(activation_qtype(node) == luci::ActivationQType::PreDefinedValue);
+ auto qparam = make_predefined_qparam(luci::CircleOpcode::SOFTMAX, output_type);
+ node->quantparam(std::move(qparam));
+}
+
+void QuantizeSpecialActivation::visit(luci::CircleFloor *node)
+{
+ assert(activation_qtype(node) == luci::ActivationQType::IntScale);
+ set_int_scale(node);
+}
+
+void QuantizeSpecialActivation::visit(luci::CircleFloorDiv *node)
+{
+ assert(activation_qtype(node) == luci::ActivationQType::IntScale);
+ set_int_scale(node);
+}
+
+void QuantizeSpecialActivation::visit(luci::CircleFloorMod *node)
+{
+ assert(activation_qtype(node) == luci::ActivationQType::IntScale);
+ set_int_scale(node);
+}
+
+void QuantizeSpecialActivation::visit(luci::CircleCeil *node)
+{
+ assert(activation_qtype(node) == luci::ActivationQType::IntScale);
+ set_int_scale(node);
+}
+
+} // namespace luci
+
+// QuantizeConstInputActivation
+namespace luci
+{
+
+// Default behavior (NYI)
+void QuantizeConstInputActivation::visit(luci::CircleNode *node)
+{
+ for (uint32_t i = 0; i < node->arity(); i++)
+ {
+ auto input_node = node->arg(i);
+ auto const_node = dynamic_cast<luci::CircleConst *>(input_node);
+ if (const_node != nullptr)
+ throw std::runtime_error("Unsupported Op for const inputs");
+ }
+}
+
+// INPUT_NAME is the only activation of NODE
+#define QUANTIZE_SINGLE_CONST_INPUT(NODE, INPUT_NAME) \
+ void QuantizeConstInputActivation::visit(NODE *node) \
+ { \
+ auto input = node->INPUT_NAME(); \
+ auto const_node = dynamic_cast<luci::CircleConst *>(input); \
+ if (const_node && !is_quantized(const_node)) \
+ { \
+ auto new_const = luci::clone(const_node); \
+ quant_const(new_const, _output_type); \
+ node->INPUT_NAME(new_const); \
+ } \
+ }
+
+// INPUT_NAME1 and INPUT_NAME2 are the only activations of NODE
+#define QUANTIZE_TWO_CONST_INPUTS(NODE, INPUT_NAME1, INPUT_NAME2) \
+ void QuantizeConstInputActivation::visit(NODE *node) \
+ { \
+ auto input1 = node->INPUT_NAME1(); \
+ auto const_node1 = dynamic_cast<luci::CircleConst *>(input1); \
+ if (const_node1 && !is_quantized(const_node1)) \
+ { \
+ auto new_const1 = luci::clone(const_node1); \
+ quant_const(new_const1, _output_type); \
+ node->INPUT_NAME1(new_const1); \
+ } \
+ auto input2 = node->INPUT_NAME2(); \
+ auto const_node2 = dynamic_cast<luci::CircleConst *>(input2); \
+ if (const_node2 && !is_quantized(const_node2)) \
+ { \
+ auto new_const2 = luci::clone(const_node2); \
+ quant_const(new_const2, _output_type); \
+ node->INPUT_NAME2(new_const2); \
+ } \
+ }
+
+// Ops that receive a single activation as an input
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleArgMax, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleArgMin, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleBatchToSpaceND, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleDepthToSpace, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleElu, features)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleExp, x)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleFloor, x)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleGather, params)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleLocalResponseNormalization, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleLogistic, x)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleMean, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleMirrorPad, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CirclePad, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleReduceAny, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleReduceProd, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleReduceMax, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleReduceMin, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleReshape, tensor)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleResizeBilinear, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleResizeNearestNeighbor, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleReverseSequence, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleRsqrt, x)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSlice, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSoftmax, logits)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSpaceToBatchND, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSpaceToDepth, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSplit, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSplitV, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSqrt, x)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleStridedSlice, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSum, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleTanh, x)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleTile, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleTopKV2, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleTranspose, a)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleUnpack, value)
+
+// Ops that receive two activations as inputs
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleAdd, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleBatchMatMul, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleDiv, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleEqual, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleFloorDiv, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleGreater, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleGreaterEqual, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleLess, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleLessEqual, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleMaximum, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleMinimum, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleMul, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleNotEqual, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CirclePow, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleSub, x, y)
+
+// AddN has arbitrary number of inputs
+void QuantizeConstInputActivation::visit(luci::CircleAddN *node)
+{
+ auto arity = node->arity();
+ for (uint32_t i = 0; i < arity; i++)
+ {
+ auto input_node = node->inputs(i);
+ auto const_node = dynamic_cast<luci::CircleConst *>(input_node);
+ if (const_node && !is_quantized(const_node))
+ {
+ auto new_const = luci::clone(const_node);
+ quant_const(new_const, _output_type);
+ node->inputs(i, new_const);
+ }
+ }
+}
+
+#undef QUANTIZE_SINGLE_CONST_INPUT
+#undef QUANTIZE_TWO_CONST_INPUTS
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/QuantizeActivation.h b/compiler/luci/pass/src/QuantizeActivation.h
new file mode 100644
index 000000000..fc32d1cde
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizeActivation.h
@@ -0,0 +1,165 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_QUANTIZATION_ACTIVATION_H__
+#define __LUCI_QUANTIZATION_ACTIVATION_H__
+
+#include <luci/IR/CircleNodeVisitor.h>
+
+namespace luci
+{
+
+/**
+ * @brief Quantize non-const activation using recorded min/max values
+ */
+struct QuantizeActivation final : public luci::CircleNodeMutableVisitor<void>
+{
+ QuantizeActivation(loco::DataType input, loco::DataType output)
+ : input_type(input), output_type(output)
+ {
+ }
+
+ loco::DataType input_type;
+ loco::DataType output_type;
+
+ // Quantize each node using recorded min/max
+ void visit(luci::CircleNode *node);
+};
+
+/**
+ * @brief Quantize non-const activaion using pre-defined scale/zp for special Ops
+ */
+struct QuantizeSpecialActivation final : public luci::CircleNodeMutableVisitor<void>
+{
+ QuantizeSpecialActivation(loco::DataType input, loco::DataType output)
+ : input_type(input), output_type(output)
+ {
+ }
+
+ loco::DataType input_type;
+ loco::DataType output_type;
+
+ void visit(luci::CircleNode *node);
+ void visit(luci::CircleLogistic *node);
+ void visit(luci::CircleTanh *node);
+ void visit(luci::CircleSoftmax *node);
+ void visit(luci::CircleFloor *node);
+ void visit(luci::CircleFloorDiv *node);
+ void visit(luci::CircleFloorMod *node);
+ void visit(luci::CircleCeil *node);
+};
+
+// Quantize constant input activation of a node
+// The input of a node is quantized if it is
+// 1. Constant (instance of CircleConst*)
+// 2. Activation (other inputs e.g., weights, bias, axis, etc should not be quantized here)
+struct QuantizeConstInputActivation final : public luci::CircleNodeMutableVisitor<void>
+{
+ QuantizeConstInputActivation(loco::DataType output_type) : _output_type(output_type) {}
+
+private:
+ loco::DataType _output_type;
+
+// Skip NODE
+#define SKIP(NODE) \
+ void visit(NODE *) {}
+
+ // Handled in QuantizeWeights and QuantizeBias
+ SKIP(luci::CircleConv2D)
+ SKIP(luci::CircleDepthwiseConv2D)
+ SKIP(luci::CircleFullyConnected)
+ SKIP(luci::CircleInstanceNorm)
+ SKIP(luci::CirclePRelu)
+ SKIP(luci::CircleTransposeConv)
+
+ // Handled in PropagateQParamBackwardPass
+ SKIP(luci::CircleConcatenation)
+ SKIP(luci::CirclePadV2)
+ SKIP(luci::CirclePack)
+ SKIP(luci::CircleOneHot)
+
+ // Inputs of logical Ops are bool, thus not quantized
+ SKIP(luci::CircleLogicalOr)
+ SKIP(luci::CircleLogicalAnd)
+ SKIP(luci::CircleLogicalNot)
+
+#undef SKIP
+
+ // Default behavior (NYI)
+ void visit(luci::CircleNode *node);
+
+ // Ops that receive a single activation as an input
+ void visit(luci::CircleArgMax *node);
+ void visit(luci::CircleArgMin *node);
+ void visit(luci::CircleBatchToSpaceND *node);
+ void visit(luci::CircleDepthToSpace *node);
+ void visit(luci::CircleElu *node);
+ void visit(luci::CircleExp *node);
+ void visit(luci::CircleFloor *node);
+ void visit(luci::CircleGather *node);
+ void visit(luci::CircleLocalResponseNormalization *node);
+ void visit(luci::CircleLogistic *node);
+ void visit(luci::CircleMean *node);
+ void visit(luci::CircleMirrorPad *node);
+ void visit(luci::CirclePad *node);
+ void visit(luci::CircleReduceAny *node);
+ void visit(luci::CircleReduceProd *node);
+ void visit(luci::CircleReduceMax *node);
+ void visit(luci::CircleReduceMin *node);
+ void visit(luci::CircleReshape *node);
+ void visit(luci::CircleResizeBilinear *node);
+ void visit(luci::CircleResizeNearestNeighbor *node);
+ void visit(luci::CircleReverseSequence *node);
+ void visit(luci::CircleRsqrt *node);
+ void visit(luci::CircleSlice *node);
+ void visit(luci::CircleSoftmax *node);
+ void visit(luci::CircleSpaceToBatchND *node);
+ void visit(luci::CircleSpaceToDepth *node);
+ void visit(luci::CircleSplit *node);
+ void visit(luci::CircleSplitV *node);
+ void visit(luci::CircleSqrt *node);
+ void visit(luci::CircleStridedSlice *node);
+ void visit(luci::CircleSum *node);
+ void visit(luci::CircleTanh *node);
+ void visit(luci::CircleTile *node);
+ void visit(luci::CircleTopKV2 *node);
+ void visit(luci::CircleTranspose *node);
+ void visit(luci::CircleUnpack *node);
+
+ // Ops that receive two activations as inputs
+ void visit(luci::CircleAdd *node);
+ void visit(luci::CircleBatchMatMul *node);
+ void visit(luci::CircleDiv *node);
+ void visit(luci::CircleEqual *node);
+ void visit(luci::CircleFloorDiv *node);
+ void visit(luci::CircleGreater *node);
+ void visit(luci::CircleGreaterEqual *node);
+ void visit(luci::CircleLess *node);
+ void visit(luci::CircleLessEqual *node);
+ void visit(luci::CircleMaximum *node);
+ void visit(luci::CircleMinimum *node);
+ void visit(luci::CircleMul *node);
+ void visit(luci::CircleNotEqual *node);
+ void visit(luci::CirclePow *node);
+ void visit(luci::CircleSub *node);
+
+ // AddN has arbitrary number of inputs
+ void visit(luci::CircleAddN *node);
+};
+
+} // namespace luci
+
+#endif // __LUCI_QUANTIZATION_ACTIVATION_H__
diff --git a/compiler/luci/pass/src/QuantizeBias.cpp b/compiler/luci/pass/src/QuantizeBias.cpp
new file mode 100644
index 000000000..aa496232a
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizeBias.cpp
@@ -0,0 +1,300 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "QuantizeBias.h"
+#include "QuantizationUtils.h"
+
+#include <luci/Service/Nodes/CircleConst.h>
+#include <luci/Log.h>
+
+#include <algorithm>
+#include <cmath>
+
+using namespace luci;
+
+namespace
+{
+
+// struct to carry Input/Weights/Bias
+struct IWB
+{
+ CircleNode *input = nullptr;
+ CircleNode *weights = nullptr;
+ CircleConst *bias = nullptr;
+
+ IWB(loco::Node *i, loco::Node *w, loco::Node *b)
+ {
+ input = dynamic_cast<luci::CircleNode *>(i);
+ weights = dynamic_cast<luci::CircleNode *>(w);
+ bias = dynamic_cast<luci::CircleConst *>(b);
+ }
+
+ // Return true if bias can be quantized with valid input an weights
+ operator bool()
+ {
+ if (bias == nullptr || is_quantized(bias))
+ return false;
+ if (input == nullptr || weights == nullptr)
+ return false;
+ return true;
+ }
+};
+
+// Create a new const node from an existing node.
+// The new node has the following characteristics
+// type: T
+// shape: same with 'node' (given as an argument)
+// buffer size: 'size' (given as an argument)
+// Note that contents are not filled in this function.
+template <loco::DataType T>
+luci::CircleConst *create_empty_const_from(luci::CircleConst *node, uint32_t size)
+{
+ auto new_node = node->graph()->nodes()->create<CircleConst>();
+ // TODO: We don't have any naming convention for quantized nodes yet.
+ // Fix this when we have one.
+ new_node->name(node->name());
+ new_node->dtype(T);
+ new_node->rank(node->rank());
+ for (uint32_t i = 0; i < node->rank(); i++)
+ new_node->dim(i).set(node->dim(i).value());
+
+ new_node->size<T>(size);
+ new_node->shape_status(luci::ShapeStatus::VALID);
+
+ return new_node;
+}
+
+CircleConst *asym_quant_bias_per_layer(CircleConst *node, float input_scale, float weight_scale,
+ float *scaling_factor, int64_t *zp)
+{
+ float scale = input_scale * weight_scale;
+ const float scaling_factor_inv = (scale == 0) ? 0 : 1.0 / scale;
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+ std::vector<int32_t> quantized_values(size);
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ quantized_values[i] =
+ static_cast<int32_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
+ }
+
+ auto new_bias = create_empty_const_from<loco::DataType::S32>(node, size);
+
+ const int32_t kMinScale = std::numeric_limits<int32_t>::lowest();
+ const int32_t kMaxScale = std::numeric_limits<int32_t>::max();
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ new_bias->at<loco::DataType::S32>(i) =
+ std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
+ }
+ *scaling_factor = scale;
+ *zp = 0;
+
+ return new_bias;
+}
+
+CircleConst *quant_bias_per_channel(CircleConst *node, float input_scale,
+ std::vector<float> &weight_scale,
+ std::vector<float> &scaling_factor, std::vector<int64_t> &zp)
+{
+ float scaling_factor_inv{0};
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+ std::vector<int32_t> quantized_values(size);
+
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ scaling_factor[i] = input_scale * weight_scale[i];
+ scaling_factor_inv = (scaling_factor[i] == 0) ? 0 : 1.0 / scaling_factor[i];
+ quantized_values[i] =
+ static_cast<int32_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
+ zp[i] = 0;
+ }
+
+ auto new_bias = create_empty_const_from<loco::DataType::S32>(node, size);
+
+ const int32_t kMinScale = std::numeric_limits<int32_t>::lowest();
+ const int32_t kMaxScale = std::numeric_limits<int32_t>::max();
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ new_bias->at<loco::DataType::S32>(i) =
+ std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
+ }
+
+ return new_bias;
+}
+
+CircleConst *int16_quant_bias_per_channel(CircleConst *node, float input_scale,
+ std::vector<float> &weight_scale,
+ std::vector<float> &scaling_factor,
+ std::vector<int64_t> &zp)
+{
+ float scaling_factor_inv{0};
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+ std::vector<int64_t> quantized_values(size);
+
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ scaling_factor[i] = input_scale * weight_scale[i];
+ scaling_factor_inv = (scaling_factor[i] == 0) ? 0 : 1.0 / scaling_factor[i];
+ quantized_values[i] =
+ static_cast<int64_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
+ zp[i] = 0;
+ }
+
+ auto new_bias = create_empty_const_from<loco::DataType::S64>(node, size);
+
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ new_bias->at<loco::DataType::S64>(i) = quantized_values[i];
+ }
+
+ return new_bias;
+}
+
+} // namespace
+
+namespace luci
+{
+
+// Return a quantized bias node
+CircleConst *QuantizeBias::quantized_bias(CircleNode *input, const CircleNode *weight,
+ CircleNode *bias)
+{
+ auto const_bias = loco::must_cast<luci::CircleConst *>(bias);
+ assert(const_bias->dtype() == loco::DataType::FLOAT32);
+
+ // If input is const, it is quantized here, not in QuantizeActivation
+ if (auto const_input = dynamic_cast<luci::CircleConst *>(input))
+ {
+ quant_const(const_input, output_type);
+ }
+
+ CircleConst *new_bias = nullptr;
+
+ if (granularity == QuantizationGranularity::ChannelWise)
+ {
+ auto input_q = input->quantparam();
+ assert(input_q);
+ assert(input_q->scale.size() == 1); // input scale's layer-wise
+ auto input_scale = input_q->scale[0];
+
+ assert(weight->quantparam() != nullptr); // weight scale's channel-wise
+ auto weight_scale = weight->quantparam()->scale;
+
+ uint32_t size = const_bias->size<loco::DataType::FLOAT32>();
+ assert(size == weight_scale.size());
+ std::vector<float> scaling_factor(size);
+ std::vector<int64_t> zp(size);
+
+ if (output_type == loco::DataType::U8)
+ {
+ new_bias = quant_bias_per_channel(const_bias, input_scale, weight_scale, scaling_factor, zp);
+ }
+ else if (output_type == loco::DataType::S16)
+ {
+ new_bias =
+ int16_quant_bias_per_channel(const_bias, input_scale, weight_scale, scaling_factor, zp);
+ }
+ else
+ {
+ throw std::runtime_error("Unsupported quantization type.");
+ }
+
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->scale = scaling_factor;
+ quantparam->zerop = zp;
+ assert(new_bias->quantparam() == nullptr); // bias should not be quantized before
+ new_bias->quantparam(std::move(quantparam));
+
+ return new_bias;
+ }
+ else
+ {
+ auto input_q = input->quantparam();
+ assert(input_q);
+ assert(input_q->scale.size() == 1); // Only support per-layer quant
+ auto input_scale = input_q->scale[0];
+
+ auto weight_q = weight->quantparam();
+ assert(weight_q);
+ assert(weight_q->scale.size() == 1); // Only support per-layer quant
+ auto weight_scale = weight_q->scale[0];
+
+ float scaling_factor{0};
+ int64_t zp{0};
+ new_bias =
+ asym_quant_bias_per_layer(const_bias, input_scale, weight_scale, &scaling_factor, &zp);
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->scale.push_back(scaling_factor);
+ quantparam->zerop.push_back(zp);
+ assert(new_bias->quantparam() == nullptr); // bias should not be quantized before
+ new_bias->quantparam(std::move(quantparam));
+
+ return new_bias;
+ }
+}
+
+void QuantizeBias::visit(luci::CircleConv2D *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeBias QuantizeBias::visit node: " << node->name() << std::endl;
+
+ if (auto iwb = IWB(node->input(), node->filter(), node->bias()))
+ {
+ auto new_bias = quantized_bias(iwb.input, iwb.weights, iwb.bias);
+ node->bias(new_bias);
+ }
+}
+
+void QuantizeBias::visit(luci::CircleDepthwiseConv2D *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeBias QuantizeBias::visit node: " << node->name() << std::endl;
+
+ if (auto iwb = IWB(node->input(), node->filter(), node->bias()))
+ {
+ auto new_bias = quantized_bias(iwb.input, iwb.weights, iwb.bias);
+ node->bias(new_bias);
+ }
+}
+
+void QuantizeBias::visit(luci::CircleTransposeConv *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeBias QuantizeBias::visit node: " << node->name() << std::endl;
+
+ if (auto iwb = IWB(node->outBackprop(), node->filter(), node->bias()))
+ {
+ auto new_bias = quantized_bias(iwb.input, iwb.weights, iwb.bias);
+ node->bias(new_bias);
+ }
+}
+
+void QuantizeBias::visit(luci::CircleFullyConnected *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeBias visit node: " << node->name() << std::endl;
+
+ if (auto iwb = IWB(node->input(), node->weights(), node->bias()))
+ {
+ auto new_bias = quantized_bias(iwb.input, iwb.weights, iwb.bias);
+ node->bias(new_bias);
+ }
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/QuantizeBias.h b/compiler/luci/pass/src/QuantizeBias.h
new file mode 100644
index 000000000..8de09df72
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizeBias.h
@@ -0,0 +1,56 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_QUANTIZE_BIAS_H__
+#define __LUCI_QUANTIZE_BIAS_H__
+
+#include <luci/Pass/QuantizationParameters.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+namespace luci
+{
+
+/**
+ * @brief QuantizeBias quantizes tensors for bias
+ * @details Use input/weights scale to quantize values
+ */
+struct QuantizeBias final : public luci::CircleNodeMutableVisitor<void>
+{
+ QuantizeBias(loco::DataType input, loco::DataType output, QuantizationGranularity gr)
+ : input_type(input), output_type(output), granularity(gr)
+ {
+ }
+
+ loco::DataType input_type;
+ loco::DataType output_type;
+ QuantizationGranularity granularity;
+
+private:
+ // Return a quantized bias node
+ CircleConst *quantized_bias(CircleNode *input, const CircleNode *weight, CircleNode *bias);
+
+ void visit(luci::CircleConv2D *node);
+ void visit(luci::CircleDepthwiseConv2D *node);
+ void visit(luci::CircleTransposeConv *node);
+ void visit(luci::CircleFullyConnected *node);
+
+ // Default behavior
+ void visit(luci::CircleNode *) {}
+};
+
+} // namespace luci
+
+#endif // __LUCI_QUANTIZE_BIAS_H__
diff --git a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp
index c8ad87e3d..c9b35e0be 100644
--- a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp
+++ b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp
@@ -16,9 +16,11 @@
#include "luci/Pass/QuantizeDequantizeWeightsPass.h"
#include "QuantizationUtils.h"
+#include "helpers/LayerInfoMap.h"
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Service/Nodes/CircleConst.h>
#include <luci/Log.h>
#include <loco/IR/TensorShape.h>
@@ -251,7 +253,7 @@ void asymmetric_wdequant_with_minmax_per_layer(CircleConst *node, float scaling_
* @brief QuantizeDequantizeWeights quantizes and dequantizes tensors for weights
* @details Find min/max values on the fly, quantize the model, and dequantize the model
*/
-struct QuantizeDequantizeWeights final : public luci::CircleNodeMutableVisitor<bool>
+struct QuantizeDequantizeWeights final : public luci::CircleNodeMutableVisitor<void>
{
QuantizeDequantizeWeights(loco::DataType input, loco::DataType output,
QuantizationGranularity granularity)
@@ -263,88 +265,164 @@ struct QuantizeDequantizeWeights final : public luci::CircleNodeMutableVisitor<b
loco::DataType output_type;
QuantizationGranularity granularity;
- // Quantize and dequantize input tensors of each node
- bool visit(luci::CircleNode *node)
+private:
+ // Fake quantize weights (Only u8 quantization is supported for LWQ)
+ void fake_quantize_lwq(luci::CircleConst *weights) const
{
- assert(output_type == loco::DataType::U8 || output_type == loco::DataType::S16);
- LOGGER(l);
- INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl;
- auto arity = node->arity();
- for (uint32_t i = 0; i < arity; i++)
+ assert(output_type == loco::DataType::U8); // FIX_CALLER_UNLESS
+
+ // Find min/max per layer
+ float min = std::numeric_limits<float>::max();
+ float max = std::numeric_limits<float>::lowest();
+ for (uint32_t i = 0; i < weights->size<loco::DataType::FLOAT32>(); i++)
{
- auto input_node = node->arg(i);
- auto circle_node = loco::must_cast<luci::CircleNode *>(input_node);
+ auto data = weights->at<loco::DataType::FLOAT32>(i);
+ min = data < min ? data : min;
+ max = data > max ? data : max;
+ }
+ float scaling_factor{0};
+ int64_t zp{0};
+ float nudged_min{0};
+ float nudged_max{0};
+
+ asymmetric_wquant_with_minmax_per_layer(weights, min, max, scaling_factor, zp, nudged_min,
+ nudged_max);
+ asymmetric_wdequant_with_minmax_per_layer(weights, scaling_factor, nudged_min);
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->min.push_back(nudged_min);
+ quantparam->max.push_back(nudged_max);
+ quantparam->scale.push_back(scaling_factor);
+ quantparam->zerop.push_back(zp);
+ weights->quantparam(std::move(quantparam));
+ }
- // Check if this is already quantized
- if (is_quantized(circle_node))
- continue;
+private:
+ // Fake quantize weights (u8/s16 quantization are supported for CWQ)
+ void fake_quantize_cwq(luci::CircleConst *weights) const
+ {
+ assert(output_type == loco::DataType::U8 ||
+ output_type == loco::DataType::S16); // FIX_CALLER_UNLESS
- if (is_weights(circle_node))
- {
- auto circle_const = loco::must_cast<luci::CircleConst *>(circle_node);
+ // Find min/max per channel
+ std::vector<float> min;
+ std::vector<float> max;
- // Find min/max per channel-wise
- if (granularity == QuantizationGranularity::ChannelWise)
- {
- std::vector<float> min;
- std::vector<float> max;
-
- cal_minmax_per_channel(circle_const, min, max);
-
- std::vector<float> nudged_min(min.size());
- std::vector<float> nudged_max(min.size());
- std::vector<float> scaling_factor(min.size());
- std::vector<int64_t> zp(min.size());
-
- if (output_type == loco::DataType::U8)
- {
- asymmetric_wquant_per_channel(circle_const, min, max, scaling_factor, zp, nudged_min,
- nudged_max);
- asymmetric_wdequant_per_channel(circle_const, scaling_factor, nudged_min);
- }
- else
- {
- sym_wquant_per_channel(circle_const, min, max, scaling_factor, zp, nudged_min,
- nudged_max);
- sym_wdequant_per_channel(circle_const, scaling_factor);
- }
-
- auto quantparam = std::make_unique<CircleQuantParam>();
- quantparam->min = nudged_min;
- quantparam->max = nudged_max;
- quantparam->scale = scaling_factor;
- quantparam->zerop = zp;
- circle_node->quantparam(std::move(quantparam));
- }
- // Find min/max per layer-wise
- else
- {
- float min = std::numeric_limits<float>::max();
- float max = std::numeric_limits<float>::lowest();
- for (uint32_t i = 0; i < circle_const->size<loco::DataType::FLOAT32>(); i++)
- {
- auto data = circle_const->at<loco::DataType::FLOAT32>(i);
- min = data < min ? data : min;
- max = data > max ? data : max;
- }
- float scaling_factor{0};
- int64_t zp{0};
- float nudged_min{0};
- float nudged_max{0};
-
- asymmetric_wquant_with_minmax_per_layer(circle_const, min, max, scaling_factor, zp,
- nudged_min, nudged_max);
- asymmetric_wdequant_with_minmax_per_layer(circle_const, scaling_factor, nudged_min);
- auto quantparam = std::make_unique<CircleQuantParam>();
- quantparam->min.push_back(nudged_min);
- quantparam->max.push_back(nudged_max);
- quantparam->scale.push_back(scaling_factor);
- quantparam->zerop.push_back(zp);
- circle_node->quantparam(std::move(quantparam));
- }
- }
+ cal_minmax_per_channel(weights, min, max);
+
+ std::vector<float> nudged_min(min.size());
+ std::vector<float> nudged_max(min.size());
+ std::vector<float> scaling_factor(min.size());
+ std::vector<int64_t> zp(min.size());
+
+ if (output_type == loco::DataType::U8)
+ {
+ asymmetric_wquant_per_channel(weights, min, max, scaling_factor, zp, nudged_min, nudged_max);
+ asymmetric_wdequant_per_channel(weights, scaling_factor, nudged_min);
+ }
+ else
+ {
+ sym_wquant_per_channel(weights, min, max, scaling_factor, zp, nudged_min, nudged_max);
+ sym_wdequant_per_channel(weights, scaling_factor);
}
- return false;
+
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->min = nudged_min;
+ quantparam->max = nudged_max;
+ quantparam->scale = scaling_factor;
+ quantparam->zerop = zp;
+ weights->quantparam(std::move(quantparam));
+ }
+
+private:
+ void fake_quantize(luci::CircleConst *weights) const
+ {
+ switch (granularity)
+ {
+ case luci::QuantizationGranularity::ChannelWise:
+ fake_quantize_cwq(weights);
+ break;
+ case luci::QuantizationGranularity::LayerWise:
+ fake_quantize_lwq(weights);
+ break;
+ default:
+ throw std::invalid_argument("Unsupported granularity");
+ }
+ }
+
+private:
+ // Check if
+ // 1. node is const
+ // 2. node was not quantized
+ bool is_quantizable(loco::Node *node)
+ {
+ auto const_node = dynamic_cast<luci::CircleConst *>(node);
+ if (not const_node)
+ return false;
+
+ // Skip if this is already quantized
+ if (is_quantized(const_node))
+ return false;
+
+ return true;
+ }
+
+ // Default behavior (Do nothing)
+ void visit(luci::CircleNode *) {}
+
+ void visit(luci::CircleConv2D *node)
+ {
+ LOGGER(l);
+ INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl;
+
+ if (not is_quantizable(node->filter()))
+ return;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
+ auto new_weights = luci::clone(weights);
+ node->filter(new_weights);
+ fake_quantize(new_weights);
+ }
+
+ void visit(luci::CircleDepthwiseConv2D *node)
+ {
+ LOGGER(l);
+ INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl;
+
+ if (not is_quantizable(node->filter()))
+ return;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
+ auto new_weights = luci::clone(weights);
+ node->filter(new_weights);
+ fake_quantize(new_weights);
+ }
+
+ void visit(luci::CircleTransposeConv *node)
+ {
+ LOGGER(l);
+ INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl;
+
+ if (not is_quantizable(node->filter()))
+ return;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
+ auto new_weights = luci::clone(weights);
+ node->filter(new_weights);
+ fake_quantize(new_weights);
+ }
+
+ void visit(luci::CircleFullyConnected *node)
+ {
+ LOGGER(l);
+ INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl;
+
+ if (not is_quantizable(node->weights()))
+ return;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->weights());
+ auto new_weights = luci::clone(weights);
+ node->weights(new_weights);
+ fake_quantize(new_weights);
}
};
@@ -355,11 +433,36 @@ bool QuantizeDequantizeWeightsPass::run(loco::Graph *g)
LOGGER(l);
INFO(l) << "QuantizeDequantizeWeightsPass Start" << std::endl;
+ auto info_by_name = layer_info_map(g, _ctx->layers_info);
+
+ auto quantize_dtype = [&](const luci::CircleNode *node) {
+ auto iter = info_by_name.find(node->name());
+
+ // Return designated quantization dtype
+ if (iter != info_by_name.end())
+ return iter->second.dtype;
+
+ // Return default quantization dtype
+ return _ctx->output_model_dtype;
+ };
+
+ auto quantize_granularity = [&](const luci::CircleNode *node) {
+ auto iter = info_by_name.find(node->name());
+
+ // Return designated quantization granularity
+ if (iter != info_by_name.end())
+ return iter->second.granularity;
+
+ // Return default quantization granularity
+ return _ctx->granularity;
+ };
+
// Quantize weights
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
- QuantizeDequantizeWeights qw(_input_model_dtype, _output_model_dtype, _granularity);
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ QuantizeDequantizeWeights qw(_ctx->input_model_dtype, quantize_dtype(circle_node),
+ quantize_granularity(circle_node));
circle_node->accept(&qw);
}
diff --git a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.test.cpp b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.test.cpp
index f226253c2..15f5ca7ac 100644
--- a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.test.cpp
+++ b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.test.cpp
@@ -25,3 +25,17 @@ TEST(QuantizeDequantizeWeightsPassTest, name)
auto const name = pass.name();
ASSERT_NE(nullptr, name);
}
+
+TEST(QuantizeDequantizeWeightsPassTest, name_ctx)
+{
+ auto ctx = std::make_unique<luci::QuantizeDequantizeWeightsPass::Context>();
+ {
+ ctx->input_model_dtype = loco::DataType::FLOAT32;
+ ctx->output_model_dtype = loco::DataType::U8;
+ ctx->granularity = luci::QuantizationGranularity::LayerWise;
+ }
+
+ luci::QuantizeDequantizeWeightsPass pass(std::move(ctx));
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
diff --git a/compiler/luci/pass/src/QuantizePreCheckerPass.cpp b/compiler/luci/pass/src/QuantizePreCheckerPass.cpp
new file mode 100644
index 000000000..4b3b7e330
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizePreCheckerPass.cpp
@@ -0,0 +1,119 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/QuantizePreCheckerPass.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+#include <luci/Log.h>
+
+namespace luci
+{
+
+namespace
+{
+
+void check_const_opcode(luci::CircleNode *node)
+{
+ if (node == nullptr)
+ return;
+
+ if (node->opcode() != luci::CircleOpcode::CIRCLECONST and
+ node->opcode() != luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE)
+ {
+ throw std::runtime_error("Unsupported non const input " + node->name());
+ }
+}
+
+struct ConstInputChecker final : public luci::CircleNodeMutableVisitor<void>
+{
+// INPUT_NAME is name for input const for current NODE
+#define CHECK_NODE_WITH_ONE_INPUT_CONST(NODE, INPUT_NAME) \
+ void visit(NODE *node) \
+ { \
+ const auto input = dynamic_cast<luci::CircleNode *>(node->INPUT_NAME()); \
+ check_const_opcode(input); \
+ }
+
+// INPUT_NAME_1 and INPUT_NAME_2 are names for input const for current NODE
+#define CHECK_NODE_WITH_TWO_INPUT_CONST(NODE, INPUT_NAME_1, INPUT_NAME_2) \
+ void visit(NODE *node) \
+ { \
+ const auto input_1 = dynamic_cast<luci::CircleNode *>(node->INPUT_NAME_1()); \
+ const auto input_2 = dynamic_cast<luci::CircleNode *>(node->INPUT_NAME_2()); \
+ \
+ check_const_opcode(input_1); \
+ check_const_opcode(input_2); \
+ }
+
+// INPUT_NAME_1, INPUT_NAME_2 and INPUT_NAME_3 are names for input const for current NODE
+#define CHECK_NODE_WITH_THREE_INPUT_CONST(NODE, INPUT_NAME_1, INPUT_NAME_2, INPUT_NAME_3) \
+ void visit(NODE *node) \
+ { \
+ const auto input_1 = dynamic_cast<luci::CircleNode *>(node->INPUT_NAME_1()); \
+ const auto input_2 = dynamic_cast<luci::CircleNode *>(node->INPUT_NAME_2()); \
+ const auto input_3 = dynamic_cast<luci::CircleNode *>(node->INPUT_NAME_3()); \
+ \
+ check_const_opcode(input_1); \
+ check_const_opcode(input_2); \
+ check_const_opcode(input_3); \
+ }
+
+ // Skip other circle node
+ void visit(luci::CircleNode *) {}
+
+ // Ops that receive one const nodes as inputs
+ CHECK_NODE_WITH_ONE_INPUT_CONST(luci::CirclePRelu, alpha)
+
+ // Ops that receive two const node as an inputs
+ CHECK_NODE_WITH_TWO_INPUT_CONST(luci::CircleConv2D, filter, bias)
+ CHECK_NODE_WITH_TWO_INPUT_CONST(luci::CircleDepthwiseConv2D, filter, bias)
+ CHECK_NODE_WITH_TWO_INPUT_CONST(luci::CircleFullyConnected, weights, bias)
+ CHECK_NODE_WITH_TWO_INPUT_CONST(luci::CircleInstanceNorm, gamma, beta)
+
+ // Ops that receive three const nodes as an inputs
+ CHECK_NODE_WITH_THREE_INPUT_CONST(luci::CircleTransposeConv, inputSizes, filter, bias)
+
+#undef CHECK_NODE_WITH_ONE_INPUT_CONST
+#undef CHECK_NODE_WITH_TWO_INPUT_CONST
+#undef CHECK_NODE_WITH_THREE_INPUT_CONST
+};
+
+} // namespace
+
+/**
+ * Verify the input model has the form acceptable by quantizer
+ */
+bool QuantizePreCheckerPass::run(loco::Graph *g)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizePreCheckerPass Start" << std::endl;
+
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ // Check const inputs
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ ConstInputChecker checker{};
+ circle_node->accept(&checker);
+ }
+
+ INFO(l) << "QuantizePreCheckerPass End" << std::endl;
+
+ return false; // one time run
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp b/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp
new file mode 100644
index 000000000..788353cd8
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp
@@ -0,0 +1,401 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/QuantizePreCheckerPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+class SimpleConv2DGraph
+{
+public:
+ SimpleConv2DGraph(bool make_valid)
+ {
+ conv2d_node = g.nodes()->create<luci::CircleConv2D>();
+ input_1 = g.nodes()->create<luci::CircleInput>();
+ filter = g.nodes()->create<luci::CircleConst>();
+
+ conv2d_node->input(input_1);
+ conv2d_node->filter(filter);
+
+ if (make_valid)
+ {
+ bias = g.nodes()->create<luci::CircleConst>();
+ conv2d_node->bias(bias);
+ }
+ else
+ {
+ input_2 = g.nodes()->create<luci::CircleInput>();
+ conv2d_node->bias(input_2);
+ }
+
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ output->from(conv2d_node);
+ }
+
+public:
+ loco::Graph g;
+
+private:
+ luci::CircleConv2D *conv2d_node = nullptr;
+ luci::CircleInput *input_1 = nullptr;
+ luci::CircleInput *input_2 = nullptr;
+ luci::CircleConst *filter = nullptr;
+ luci::CircleConst *bias = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+class SimpleDepthConv2DGraph
+{
+public:
+ SimpleDepthConv2DGraph(bool make_valid)
+ {
+ depth_conv2d_node = g.nodes()->create<luci::CircleDepthwiseConv2D>();
+ input_1 = g.nodes()->create<luci::CircleInput>();
+ filter = g.nodes()->create<luci::CircleConst>();
+
+ depth_conv2d_node->input(input_1);
+ depth_conv2d_node->filter(filter);
+
+ if (make_valid)
+ {
+ bias = g.nodes()->create<luci::CircleConst>();
+ depth_conv2d_node->bias(bias);
+ }
+ else
+ {
+ input_2 = g.nodes()->create<luci::CircleInput>();
+ depth_conv2d_node->bias(input_2);
+ }
+
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ output->from(depth_conv2d_node);
+ }
+
+public:
+ loco::Graph g;
+
+private:
+ luci::CircleDepthwiseConv2D *depth_conv2d_node = nullptr;
+ luci::CircleInput *input_1 = nullptr;
+ luci::CircleInput *input_2 = nullptr;
+ luci::CircleConst *filter = nullptr;
+ luci::CircleConst *bias = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+class SimpleFCGraph
+{
+public:
+ SimpleFCGraph(bool make_valid)
+ {
+ fc_node = g.nodes()->create<luci::CircleFullyConnected>();
+ input_1 = g.nodes()->create<luci::CircleInput>();
+ weights = g.nodes()->create<luci::CircleConst>();
+
+ fc_node->input(input_1);
+ fc_node->weights(weights);
+
+ if (make_valid)
+ {
+ bias = g.nodes()->create<luci::CircleConst>();
+ fc_node->bias(bias);
+ }
+ else
+ {
+ input_2 = g.nodes()->create<luci::CircleInput>();
+ fc_node->bias(input_2);
+ }
+
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ output->from(fc_node);
+ }
+
+public:
+ loco::Graph g;
+
+private:
+ luci::CircleFullyConnected *fc_node = nullptr;
+ luci::CircleInput *input_1 = nullptr;
+ luci::CircleInput *input_2 = nullptr;
+ luci::CircleConst *weights = nullptr;
+ luci::CircleConst *bias = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+class SimpleInstanceNormGraph
+{
+public:
+ SimpleInstanceNormGraph(bool make_valid)
+ {
+ instance_norm_node = g.nodes()->create<luci::CircleInstanceNorm>();
+ input_1 = g.nodes()->create<luci::CircleInput>();
+ gamma = g.nodes()->create<luci::CircleConst>();
+
+ instance_norm_node->input(input_1);
+ instance_norm_node->gamma(gamma);
+
+ if (make_valid)
+ {
+ beta = g.nodes()->create<luci::CircleConst>();
+ instance_norm_node->beta(beta);
+ }
+ else
+ {
+ input_2 = g.nodes()->create<luci::CircleInput>();
+ instance_norm_node->beta(input_2);
+ }
+
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ output->from(instance_norm_node);
+ }
+
+public:
+ loco::Graph g;
+
+private:
+ luci::CircleInstanceNorm *instance_norm_node = nullptr;
+ luci::CircleInput *input_1 = nullptr;
+ luci::CircleInput *input_2 = nullptr;
+ luci::CircleConst *gamma = nullptr;
+ luci::CircleConst *beta = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+class SimpleTransposeConvGraph
+{
+public:
+ SimpleTransposeConvGraph(bool make_valid)
+ {
+ transpose_conv = g.nodes()->create<luci::CircleTransposeConv>();
+ input_1 = g.nodes()->create<luci::CircleInput>();
+
+ input_sizes = g.nodes()->create<luci::CircleConst>();
+ filter = g.nodes()->create<luci::CircleConst>();
+
+ transpose_conv->outBackprop(input_1);
+ transpose_conv->filter(filter);
+ transpose_conv->inputSizes(input_sizes);
+
+ if (make_valid)
+ {
+ bias = g.nodes()->create<luci::CircleConst>();
+ transpose_conv->bias(bias);
+ }
+ else
+ {
+ input_2 = g.nodes()->create<luci::CircleInput>();
+ transpose_conv->bias(input_2);
+ }
+
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ output->from(transpose_conv);
+ }
+
+public:
+ loco::Graph g;
+
+private:
+ luci::CircleTransposeConv *transpose_conv = nullptr;
+ luci::CircleInput *input_1 = nullptr;
+ luci::CircleInput *input_2 = nullptr;
+ luci::CircleConst *input_sizes = nullptr;
+ luci::CircleConst *filter = nullptr;
+ luci::CircleConst *bias = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+class SimplePReluGraph
+{
+public:
+ SimplePReluGraph(bool make_valid)
+ {
+ prelu = g.nodes()->create<luci::CirclePRelu>();
+ input_1 = g.nodes()->create<luci::CircleInput>();
+
+ prelu->input(input_1);
+
+ if (make_valid)
+ {
+ alpha = g.nodes()->create<luci::CircleConst>();
+ prelu->alpha(alpha);
+ }
+ else
+ {
+ input_2 = g.nodes()->create<luci::CircleInput>();
+ prelu->alpha(input_2);
+ }
+
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ output->from(prelu);
+ }
+
+public:
+ loco::Graph g;
+
+private:
+ luci::CirclePRelu *prelu = nullptr;
+ luci::CircleInput *input_1 = nullptr;
+ luci::CircleInput *input_2 = nullptr;
+ luci::CircleConst *alpha = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+TEST(QuantizePreCheckerPassTest, name)
+{
+ luci::QuantizePreCheckerPass pass{};
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+// Test Conv2d
+TEST(QuantizePreCheckerPassTest, conv2d)
+{
+ SimpleConv2DGraph valid_graph(true);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_NO_THROW(checker.run(&valid_graph.g));
+}
+
+TEST(QuantizePreCheckerPassTest, conv2d_NEG)
+{
+ SimpleConv2DGraph invalid_graph(false);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
+}
+
+// Test DepthwiseConv2d
+TEST(QuantizePreCheckerPassTest, depthwise_conv2d)
+{
+ SimpleDepthConv2DGraph valid_graph(true);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_NO_THROW(checker.run(&valid_graph.g));
+}
+
+TEST(QuantizePreCheckerPassTest, depthwise_conv2d_NEG)
+{
+ SimpleDepthConv2DGraph invalid_graph(false);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
+}
+
+// Test FullyConnected
+TEST(QuantizePreCheckerPassTest, fully_connected)
+{
+ SimpleFCGraph valid_graph(true);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_NO_THROW(checker.run(&valid_graph.g));
+}
+
+TEST(QuantizePreCheckerPassTest, fully_connected_NEG)
+{
+ SimpleFCGraph invalid_graph(false);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
+}
+
+// Test InstanceNorm
+TEST(QuantizePreCheckerPassTest, instance_norm)
+{
+ SimpleInstanceNormGraph valid_graph(true);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_NO_THROW(checker.run(&valid_graph.g));
+}
+
+TEST(QuantizePreCheckerPassTest, instance_norm_NEG)
+{
+ SimpleInstanceNormGraph invalid_graph(false);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
+}
+
+// Test TransposeConv
+TEST(QuantizePreCheckerPassTest, transpose_conv)
+{
+ SimpleTransposeConvGraph valid_graph(true);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_NO_THROW(checker.run(&valid_graph.g));
+}
+
+TEST(QuantizePreCheckerPassTest, transpose_conv_NEG)
+{
+ SimpleTransposeConvGraph invalid_graph(false);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
+}
+
+// Test PRelu
+TEST(QuantizePreCheckerPassTest, prelu)
+{
+ SimplePReluGraph valid_graph(true);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_NO_THROW(checker.run(&valid_graph.g));
+}
+
+TEST(QuantizePreCheckerPassTest, prelu_NEG)
+{
+ SimplePReluGraph invalid_graph(false);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
+}
diff --git a/compiler/luci/pass/src/QuantizeWeights.cpp b/compiler/luci/pass/src/QuantizeWeights.cpp
new file mode 100644
index 000000000..11322ab44
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizeWeights.cpp
@@ -0,0 +1,394 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "QuantizeWeights.h"
+#include "QuantizationUtils.h"
+
+#include <luci/Service/Nodes/CircleConst.h>
+#include <luci/Log.h>
+
+#include <cmath>
+#include <vector>
+#include <functional>
+
+using namespace luci;
+
+namespace
+{
+
+using IterFunc = std::function<void(uint32_t *, loco::TensorShape &, int32_t)>;
+
+void iterate_per_channel(CircleConst *node, int32_t &channel_dim_index, IterFunc func)
+{
+ loco::TensorShape dimension;
+ dimension.rank(4);
+ uint32_t indices[4] = {
+ 0,
+ };
+
+ if (!get_channel_dim_index(node, dimension, channel_dim_index))
+ {
+ assert(false);
+ return;
+ }
+
+ for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++)
+ {
+ for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++)
+ {
+ for (indices[2] = 0; indices[2] < dimension.dim(2).value(); indices[2]++)
+ {
+ for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++)
+ {
+ func(indices, dimension, channel_dim_index);
+ }
+ }
+ }
+ }
+}
+
+void asym_wquant_per_channel(CircleConst *node, std::vector<float> &min,
+ std::vector<float> &scaling_factor, int32_t &channel_dim_index)
+{
+ assert(node->dtype() == loco::DataType::FLOAT32);
+
+ const int32_t kMinScale = 0;
+ const int32_t kMaxScale = 255;
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+ std::vector<int32_t> quantized_values(size);
+
+ auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int32_t channel_dim_index) {
+ int channel_idx = indices[channel_dim_index];
+ const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
+ auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
+ quantized_values[cal_offset(dimension, indices)] =
+ static_cast<int32_t>(std::round((data - min[channel_idx]) * scaling_factor_inv));
+ };
+
+ iterate_per_channel(node, channel_dim_index, quantize);
+
+ node->dtype(loco::DataType::U8); // change the type of tensor
+ node->size<loco::DataType::U8>(size); // resize tensor
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ node->at<loco::DataType::U8>(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
+ }
+}
+
+void sym_wquant_per_channel(CircleConst *node, std::vector<float> &scaling_factor,
+ int32_t &channel_dim_index)
+{
+ assert(node->dtype() == loco::DataType::FLOAT32);
+
+ const int32_t kMaxScale = std::numeric_limits<int16_t>::max();
+ const int32_t kMinScale = -kMaxScale;
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+ std::vector<int32_t> quantized_values(size);
+
+ auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int32_t channel_dim_index) {
+ int channel_idx = indices[channel_dim_index];
+ const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
+ auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
+ quantized_values[cal_offset(dimension, indices)] =
+ static_cast<int32_t>(std::round(data * scaling_factor_inv));
+ };
+
+ iterate_per_channel(node, channel_dim_index, quantize);
+
+ node->dtype(loco::DataType::S16); // change the type of tensor
+ node->size<loco::DataType::S16>(size); // resize tensor
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ node->at<loco::DataType::S16>(i) =
+ std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
+ }
+}
+
+void asym_wquant_per_layer(CircleConst *node, float min, float scaling_factor)
+{
+ const int32_t kMinScale = 0;
+ const int32_t kMaxScale = 255;
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+
+ const float scaling_factor_inv = 1.0 / scaling_factor;
+ std::vector<int32_t> quantized_values(size);
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ auto data = node->at<loco::DataType::FLOAT32>(i);
+ quantized_values[i] = static_cast<int32_t>(std::round((data - min) * scaling_factor_inv));
+ }
+
+ node->dtype(loco::DataType::U8); // change the type of tensor
+ node->size<loco::DataType::U8>(size); // resize tensor
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ node->at<loco::DataType::U8>(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
+ }
+}
+
+// Quantize const per channel
+//
+// The last dimension of const is the same as the dimension of channel
+// And the rest of the const dimensions should be 1
+// So, a 'single value' is quantized per channel
+//
+// Quantization spec (f: fp value, q: quantized value)
+//
+// uint8
+// Positive f: f = f * (q - 0) [q = 1, scale = f, zp = 0]
+// Negative f: f = (-f) * (q - 1) [q = 0, scale = -f, zp = 1]
+//
+// int16
+// Positive f: f = f * (q - 0) [q = 1, scale = f, zp = 0]
+// Negative f: f = (-f) * (q - 0) [q = -1, scale = -f, zp = 0]
+void quant_const_per_channel(CircleConst *node, loco::DataType quant_type)
+{
+ assert(node->dtype() == loco::DataType::FLOAT32);
+ assert(node->rank() > 0);
+
+ for (uint32_t i = 0; i < node->rank() - 1; i++)
+ {
+ // Caller should call this function when the below condition is satisfied
+ if (node->dim(i).value() != 1)
+ throw std::runtime_error("Non-channel dimension of const node must be 1");
+ }
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+ assert(size == node->dim(node->rank() - 1).value());
+
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->quantized_dimension = node->rank() - 1;
+ std::vector<int32_t> quantized_data(size);
+
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ auto data = node->at<loco::DataType::FLOAT32>(i);
+ if (quant_type == loco::DataType::U8)
+ {
+ if (data >= 0)
+ {
+ quantparam->scale.push_back(data);
+ quantparam->zerop.push_back(0);
+ quantized_data[i] = 1;
+ }
+ else
+ {
+ quantparam->scale.push_back(-data);
+ quantparam->zerop.push_back(1);
+ quantized_data[i] = 0;
+ }
+ }
+ else if (quant_type == loco::DataType::S16)
+ {
+ if (data >= 0)
+ {
+ quantparam->scale.push_back(data);
+ quantized_data[i] = 1;
+ }
+ else
+ {
+ quantparam->scale.push_back(-data);
+ quantized_data[i] = -1;
+ }
+ quantparam->zerop.push_back(0);
+ }
+ }
+ node->quantparam(std::move(quantparam));
+
+ switch (quant_type)
+ {
+ case loco::DataType::U8:
+ node->dtype(loco::DataType::U8);
+ node->size<loco::DataType::U8>(size);
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ assert(quantized_data[i] == 0 || quantized_data[i] == 1);
+ node->at<loco::DataType::U8>(i) = quantized_data[i];
+ }
+ break;
+ case loco::DataType::S16:
+ node->dtype(loco::DataType::S16);
+ node->size<loco::DataType::S16>(size);
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ assert(quantized_data[i] == -1 || quantized_data[i] == 1);
+ node->at<loco::DataType::S16>(i) = quantized_data[i];
+ }
+ break;
+ default:
+ throw std::runtime_error("Unsupported data type");
+ }
+}
+
+} // namespace
+
+namespace luci
+{
+
+void QuantizeWeights::quantize_weights(luci::CircleConst *weights)
+{
+ // Find min/max per channel-wise
+ if (granularity == QuantizationGranularity::ChannelWise)
+ {
+ auto quantparam = weights->quantparam();
+ if (quantparam == nullptr)
+ {
+ assert(false && "quantparam is nullptr");
+ return;
+ }
+
+ auto min = quantparam->min;
+ auto scaling_factor = quantparam->scale;
+ int32_t channel_dim_index = 0;
+
+ if (output_type == loco::DataType::U8)
+ {
+ asym_wquant_per_channel(weights, min, scaling_factor, channel_dim_index);
+ }
+ else
+ {
+ sym_wquant_per_channel(weights, scaling_factor, channel_dim_index);
+ }
+ quantparam->min.clear();
+ quantparam->max.clear();
+ quantparam->quantized_dimension = channel_dim_index;
+ }
+ // Find min/max per layer-wise
+ else
+ {
+ // Quantize using recorded quantparam
+ auto quantparam = weights->quantparam();
+ assert(quantparam != nullptr);
+ assert(quantparam->min.size() == 1); // only support layer-wise quant
+ assert(quantparam->scale.size() == 1); // only support layer-wise quant
+ auto min = quantparam->min[0];
+ auto scaling_factor = quantparam->scale[0];
+ asym_wquant_per_layer(weights, min, scaling_factor);
+ quantparam->min.clear();
+ quantparam->max.clear();
+ }
+}
+void QuantizeWeights::visit(luci::CircleConv2D *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeWeights QuantizeWeights::visit node: " << node->name() << std::endl;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
+ if (!is_quantized(weights))
+ {
+ auto new_weights = luci::clone(weights);
+ node->filter(new_weights);
+ quantize_weights(new_weights);
+ }
+}
+
+void QuantizeWeights::visit(luci::CircleDepthwiseConv2D *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeWeights QuantizeWeights::visit node: " << node->name() << std::endl;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
+ if (!is_quantized(weights))
+ {
+ auto new_weights = luci::clone(weights);
+ node->filter(new_weights);
+ quantize_weights(new_weights);
+ }
+}
+
+void QuantizeWeights::visit(luci::CircleInstanceNorm *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeWeights QuantizeWeights::visit node: " << node->name() << std::endl;
+
+ auto gamma = loco::must_cast<luci::CircleConst *>(node->gamma());
+ auto beta = loco::must_cast<luci::CircleConst *>(node->beta());
+
+ if (!is_quantized(gamma))
+ {
+ assert(gamma->dtype() == loco::DataType::FLOAT32);
+ auto new_gamma = luci::clone(gamma);
+ if (granularity == QuantizationGranularity::LayerWise)
+ quant_const(new_gamma, output_type);
+ else if (granularity == QuantizationGranularity::ChannelWise)
+ quant_const_per_channel(new_gamma, output_type);
+ node->gamma(new_gamma);
+ }
+ if (!is_quantized(beta))
+ {
+ assert(beta->dtype() == loco::DataType::FLOAT32);
+ auto new_beta = luci::clone(beta);
+ if (granularity == QuantizationGranularity::LayerWise)
+ quant_const(new_beta, output_type);
+ else if (granularity == QuantizationGranularity::ChannelWise)
+ quant_const_per_channel(new_beta, output_type);
+ node->beta(new_beta);
+ }
+}
+
+void QuantizeWeights::visit(luci::CirclePRelu *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeWeights QuantizeWeights::visit node: " << node->name() << std::endl;
+
+ auto alpha = loco::must_cast<luci::CircleConst *>(node->alpha());
+
+ if (!is_quantized(alpha))
+ {
+ assert(alpha->dtype() == loco::DataType::FLOAT32);
+ auto new_alpha = luci::clone(alpha);
+ if (granularity == QuantizationGranularity::LayerWise)
+ quant_const(new_alpha, output_type);
+ else if (granularity == QuantizationGranularity::ChannelWise)
+ quant_const_per_channel(new_alpha, output_type);
+ node->alpha(new_alpha);
+ }
+}
+
+void QuantizeWeights::visit(luci::CircleTransposeConv *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeWeights QuantizeWeights::visit node: " << node->name() << std::endl;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
+ if (!is_quantized(weights))
+ {
+ auto new_weights = luci::clone(weights);
+ node->filter(new_weights);
+ quantize_weights(new_weights);
+ }
+}
+
+void QuantizeWeights::visit(luci::CircleFullyConnected *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeWeights QuantizeWeights::visit node: " << node->name() << std::endl;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->weights());
+ if (!is_quantized(weights))
+ {
+ auto new_weights = luci::clone(weights);
+ node->weights(new_weights);
+ quantize_weights(new_weights);
+ }
+}
+
+void QuantizeWeights::visit(luci::CircleNode *) {}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/QuantizeWeights.h b/compiler/luci/pass/src/QuantizeWeights.h
new file mode 100644
index 000000000..f62cd40f3
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizeWeights.h
@@ -0,0 +1,55 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_QUANTIZE_WEIGHTS_H__
+#define __LUCI_QUANTIZE_WEIGHTS_H__
+
+#include <luci/Pass/QuantizationParameters.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+namespace luci
+{
+
+/**
+ * @brief QuantizeWeights quantizes tensors for weights
+ * @details Find min/max values on the fly and then quantize
+ */
+struct QuantizeWeights final : public luci::CircleNodeMutableVisitor<void>
+{
+ QuantizeWeights(loco::DataType input, loco::DataType output, QuantizationGranularity gr)
+ : input_type(input), output_type(output), granularity(gr)
+ {
+ }
+
+ loco::DataType input_type;
+ loco::DataType output_type;
+ QuantizationGranularity granularity;
+
+private:
+ void quantize_weights(luci::CircleConst *weights);
+
+ void visit(luci::CircleConv2D *node);
+ void visit(luci::CircleDepthwiseConv2D *node);
+ void visit(luci::CircleInstanceNorm *node);
+ void visit(luci::CirclePRelu *node);
+ void visit(luci::CircleTransposeConv *node);
+ void visit(luci::CircleFullyConnected *node);
+ void visit(luci::CircleNode *);
+};
+
+} // namespace luci
+
+#endif // __LUCI_QUANTIZE_WEIGHTS_H__
diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
index c3552ec52..d9a9d4db7 100644
--- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
+++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
@@ -15,55 +15,32 @@
*/
#include "luci/Pass/QuantizeWithMinMaxPass.h"
+#include "luci/Pass/PropagateQParamForwardPass.h"
+#include "luci/Pass/PropagateQParamBackwardPass.h"
+#include "luci/Pass/RemoveRedundantQuantizePass.h"
+#include "QuantizeActivation.h"
+#include "QuantizeWeights.h"
+#include "QuantizeBias.h"
#include "QuantizationUtils.h"
+#include "ProgressReporter.h"
+#include "helpers/LayerInfoMap.h"
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
#include <luci/Service/Nodes/CircleConst.h>
#include <luci/Profile/CircleNodeOrigin.h>
#include <luci/Log.h>
+#include <logo/Phase.h>
#include <oops/UserExn.h>
#include <iostream>
#include <cmath>
-#include <functional>
namespace
{
using namespace luci;
-using IterFunc = std::function<void(uint32_t *, loco::TensorShape &, int32_t)>;
-
-void iterate_per_channel(CircleConst *node, int32_t &channel_dim_index, IterFunc func)
-{
- loco::TensorShape dimension;
- dimension.rank(4);
- uint32_t indices[4] = {
- 0,
- };
-
- if (!get_channel_dim_index(node, dimension, channel_dim_index))
- {
- assert(false);
- return;
- }
-
- for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++)
- {
- for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++)
- {
- for (indices[2] = 0; indices[2] < dimension.dim(2).value(); indices[2]++)
- {
- for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++)
- {
- func(indices, dimension, channel_dim_index);
- }
- }
- }
- }
-}
-
// Create a Quantize Op whose
// dtype is out_type
// shape is the same with node
@@ -80,7 +57,17 @@ luci::CircleQuantize *create_quantize_op(luci::CircleNode *node, loco::DataType
quantize->shape_status(luci::ShapeStatus::VALID);
auto qparam = node->quantparam();
- assert(qparam); // FIX_CALLER_UNLESS
+ assert(qparam); // FIX_CALLER_UNLESS
+
+ auto qtype = luci::activation_qtype(node);
+ if (qtype == ActivationQType::PreDefinedValue)
+ {
+ quantize->quantparam(luci::make_predefined_qparam(node->opcode(), out_type));
+ return quantize;
+ }
+
+ assert(qtype == ActivationQType::MinMax or qtype == ActivationQType::IntScale);
+
assert(qparam->min.size() == 1); // FIX_CALLER_UNLESS
assert(qparam->max.size() == 1); // FIX_CALLER_UNLESS
auto min = qparam->min[0];
@@ -104,9 +91,17 @@ luci::CircleQuantize *create_quantize_op(luci::CircleNode *node, loco::DataType
auto quantparam = std::make_unique<CircleQuantParam>();
quantparam->scale.push_back(scaling_factor);
quantparam->zerop.push_back(zp);
+ // Save original min/max (not nudged_min/max). Nudged min/max
+ // is different from the real min/max values, causing wrong
+ // qparam when quantization dtype is changed.
+ quantparam->min.push_back(min);
+ quantparam->max.push_back(max);
quantize->quantparam(std::move(quantparam));
+ if (qtype == ActivationQType::IntScale)
+ set_int_scale(quantize);
+
return quantize;
}
@@ -118,1412 +113,232 @@ namespace luci
namespace
{
-// Create a new const node from an existing node.
-// The new node has the following characteristics
-// type: T
-// shape: same with 'node' (given as an argument)
-// buffer size: 'size' (given as an argument)
-// Note that contents are not filled in this function.
-template <loco::DataType T>
-luci::CircleConst *create_empty_const_from(luci::CircleConst *node, uint32_t size)
-{
- auto new_node = node->graph()->nodes()->create<CircleConst>();
- // TODO: We don't have any naming convention for quantized nodes yet.
- // Fix this when we have one.
- new_node->name(node->name());
- new_node->dtype(T);
- new_node->rank(node->rank());
- for (uint32_t i = 0; i < node->rank(); i++)
- new_node->dim(i).set(node->dim(i).value());
-
- new_node->size<T>(size);
- new_node->shape_status(luci::ShapeStatus::VALID);
-
- return new_node;
-}
-
-void overwrite_quantparam(luci::CircleNode *source, luci::CircleNode *target)
-{
- auto source_qparam = source->quantparam();
- if (source_qparam == nullptr)
- throw std::runtime_error("source quantparam is not found during overwrite");
-
- auto target_qparam = target->quantparam();
- if (target_qparam == nullptr)
- {
- auto quantparam = std::make_unique<CircleQuantParam>();
- target->quantparam(std::move(quantparam));
- target_qparam = target->quantparam();
-
- if (target_qparam == nullptr)
- throw std::runtime_error("Creating new quant param failed");
- }
- target_qparam->min = source_qparam->min;
- target_qparam->max = source_qparam->max;
- target_qparam->scale = source_qparam->scale;
- target_qparam->zerop = source_qparam->zerop;
- target_qparam->quantized_dimension = source_qparam->quantized_dimension;
-}
-
-void quant_const_values(luci::CircleConst *const_node, float scaling_factor, float zerop,
- loco::DataType quant_type)
-{
- uint32_t size = const_node->size<loco::DataType::FLOAT32>();
-
- const float scaling_factor_inv = 1.0 / scaling_factor;
- std::vector<int32_t> quantized_values(size);
- for (uint32_t i = 0; i < size; ++i)
- {
- auto data = static_cast<double>(const_node->at<loco::DataType::FLOAT32>(i));
- double quantized_float = std::round(data * scaling_factor_inv) + zerop;
- constexpr auto int_max = static_cast<double>(std::numeric_limits<int32_t>::max());
- constexpr auto int_min = static_cast<double>(std::numeric_limits<int32_t>::min());
- quantized_float = std::min(int_max, std::max(int_min, quantized_float));
-
- quantized_values[i] = static_cast<int32_t>(quantized_float);
- }
-
- switch (quant_type)
- {
- case loco::DataType::U8:
- const_node->dtype(loco::DataType::U8); // change the type of tensor
- const_node->size<loco::DataType::U8>(size); // resize tensor
- for (uint32_t i = 0; i < size; ++i)
- const_node->at<loco::DataType::U8>(i) = std::min(255, std::max(0, quantized_values[i]));
- break;
- case loco::DataType::S16:
- assert(zerop == 0);
- const_node->dtype(loco::DataType::S16); // change the type of tensor
- const_node->size<loco::DataType::S16>(size); // resize tensor
- for (uint32_t i = 0; i < size; ++i)
- const_node->at<loco::DataType::S16>(i) =
- std::min(32767, std::max(-32767, quantized_values[i]));
- break;
- default:
- throw std::runtime_error("Unsupported data type");
- }
-}
-
-// Quantize const per channel
-//
-// The last dimension of const is the same as the dimension of channel
-// And the rest of the const dimensions should be 1
-// So, a 'single value' is quantized per channel
-//
-// Quantization spec (f: fp value, q: quantized value)
-//
-// uint8
-// Positive f: f = f * (q - 0) [q = 1, scale = f, zp = 0]
-// Negative f: f = (-f) * (q - 1) [q = 0, scale = -f, zp = 1]
-//
-// int16
-// Positive f: f = f * (q - 0) [q = 1, scale = f, zp = 0]
-// Negative f: f = (-f) * (q - 0) [q = -1, scale = -f, zp = 0]
-void quant_const_per_channel(CircleConst *node, loco::DataType quant_type)
-{
- assert(node->dtype() == loco::DataType::FLOAT32);
- assert(node->rank() > 0);
-
- for (uint32_t i = 0; i < node->rank() - 1; i++)
- {
- // Caller should call this function when the below condition is satisfied
- if (node->dim(i).value() != 1)
- throw std::runtime_error("Non-channel dimension of const node must be 1");
- }
-
- uint32_t size = node->size<loco::DataType::FLOAT32>();
- assert(size == node->dim(node->rank() - 1).value());
-
- auto quantparam = std::make_unique<CircleQuantParam>();
- quantparam->quantized_dimension = node->rank() - 1;
- std::vector<int32_t> quantized_data(size);
-
- for (uint32_t i = 0; i < size; ++i)
- {
- auto data = node->at<loco::DataType::FLOAT32>(i);
- if (quant_type == loco::DataType::U8)
- {
- if (data >= 0)
- {
- quantparam->scale.push_back(data);
- quantparam->zerop.push_back(0);
- quantized_data[i] = 1;
- }
- else
- {
- quantparam->scale.push_back(-data);
- quantparam->zerop.push_back(1);
- quantized_data[i] = 0;
- }
- }
- else if (quant_type == loco::DataType::S16)
- {
- if (data >= 0)
- {
- quantparam->scale.push_back(data);
- quantized_data[i] = 1;
- }
- else
- {
- quantparam->scale.push_back(-data);
- quantized_data[i] = -1;
- }
- quantparam->zerop.push_back(0);
- }
- }
- node->quantparam(std::move(quantparam));
-
- switch (quant_type)
- {
- case loco::DataType::U8:
- node->dtype(loco::DataType::U8);
- node->size<loco::DataType::U8>(size);
- for (uint32_t i = 0; i < size; ++i)
- {
- assert(quantized_data[i] == 0 || quantized_data[i] == 1);
- node->at<loco::DataType::U8>(i) = quantized_data[i];
- }
- break;
- case loco::DataType::S16:
- node->dtype(loco::DataType::S16);
- node->size<loco::DataType::S16>(size);
- for (uint32_t i = 0; i < size; ++i)
- {
- assert(quantized_data[i] == -1 || quantized_data[i] == 1);
- node->at<loco::DataType::S16>(i) = quantized_data[i];
- }
- break;
- default:
- throw std::runtime_error("Unsupported data type");
- }
-}
-
-void quant_const(CircleConst *node, loco::DataType quant_type)
-{
- assert(node->dtype() == loco::DataType::FLOAT32);
-
- float min = std::numeric_limits<float>::max();
- float max = std::numeric_limits<float>::lowest();
- for (uint32_t i = 0; i < node->size<loco::DataType::FLOAT32>(); i++)
- {
- auto data = node->at<loco::DataType::FLOAT32>(i);
- min = data < min ? data : min;
- max = data > max ? data : max;
- }
-
- float scaling_factor{0.0};
- int64_t zp{0};
- float nudged_min{0.0};
- float nudged_max{0.0};
-
- switch (quant_type)
- {
- case loco::DataType::U8:
- asymmetric_wquant_with_minmax_per_layer(node, min, max, scaling_factor, zp, nudged_min,
- nudged_max);
- break;
- case loco::DataType::S16:
- symmetric_wquant_with_minmax_per_layer(node, min, max, scaling_factor, zp, nudged_min,
- nudged_max);
- break;
- default:
- throw std::runtime_error("Unsupported data type");
- }
-
- auto quantparam = std::make_unique<CircleQuantParam>();
- quantparam->scale.push_back(scaling_factor);
- quantparam->zerop.push_back(zp);
- node->quantparam(std::move(quantparam));
-}
-
-// Check if the node is the bias of Conv2D, DepthwiseConv2D, FullyConnected, or TransposeConv layer
-// Returns a list of <input, weights, output> vectors for the above operators.
-// Note that it returns a 'list' because bias can be used by multiple operators.
-std::vector<std::vector<loco::Node *>> get_input_weight_output_of_bias(CircleNode *node)
-{
- std::vector<std::vector<loco::Node *>> result;
- auto circle_const = dynamic_cast<CircleConst *>(node);
- if (circle_const == nullptr)
- return result;
-
- auto succs = loco::succs(node);
-
- for (auto out : succs)
- {
- auto conv = dynamic_cast<CircleConv2D *>(out);
- if (conv != nullptr && conv->bias() == circle_const)
- {
- assert(conv->input() != nullptr);
- assert(conv->filter() != nullptr);
- result.push_back({conv->input(), conv->filter(), conv});
- continue;
- }
- auto dw_conv = dynamic_cast<CircleDepthwiseConv2D *>(out);
- if (dw_conv != nullptr && dw_conv->bias() == circle_const)
- {
- assert(dw_conv->input() != nullptr);
- assert(dw_conv->filter() != nullptr);
- result.push_back({dw_conv->input(), dw_conv->filter(), dw_conv});
- continue;
- }
- auto fc = dynamic_cast<CircleFullyConnected *>(out);
- if (fc != nullptr && fc->bias() == circle_const)
- {
- assert(fc->input() != nullptr);
- assert(fc->weights() != nullptr);
- result.push_back({fc->input(), fc->weights(), fc});
- continue;
- }
- auto tconv = dynamic_cast<CircleTransposeConv *>(out);
- if (tconv != nullptr && tconv->bias() == circle_const)
- {
- assert(tconv->outBackprop() != nullptr);
- assert(tconv->filter() != nullptr);
- result.push_back({tconv->outBackprop(), tconv->filter(), tconv});
- continue;
- }
- }
- return result;
-}
-
-CircleConst *asym_quant_bias_per_layer(CircleConst *node, float input_scale, float weight_scale,
- float *scaling_factor, int64_t *zp)
-{
- float scale = input_scale * weight_scale;
- const float scaling_factor_inv = (scale == 0) ? 0 : 1.0 / scale;
-
- uint32_t size = node->size<loco::DataType::FLOAT32>();
- std::vector<int32_t> quantized_values(size);
- for (uint32_t i = 0; i < size; ++i)
- {
- quantized_values[i] =
- static_cast<int32_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
- }
-
- auto new_bias = create_empty_const_from<loco::DataType::S32>(node, size);
-
- const int32_t kMinScale = std::numeric_limits<int32_t>::lowest();
- const int32_t kMaxScale = std::numeric_limits<int32_t>::max();
- for (uint32_t i = 0; i < size; ++i)
- {
- new_bias->at<loco::DataType::S32>(i) =
- std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
- }
- *scaling_factor = scale;
- *zp = 0;
-
- return new_bias;
-}
-
-CircleConst *quant_bias_per_channel(CircleConst *node, float input_scale,
- std::vector<float> &weight_scale,
- std::vector<float> &scaling_factor, std::vector<int64_t> &zp)
-{
- float scaling_factor_inv{0};
-
- uint32_t size = node->size<loco::DataType::FLOAT32>();
- std::vector<int32_t> quantized_values(size);
-
- for (uint32_t i = 0; i < size; ++i)
- {
- scaling_factor[i] = input_scale * weight_scale[i];
- scaling_factor_inv = (scaling_factor[i] == 0) ? 0 : 1.0 / scaling_factor[i];
- quantized_values[i] =
- static_cast<int32_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
- zp[i] = 0;
- }
-
- auto new_bias = create_empty_const_from<loco::DataType::S32>(node, size);
-
- const int32_t kMinScale = std::numeric_limits<int32_t>::lowest();
- const int32_t kMaxScale = std::numeric_limits<int32_t>::max();
- for (uint32_t i = 0; i < size; ++i)
- {
- new_bias->at<loco::DataType::S32>(i) =
- std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
- }
-
- return new_bias;
-}
-
-CircleConst *int16_quant_bias_per_channel(CircleConst *node, float input_scale,
- std::vector<float> &weight_scale,
- std::vector<float> &scaling_factor,
- std::vector<int64_t> &zp)
-{
- float scaling_factor_inv{0};
-
- uint32_t size = node->size<loco::DataType::FLOAT32>();
- std::vector<int64_t> quantized_values(size);
-
- for (uint32_t i = 0; i < size; ++i)
- {
- scaling_factor[i] = input_scale * weight_scale[i];
- scaling_factor_inv = (scaling_factor[i] == 0) ? 0 : 1.0 / scaling_factor[i];
- quantized_values[i] =
- static_cast<int64_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
- zp[i] = 0;
- }
-
- auto new_bias = create_empty_const_from<loco::DataType::S64>(node, size);
-
- for (uint32_t i = 0; i < size; ++i)
- {
- new_bias->at<loco::DataType::S64>(i) = quantized_values[i];
- }
-
- return new_bias;
-}
-
-bool has_min_max(const CircleNode *node)
-{
- return node->quantparam() && !node->quantparam()->min.empty() && !node->quantparam()->max.empty();
-}
-
-void sym_wquant_per_channel(CircleConst *node, std::vector<float> &scaling_factor,
- int32_t &channel_dim_index)
-{
- assert(node->dtype() == loco::DataType::FLOAT32);
-
- const int32_t kMaxScale = std::numeric_limits<int16_t>::max();
- const int32_t kMinScale = -kMaxScale;
-
- uint32_t size = node->size<loco::DataType::FLOAT32>();
- std::vector<int32_t> quantized_values(size);
-
- auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int32_t channel_dim_index) {
- int channel_idx = indices[channel_dim_index];
- const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
- auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
- quantized_values[cal_offset(dimension, indices)] =
- static_cast<int32_t>(std::round(data * scaling_factor_inv));
- };
-
- iterate_per_channel(node, channel_dim_index, quantize);
-
- node->dtype(loco::DataType::S16); // change the type of tensor
- node->size<loco::DataType::S16>(size); // resize tensor
- for (uint32_t i = 0; i < size; ++i)
- {
- node->at<loco::DataType::S16>(i) =
- std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
- }
-}
-
-void asym_wquant_per_channel(CircleConst *node, std::vector<float> &min,
- std::vector<float> &scaling_factor, int32_t &channel_dim_index)
-{
- assert(node->dtype() == loco::DataType::FLOAT32);
-
- const int32_t kMinScale = 0;
- const int32_t kMaxScale = 255;
-
- uint32_t size = node->size<loco::DataType::FLOAT32>();
- std::vector<int32_t> quantized_values(size);
-
- auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int32_t channel_dim_index) {
- int channel_idx = indices[channel_dim_index];
- const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
- auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
- quantized_values[cal_offset(dimension, indices)] =
- static_cast<int32_t>(std::round((data - min[channel_idx]) * scaling_factor_inv));
- };
-
- iterate_per_channel(node, channel_dim_index, quantize);
-
- node->dtype(loco::DataType::U8); // change the type of tensor
- node->size<loco::DataType::U8>(size); // resize tensor
- for (uint32_t i = 0; i < size; ++i)
- {
- node->at<loco::DataType::U8>(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
- }
-}
-
-void asym_wquant_per_layer(CircleConst *node, float min, float scaling_factor)
-{
- const int32_t kMinScale = 0;
- const int32_t kMaxScale = 255;
-
- uint32_t size = node->size<loco::DataType::FLOAT32>();
-
- const float scaling_factor_inv = 1.0 / scaling_factor;
- std::vector<int32_t> quantized_values(size);
- for (uint32_t i = 0; i < size; ++i)
- {
- auto data = node->at<loco::DataType::FLOAT32>(i);
- quantized_values[i] = static_cast<int32_t>(std::round((data - min) * scaling_factor_inv));
- }
-
- node->dtype(loco::DataType::U8); // change the type of tensor
- node->size<loco::DataType::U8>(size); // resize tensor
- for (uint32_t i = 0; i < size; ++i)
- {
- node->at<loco::DataType::U8>(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
- }
-}
-
-void set_bias(luci::CircleNode *node, luci::CircleConst *bias)
-{
- if (auto conv = dynamic_cast<CircleConv2D *>(node))
- conv->bias(bias);
- else if (auto dconv = dynamic_cast<CircleDepthwiseConv2D *>(node))
- dconv->bias(bias);
- else if (auto tconv = dynamic_cast<CircleTransposeConv *>(node))
- tconv->bias(bias);
- else if (auto fc = dynamic_cast<CircleFullyConnected *>(node))
- fc->bias(bias);
- else
- throw std::runtime_error("Only convolution, depthwise convolution, transposed convolution, and "
- "fully-connected layer have bias");
-}
-
-void set_act_qparam(luci::CircleNode *node, float scale, int64_t zp)
-{
- assert(node); // FIX_CALLER_UNLESS
- assert(node->quantparam()); // FIX_CALLER_UNLESS
-
- auto qparam = node->quantparam();
- assert(qparam->scale.size() == 1); // FIX_CALLER_UNLESS
- assert(qparam->zerop.size() == 1); // FIX_CALLER_UNLESS
- qparam->scale[0] = scale;
- qparam->zerop[0] = zp;
-}
-
-/**
- * @brief Manually set scale/zp of output tensor of special Ops
- */
-struct QuantizeSpecialActivation final : public luci::CircleNodeMutableVisitor<void>
-{
- QuantizeSpecialActivation(loco::DataType input, loco::DataType output)
- : input_type(input), output_type(output)
- {
- }
-
- loco::DataType input_type;
- loco::DataType output_type;
-
- void visit(luci::CircleNode *)
- {
- // Do nothing by default
- }
-
- void visit(luci::CircleLogistic *node)
- {
- if (output_type == loco::DataType::U8)
- set_act_qparam(node, 1.0f / 256.0f, 0);
- else
- {
- assert(output_type == loco::DataType::S16);
- set_act_qparam(node, 1.0f / 32768.0f, 0);
- }
- }
-
- void visit(luci::CircleTanh *node)
- {
- if (output_type == loco::DataType::U8)
- set_act_qparam(node, 2.0f / 256.0f, 128);
- else
- {
- assert(output_type == loco::DataType::S16);
- set_act_qparam(node, 1.0f / 32768.0f, 0);
- }
- }
-
- void visit(luci::CircleStridedSlice *node)
- {
- auto input = loco::must_cast<luci::CircleNode *>(node->input());
- auto i_qparam = input->quantparam();
- assert(i_qparam);
- assert(i_qparam->scale.size() == 1); // FIX_CALLER_UNLESS
- assert(i_qparam->zerop.size() == 1); // FIX_CALLER_UNLESS
- auto i_scale = i_qparam->scale[0];
- auto i_zp = i_qparam->zerop[0];
-
- set_act_qparam(node, i_scale, i_zp);
- }
-
- void visit(luci::CircleSplitOut *node)
- {
- auto split = loco::must_cast<luci::CircleSplit *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(split->input());
- auto i_qparam = input->quantparam();
- assert(i_qparam);
- assert(i_qparam->scale.size() == 1); // FIX_CALLER_UNLESS
- assert(i_qparam->zerop.size() == 1); // FIX_CALLER_UNLESS
- auto i_scale = i_qparam->scale[0];
- auto i_zp = i_qparam->zerop[0];
-
- set_act_qparam(node, i_scale, i_zp);
- }
-
- void visit(luci::CircleSplitVOut *node)
- {
- auto splitv = loco::must_cast<luci::CircleSplitV *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(splitv->input());
- auto i_qparam = input->quantparam();
- assert(i_qparam);
- assert(i_qparam->scale.size() == 1); // FIX_CALLER_UNLESS
- assert(i_qparam->zerop.size() == 1); // FIX_CALLER_UNLESS
- auto i_scale = i_qparam->scale[0];
- auto i_zp = i_qparam->zerop[0];
-
- set_act_qparam(node, i_scale, i_zp);
- }
-
- void visit(luci::CircleUnpackOut *node)
- {
- auto unpack = loco::must_cast<luci::CircleUnpack *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(unpack->value());
- auto i_qparam = input->quantparam();
- assert(i_qparam);
- assert(i_qparam->scale.size() == 1); // FIX_CALLER_UNLESS
- assert(i_qparam->zerop.size() == 1); // FIX_CALLER_UNLESS
- auto i_scale = i_qparam->scale[0];
- auto i_zp = i_qparam->zerop[0];
-
- set_act_qparam(node, i_scale, i_zp);
- }
-
- // TODO Move Softmax, Floor, Ceil from QuantizeActivation to here
-};
-
/**
- * @brief QuantizeActivation quantizes tensors for activations
- * @details Quantize using recorded min/max values
+ * Insert Quantize operator for mixed-precision quantization
+ * 1. Before input feature map (only for non-const)
+ * 2. After output feature map
+ *
+ * For example, if default_dtype = U8 and op_dtype = S16,
+ * 1. Quantize Op for U8->S16 is inserted before ifm
+ * 2. Quantize Op for S16->U8 is inserted after ofm
+ *
+ * Why not insert Quantize Op for const ifm?
+ * We quantize const tensor at once to preserve precision.
+ * For example, if default dtype = U8, op_dtype = S16, and op is CONV2D,
+ * We directly quantize weights to 16 bits, not 8->16 bits.
*/
-struct QuantizeActivation final : public luci::CircleNodeMutableVisitor<bool>
+struct InsertQuantizeOp final : public luci::CircleNodeMutableVisitor<void>
{
- QuantizeActivation(loco::DataType input, loco::DataType output)
- : input_type(input), output_type(output)
+ InsertQuantizeOp(loco::DataType default_dtype, loco::DataType op_dtype)
+ : _default_dtype(default_dtype), _op_dtype(op_dtype)
{
+ assert(default_dtype != op_dtype); // FIX_CALLER_UNLESS
}
- loco::DataType input_type;
- loco::DataType output_type;
+private:
+ loco::DataType _default_dtype;
+ loco::DataType _op_dtype;
- // Quantize input tensors of each node
- bool visit(luci::CircleNode *node)
+private:
+ luci::CircleQuantize *create_in_quantize(loco::Node *in, loco::Node *origin)
+ {
+ auto input = loco::must_cast<luci::CircleNode *>(in);
+ if (input->opcode() == luci::CircleOpcode::CIRCLECONST)
+ return nullptr;
+
+ auto input_quant = create_quantize_op(input, _op_dtype);
+ input_quant->input(input);
+ auto origin_node = loco::must_cast<luci::CircleNode *>(origin);
+ luci::add_origin(input_quant, luci::get_origin(origin_node));
+ return input_quant;
+ }
+
+ void insert_out_quantize(loco::Node *node)
+ {
+ auto output = loco::must_cast<luci::CircleNode *>(node);
+ assert(output->opcode() != luci::CircleOpcode::CIRCLECONST); // FIX_CALLER_UNLESS
+ auto output_quant = create_quantize_op(output, _default_dtype);
+
+ luci::add_origin(output_quant, luci::get_origin(output));
+ loco::replace(node).with(output_quant);
+ output_quant->input(node);
+ }
+
+// INPUT_NAME is the only activation of NODE
+#define INSERT_QUANTIZE_TO_UNARY_OP(NODE, INPUT_NAME) \
+ void visit(NODE *node) \
+ { \
+ if (auto input_quant = create_in_quantize(node->INPUT_NAME(), node)) \
+ node->INPUT_NAME(input_quant); \
+ \
+ insert_out_quantize(node); \
+ }
+
+// INPUT_NAME is the only activation of NODE
+#define INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(NODE, INPUT_NAME, OUT_NAME) \
+ void visit(NODE *node) \
+ { \
+ if (auto input_quant = create_in_quantize(node->INPUT_NAME(), node)) \
+ node->INPUT_NAME(input_quant); \
+ \
+ auto out_nodes = loco::succs(node); \
+ for (auto out_node : out_nodes) \
+ { \
+ auto out_circle = loco::must_cast<OUT_NAME *>(out_node); \
+ insert_out_quantize(out_circle); \
+ } \
+ }
+
+// INPUT_NAME1 and INPUT_NAME2 are the only activations of NODE
+#define INSERT_QUANTIZE_TO_BINARY_OP(NODE, INPUT_NAME1, INPUT_NAME2) \
+ void visit(NODE *node) \
+ { \
+ if (auto input1_quant = create_in_quantize(node->INPUT_NAME1(), node)) \
+ node->INPUT_NAME1(input1_quant); \
+ \
+ if (auto input2_quant = create_in_quantize(node->INPUT_NAME2(), node)) \
+ node->INPUT_NAME2(input2_quant); \
+ \
+ insert_out_quantize(node); \
+ }
+
+ // Default behavior (NYI)
+ void visit(luci::CircleNode *node)
+ {
+ throw std::runtime_error("Unsupported Op for mixed-precision quantization. Layer name: " +
+ node->name());
+ }
+
+ // Skip output layer
+ void visit(luci::CircleOutput *) {}
+ void visit(luci::CircleSplitVOut *) {}
+ void visit(luci::CircleSplitOut *) {}
+ void visit(luci::CircleTopKV2Out *) {}
+ void visit(luci::CircleUniqueOut *) {}
+ void visit(luci::CircleUnpackOut *) {}
+
+ // Ops that receive a single activation as an input
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleAveragePool2D, value)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleBatchToSpaceND, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleConv2D, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleDepthToSpace, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleDepthwiseConv2D, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleElu, features)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleExp, x)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleFloor, x)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleFullyConnected, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleGather, params)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleInstanceNorm, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLocalResponseNormalization, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLogistic, x)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMaxPool2D, value)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMean, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMirrorPad, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePad, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePadV2, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePRelu, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReduceProd, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReduceMax, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReduceMin, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleRelu, features)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReshape, tensor)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleResizeBilinear, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleResizeNearestNeighbor, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReverseSequence, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleRsqrt, x)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSlice, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSoftmax, logits)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSpaceToBatchND, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSpaceToDepth, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSqrt, x)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleStridedSlice, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSum, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleTanh, x)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleTile, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleTranspose, a)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleTransposeConv, outBackprop)
+
+ // Ops that receive two activations as inputs
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleAdd, x, y)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleBatchMatMul, x, y)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleDiv, x, y)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleFloorDiv, x, y)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleMaximum, x, y)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleMinimum, x, y)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleMul, x, y)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleOneHot, on_value, off_value)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CirclePow, x, y)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleSub, x, y)
+
+ // Multiple-output ops that receive one activation as inputs
+ INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(luci::CircleSplit, input, luci::CircleSplitOut)
+ INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(luci::CircleSplitV, input, luci::CircleSplitVOut)
+ INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(luci::CircleTopKV2, input, luci::CircleTopKV2Out)
+ INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(luci::CircleUnique, input, luci::CircleUniqueOut)
+ INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(luci::CircleUnpack, value, luci::CircleUnpackOut)
+
+ // AddN has arbitrary number of inputs
+ void visit(luci::CircleAddN *node)
{
- LOGGER(l);
- INFO(l) << "QuantizeActivation visit node: " << node->name() << std::endl;
auto arity = node->arity();
for (uint32_t i = 0; i < arity; i++)
{
- auto input_node = node->arg(i);
- auto circle_node = loco::must_cast<luci::CircleNode *>(input_node);
-
- // Check if this is already quantized
- if (is_quantized(circle_node))
- continue;
-
- // Check if this is bias (bias is quantized later)
- auto iwo = get_input_weight_output_of_bias(circle_node);
- if (iwo.size() > 0)
- continue;
-
- // Check if this is bool type (bool type is not quantized)
- if (circle_node->dtype() == loco::DataType::BOOL)
- continue;
-
- // Check if this is activation
- // We assume min/max are recorded only for activations
- if (has_min_max(circle_node) && !is_weights(circle_node))
- {
- // Quantize using recorded min/max
- auto quantparam = circle_node->quantparam();
- assert(quantparam);
- assert(quantparam->min.size() == 1); // only support layer-wise quant
- assert(quantparam->max.size() == 1); // only support layer-wise quant
- auto min = quantparam->min[0];
- auto max = quantparam->max[0];
-
- // Special values
- if (circle_node->opcode() == luci::CircleOpcode::SOFTMAX)
- {
- min = 0.0f;
- max = 1.0f;
- }
-
- float scaling_factor{0};
- int64_t zp{0};
- float nudged_min{0};
- float nudged_max{0};
-
- if (output_type == loco::DataType::U8)
- {
- compute_asym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
- circle_node->dtype(loco::DataType::U8);
- }
- else
- {
- compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
- circle_node->dtype(loco::DataType::S16);
- }
-
- // Nodes fused with activation functions which need special quantization
- auto fused_act_node =
- dynamic_cast<CircleNodeMixin<CircleNodeTrait::FusedActFunc> *>(circle_node);
- if (fused_act_node != nullptr &&
- fused_act_node->fusedActivationFunction() == FusedActFunc::TANH)
- {
- if (output_type == loco::DataType::U8)
- {
- scaling_factor = 2.0f / 256.0f;
- zp = 128;
- }
- else
- {
- assert(output_type == loco::DataType::S16);
- scaling_factor = 1.0f / 32768.0f;
- zp = 0;
- }
- }
-
- // The output of these Ops should be integer, so scale should be integer
- // TODO Handle cases where the integer scale needs to be propagated
- if (circle_node->opcode() == CircleOpcode::FLOOR ||
- circle_node->opcode() == CircleOpcode::FLOOR_DIV ||
- circle_node->opcode() == CircleOpcode::FLOOR_MOD ||
- circle_node->opcode() == CircleOpcode::CEIL)
- {
- assert(scaling_factor >= 0); // FIX_ME_UNLESS
- scaling_factor = scaling_factor < 1 ? 1.0f : std::round(scaling_factor);
- }
-
- circle_node->quantparam()->scale.push_back(scaling_factor);
- circle_node->quantparam()->zerop.push_back(zp);
- }
- // Fix special attributes
- if (circle_node->opcode() == luci::CircleOpcode::CAST)
- {
- auto *cast = loco::must_cast<luci::CircleCast *>(circle_node);
- auto *cast_input = loco::must_cast<luci::CircleNode *>(cast->x());
-
- // make sure that cast_input is already quantized
- assert(cast_input->dtype() != loco::DataType::FLOAT32);
- cast->in_data_type(cast_input->dtype());
- cast->out_data_type(cast->dtype());
- }
- }
- return false;
- }
-};
-
-struct QuantizeBias final : public luci::CircleNodeMutableVisitor<bool>
-{
- QuantizeBias(loco::DataType input, loco::DataType output, QuantizationGranularity gr)
- : input_type(input), output_type(output), granularity(gr)
- {
- }
-
- loco::DataType input_type;
- loco::DataType output_type;
- QuantizationGranularity granularity;
-
- // Quantize bias node
- bool visit(luci::CircleNode *node)
- {
- // Check if this is already quantized
- if (is_quantized(node))
- return false;
-
- auto iwo_list = get_input_weight_output_of_bias(node);
-
- for (auto iwo : iwo_list)
- {
- assert(iwo.size() == 3);
-
- auto input = loco::must_cast<luci::CircleNode *>(iwo[0]);
- auto weight = loco::must_cast<luci::CircleNode *>(iwo[1]);
- auto output = loco::must_cast<luci::CircleNode *>(iwo[2]);
-
- auto const_bias = loco::must_cast<luci::CircleConst *>(node);
- assert(const_bias->dtype() == loco::DataType::FLOAT32);
-
- // If input is const, it is quantized here, not in QuantizeActivation
- if (auto const_input = dynamic_cast<luci::CircleConst *>(input))
- {
- quant_const(const_input, output_type);
- }
-
- CircleConst *new_bias = nullptr;
-
- if (granularity == QuantizationGranularity::ChannelWise)
- {
- auto input_q = input->quantparam();
- assert(input_q);
- assert(input_q->scale.size() == 1); // input scale's layer-wise
- auto input_scale = input_q->scale[0];
-
- assert(weight->quantparam() != nullptr); // weight scale's channel-wise
- auto weight_scale = weight->quantparam()->scale;
-
- uint32_t size = const_bias->size<loco::DataType::FLOAT32>();
- assert(size == weight_scale.size());
- std::vector<float> scaling_factor(size);
- std::vector<int64_t> zp(size);
-
- if (output_type == loco::DataType::U8)
- {
- new_bias =
- quant_bias_per_channel(const_bias, input_scale, weight_scale, scaling_factor, zp);
- }
- else if (output_type == loco::DataType::S16)
- {
- new_bias =
- int16_quant_bias_per_channel(const_bias, input_scale, weight_scale, scaling_factor, zp);
- }
- else
- {
- throw std::runtime_error("Unsupported quantization type.");
- }
-
- auto quantparam = std::make_unique<CircleQuantParam>();
- quantparam->scale = scaling_factor;
- quantparam->zerop = zp;
- assert(new_bias->quantparam() == nullptr); // bias should not be quantized before
- new_bias->quantparam(std::move(quantparam));
-
- set_bias(output, new_bias);
- }
- else
- {
- auto input_q = input->quantparam();
- assert(input_q);
- assert(input_q->scale.size() == 1); // Only support per-layer quant
- auto input_scale = input_q->scale[0];
-
- auto weight_q = weight->quantparam();
- assert(weight_q);
- assert(weight_q->scale.size() == 1); // Only support per-layer quant
- auto weight_scale = weight_q->scale[0];
-
- float scaling_factor{0};
- int64_t zp{0};
- new_bias =
- asym_quant_bias_per_layer(const_bias, input_scale, weight_scale, &scaling_factor, &zp);
- auto quantparam = std::make_unique<CircleQuantParam>();
- quantparam->scale.push_back(scaling_factor);
- quantparam->zerop.push_back(zp);
- assert(new_bias->quantparam() == nullptr); // bias should not be quantized before
- new_bias->quantparam(std::move(quantparam));
-
- set_bias(output, new_bias);
- }
- }
- return false;
- }
-};
-
-/**
- * @brief QuantizeWeights quantizes tensors for weights
- * @details Find min/max values on the fly and then quantize
- */
-struct QuantizeWeights final : public luci::CircleNodeMutableVisitor<bool>
-{
- QuantizeWeights(loco::DataType input, loco::DataType output, QuantizationGranularity gr)
- : input_type(input), output_type(output), granularity(gr)
- {
- }
-
- loco::DataType input_type;
- loco::DataType output_type;
- QuantizationGranularity granularity;
-
-private:
- void quantize_weights(luci::CircleConst *weights)
- {
- // Find min/max per channel-wise
- if (granularity == QuantizationGranularity::ChannelWise)
- {
- auto quantparam = weights->quantparam();
- if (quantparam == nullptr)
- {
- assert(false && "quantparam is nullptr");
- return;
- }
-
- auto min = quantparam->min;
- auto scaling_factor = quantparam->scale;
- int32_t channel_dim_index = 0;
-
- if (output_type == loco::DataType::U8)
- {
- asym_wquant_per_channel(weights, min, scaling_factor, channel_dim_index);
- }
- else
- {
- sym_wquant_per_channel(weights, scaling_factor, channel_dim_index);
- }
- quantparam->min.clear();
- quantparam->max.clear();
- quantparam->quantized_dimension = channel_dim_index;
- }
- // Find min/max per layer-wise
- else
- {
- // Quantize using recorded quantparam
- auto quantparam = weights->quantparam();
- assert(quantparam != nullptr);
- assert(quantparam->min.size() == 1); // only support layer-wise quant
- assert(quantparam->scale.size() == 1); // only support layer-wise quant
- auto min = quantparam->min[0];
- auto scaling_factor = quantparam->scale[0];
- asym_wquant_per_layer(weights, min, scaling_factor);
- quantparam->min.clear();
- quantparam->max.clear();
- }
- }
-
- bool visit(luci::CircleConv2D *node)
- {
- LOGGER(l);
- INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
-
- auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
- if (!is_quantized(weights))
- {
- auto new_weights = luci::clone(weights);
- node->filter(new_weights);
- quantize_weights(new_weights);
- return true;
+ if (auto input_quant = create_in_quantize(node->inputs(i), node))
+ node->inputs(i, input_quant);
}
- return false;
- }
-
- bool visit(luci::CircleDepthwiseConv2D *node)
- {
- LOGGER(l);
- INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
- auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
- if (!is_quantized(weights))
- {
- auto new_weights = luci::clone(weights);
- node->filter(new_weights);
- quantize_weights(new_weights);
- return true;
- }
- return false;
+ insert_out_quantize(node);
}
- bool visit(luci::CircleInstanceNorm *node)
+ // Concat has arbitrary number of inputs
+ void visit(luci::CircleConcatenation *node)
{
- LOGGER(l);
- INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
-
- auto gamma = loco::must_cast<luci::CircleConst *>(node->gamma());
- auto beta = loco::must_cast<luci::CircleConst *>(node->beta());
-
- bool changed = false;
- if (!is_quantized(gamma))
- {
- assert(gamma->dtype() == loco::DataType::FLOAT32);
- auto new_gamma = luci::clone(gamma);
- if (granularity == QuantizationGranularity::LayerWise)
- quant_const(new_gamma, output_type);
- else if (granularity == QuantizationGranularity::ChannelWise)
- quant_const_per_channel(new_gamma, output_type);
- node->gamma(new_gamma);
- changed = true;
- }
- if (!is_quantized(beta))
- {
- assert(beta->dtype() == loco::DataType::FLOAT32);
- auto new_beta = luci::clone(beta);
- if (granularity == QuantizationGranularity::LayerWise)
- quant_const(new_beta, output_type);
- else if (granularity == QuantizationGranularity::ChannelWise)
- quant_const_per_channel(new_beta, output_type);
- node->beta(new_beta);
- changed = true;
- }
-
- return changed;
- }
-
- bool visit(luci::CirclePRelu *node)
- {
- LOGGER(l);
- INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
-
- auto alpha = loco::must_cast<luci::CircleConst *>(node->alpha());
-
- if (!is_quantized(alpha))
+ auto arity = node->arity();
+ for (uint32_t i = 0; i < arity; i++)
{
- assert(alpha->dtype() == loco::DataType::FLOAT32);
- auto new_alpha = luci::clone(alpha);
- if (granularity == QuantizationGranularity::LayerWise)
- quant_const(new_alpha, output_type);
- else if (granularity == QuantizationGranularity::ChannelWise)
- quant_const_per_channel(new_alpha, output_type);
- node->alpha(new_alpha);
- return true;
+ if (auto input_quant = create_in_quantize(node->values(i), node))
+ node->values(i, input_quant);
}
- return false;
+ insert_out_quantize(node);
}
- bool visit(luci::CircleTransposeConv *node)
+ // Pack has arbitrary number of inputs
+ void visit(luci::CirclePack *node)
{
- LOGGER(l);
- INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
-
- auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
- if (!is_quantized(weights))
+ auto arity = node->arity();
+ for (uint32_t i = 0; i < arity; i++)
{
- auto new_weights = luci::clone(weights);
- node->filter(new_weights);
- quantize_weights(new_weights);
- return true;
+ if (auto input_quant = create_in_quantize(node->values(i), node))
+ node->values(i, input_quant);
}
- return false;
- }
-
- bool visit(luci::CircleFullyConnected *node)
- {
- LOGGER(l);
- INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
- auto weights = loco::must_cast<luci::CircleConst *>(node->weights());
- if (!is_quantized(weights))
- {
- auto new_weights = luci::clone(weights);
- node->weights(new_weights);
- quantize_weights(new_weights);
- return true;
- }
- return false;
+ insert_out_quantize(node);
}
- bool visit(luci::CircleNode *) { return false; }
+#undef INSERT_QUANTIZE_TO_UNARY_OP
+#undef INSERT_QUANTIZE_TO_BINARY_OP
+#undef INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP
};
-/** EXAMPLE
- *
- * BEFORE
- *
- * [CircleNode] [CircleConst]
- * (qparam1) (FP32)
- * \ /
- * \ /
- * [CirclePack]
- * (qparam2)
- *
- * AFTER
- *
- * [CircleNode] [CircleConst] [CircleConst] <- Dead node
- * (qparam2) (qparam2) (FP32)
- * \ /
- * \ /
- * [CirclePack]
- * (qparam2)
- *
- * NOTE Quantization parameter of CirclePack (qparam2) is propagated to the inputs.
- */
-void propagate_pack_quantparam(luci::CirclePack *pack, loco::DataType quant_type)
-{
- assert(pack->quantparam() != nullptr);
-
- const auto num_inputs = pack->values_count();
-
- for (uint32_t i = 0; i < num_inputs; i++)
- {
- auto node = loco::must_cast<luci::CircleNode *>(pack->arg(i));
-
- // Skip if this input is PACK Op
- if (node->opcode() == luci::CircleOpcode::PACK)
- continue;
-
- // Quantize constant values
- if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
- {
- luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
- if (const_node->dtype() != loco::DataType::FLOAT32)
- throw std::runtime_error("Unsupported data type for constant input of pack Op");
-
- const auto pack_qparam = pack->quantparam();
- if (pack_qparam == nullptr)
- throw std::runtime_error("quantparam of pack is not found during propagation");
-
- assert(pack_qparam->scale.size() == 1);
- assert(pack_qparam->zerop.size() == 1);
- const auto scaling_factor = pack_qparam->scale[0];
- const auto zerop = pack_qparam->zerop[0];
-
- auto new_const = luci::clone(const_node);
- quant_const_values(new_const, scaling_factor, zerop, quant_type);
- pack->values(i, new_const);
- overwrite_quantparam(pack, new_const);
- }
- else
- {
- const auto succs = loco::succs(node);
- if (succs.size() > 1)
- continue;
-
- // Non-const input must have been quantized
- assert(node->quantparam() != nullptr);
- overwrite_quantparam(pack, node);
- }
- }
-}
-
-/**
- * @brief Quantize const input tensors using min/max of const values
- */
-void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type)
-{
- auto opcode = node->opcode();
- auto arity = node->arity();
-
- loco::Node *input_node{nullptr};
- luci::CircleConst *const_node{nullptr};
-
- switch (opcode)
- {
- case luci::CircleOpcode::CONV_2D:
- case luci::CircleOpcode::DEPTHWISE_CONV_2D:
- case luci::CircleOpcode::FULLY_CONNECTED:
- case luci::CircleOpcode::INSTANCE_NORM:
- case luci::CircleOpcode::PRELU:
- case luci::CircleOpcode::TRANSPOSE_CONV:
- // Handled in QuantizeWeights and QuantizeBias
- break;
-
- case luci::CircleOpcode::CONCATENATION:
- // Handled in propagate_concat_quantparam
- break;
-
- case luci::CircleOpcode::LOGICAL_OR:
- // Inputs of logical Ops are bool, thus not quantized
- break;
-
- case luci::CircleOpcode::ARG_MAX:
- case luci::CircleOpcode::ARG_MIN:
- case luci::CircleOpcode::BATCH_TO_SPACE_ND:
- case luci::CircleOpcode::LOCAL_RESPONSE_NORMALIZATION:
- case luci::CircleOpcode::MEAN:
- case luci::CircleOpcode::MIRROR_PAD:
- case luci::CircleOpcode::PAD:
- case luci::CircleOpcode::REDUCE_ANY:
- case luci::CircleOpcode::REDUCE_PROD:
- case luci::CircleOpcode::REDUCE_MAX:
- case luci::CircleOpcode::REDUCE_MIN:
- case luci::CircleOpcode::RESHAPE:
- case luci::CircleOpcode::RESIZE_BILINEAR:
- case luci::CircleOpcode::RESIZE_NEAREST_NEIGHBOR:
- case luci::CircleOpcode::REVERSE_SEQUENCE:
- case luci::CircleOpcode::SLICE:
- case luci::CircleOpcode::SPACE_TO_BATCH_ND:
- case luci::CircleOpcode::SPLIT_V:
- case luci::CircleOpcode::STRIDED_SLICE:
- case luci::CircleOpcode::SUM:
- case luci::CircleOpcode::TILE:
- case luci::CircleOpcode::TOPK_V2:
- case luci::CircleOpcode::TRANSPOSE:
- // The second input of these Ops should not be quantized
- // Ex: axis, paddings
- input_node = node->arg(0);
- const_node = dynamic_cast<luci::CircleConst *>(input_node);
- if (const_node != nullptr && !is_quantized(const_node))
- quant_const(const_node, output_type);
- break;
-
- case luci::CircleOpcode::ADD:
- case luci::CircleOpcode::ADD_N:
- case luci::CircleOpcode::DEPTH_TO_SPACE:
- case luci::CircleOpcode::DIV:
- case luci::CircleOpcode::ELU:
- case luci::CircleOpcode::EQUAL:
- case luci::CircleOpcode::EXP:
- case luci::CircleOpcode::FLOOR:
- case luci::CircleOpcode::FLOOR_DIV:
- case luci::CircleOpcode::GREATER:
- case luci::CircleOpcode::GREATER_EQUAL:
- case luci::CircleOpcode::LESS:
- case luci::CircleOpcode::LESS_EQUAL:
- case luci::CircleOpcode::LOGISTIC:
- case luci::CircleOpcode::MAXIMUM:
- case luci::CircleOpcode::MINIMUM:
- case luci::CircleOpcode::MUL:
- case luci::CircleOpcode::NOT_EQUAL:
- case luci::CircleOpcode::POW:
- case luci::CircleOpcode::RSQRT:
- case luci::CircleOpcode::SOFTMAX:
- case luci::CircleOpcode::SPACE_TO_DEPTH:
- case luci::CircleOpcode::SQRT:
- case luci::CircleOpcode::SUB:
- case luci::CircleOpcode::TANH:
- case luci::CircleOpcode::UNPACK:
- // Quantize all const inputs using their values
- for (uint32_t i = 0; i < arity; i++)
- {
- input_node = node->arg(i);
- const_node = dynamic_cast<luci::CircleConst *>(input_node);
- if (const_node != nullptr && !is_quantized(const_node))
- quant_const(const_node, output_type);
- }
- break;
-
- case luci::CircleOpcode::SPLIT:
- // Only the second input is quantized
- // First input should not be quantized (e.g., split_dim)
- input_node = node->arg(1);
- const_node = dynamic_cast<luci::CircleConst *>(input_node);
- if (const_node != nullptr && !is_quantized(const_node))
- quant_const(const_node, output_type);
- break;
-
- case luci::CircleOpcode::PADV2:
- // First and third constant inputs are quantized
- // Second input should not be quantized (e.g., paddings)
- // Quant params are propagated either from output range to the non-constant input
- // or from input to output and constant values
- propagate_pad_v2_quantparam(loco::must_cast<CirclePadV2 *>(node), output_type);
- break;
-
- case luci::CircleOpcode::PACK:
- // Quant param is propagated from output to inputs
- propagate_pack_quantparam(loco::must_cast<CirclePack *>(node), output_type);
- break;
-
- default:
- for (uint32_t i = 0; i < arity; i++)
- {
- input_node = node->arg(i);
- const_node = dynamic_cast<luci::CircleConst *>(input_node);
- if (const_node != nullptr)
- throw std::runtime_error("Unsupported Op for const inputs");
- }
- break;
- }
-}
-
} // namespace
-/** BEFORE
- *
- * [CircleNode] [CircleConst]
- * (U8 qparam1) (FP32)
- * \ /
- * \ /
- * [CircleConcatenation]
- * (U8 qparam2)
- *
- * AFTER
- * [CircleNode] [CircleConst] [CircleConst] <- Dead node
- * (U8 qparam2) (U8 qparam2) (FP32)
- * \ /
- * \ /
- * [CircleConcatenation]
- * (U8 qparam2)
- */
-void propagate_concat_quantparam(luci::CircleConcatenation *concat, loco::DataType quant_type)
-{
- assert(concat->quantparam() != nullptr);
-
- const auto num_inputs = concat->numValues();
-
- // Quantize const inputs using their values if concat has fused act function
- if (concat->fusedActivationFunction() != luci::FusedActFunc::NONE)
- {
- for (uint32_t i = 0; i < num_inputs; i++)
- {
- auto node = concat->arg(i);
- auto const_node = dynamic_cast<luci::CircleConst *>(node);
- if (const_node != nullptr)
- {
- auto new_const = luci::clone(const_node);
- quant_const(new_const, quant_type);
- concat->values(i, new_const);
- }
- }
- return;
- }
-
- for (uint32_t i = 0; i < num_inputs; i++)
- {
- auto node = loco::must_cast<luci::CircleNode *>(concat->arg(i));
-
- // Skip if this input is CONCAT Op
- if (node->opcode() == luci::CircleOpcode::CONCATENATION)
- continue;
-
- // Quantize constant values
- if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
- {
- luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
- if (const_node->dtype() != loco::DataType::FLOAT32)
- throw std::runtime_error("Unsupported data type for constant input of concatenation Op");
-
- const auto concat_qparam = concat->quantparam();
- if (concat_qparam == nullptr)
- throw std::runtime_error("quantparam of concat is not found during propagation");
-
- assert(concat_qparam->scale.size() == 1);
- const auto scaling_factor = concat_qparam->scale[0];
- const auto zerop = concat_qparam->zerop[0];
-
- auto new_const = luci::clone(const_node);
- quant_const_values(new_const, scaling_factor, zerop, quant_type);
- concat->values(i, new_const);
- overwrite_quantparam(concat, new_const);
- }
- else
- {
- const auto succs = loco::succs(node);
- if (succs.size() > 1)
- continue;
-
- // Non-const input must have been quantized
- assert(node->quantparam() != nullptr);
- overwrite_quantparam(concat, node);
- }
- }
-}
-
-/**
- * tells if pad_v2 quantization should ignore padding value
- * In that case padding const will be quantized with input parameters, and probably clipped
- */
-bool ignore_pad_v2_const_quantization(luci::CirclePadV2 *pad)
-{
- // This is a workaround to quantize pad generated from MaxPoolWithArgmax operation properly
- // TODO use metadata hints to detect this case
- auto const_value_node = dynamic_cast<luci::CircleConst *>(pad->arg(2));
- if (!const_value_node)
- return false;
- if (const_value_node->dtype() == loco::DataType::FLOAT32)
- {
- float const_value = const_value_node->at<loco::DataType::FLOAT32>(0);
- if (const_value == std::numeric_limits<float>::lowest())
- return true;
- }
- return false;
-}
-
-/** BEFORE
- *
- * [CircleNode] [CircleConst] [CircleConst]
- * (U8 qparam1) (S32) (FP32)
- * \ | /
- * \ | /
- * [CirclePadV2]
- * (U8 qparam2)
- *
- * AFTER (case 1)
- *
- * By default qparam is propagated from output to inputs to meet backend requirements.
- *
- * [CircleNode] [CircleConst] [CircleConst] [CircleConst] <- Dead node
- * (U8 qparam2) (S32) (U8 qparam2) (FP32)
- * \ | /
- * \ | /
- * [CirclePadV2]
- * (U8 qparam2)
- *
- * AFTER (case 2)
- *
- * In case padded value is the lowest float value
- * Qparam is propagated from input to output and constant.
- *
- * This is a special case for optimization constructed pad, needed to guarantee that
- * extremely large negative constant do not stretch output quantization range.
- *
- * [CircleNode] [CircleConst] [CircleConst] [CircleConst] <- Dead node
- * (U8 qparam1) (S32) (U8 qparam1) (FP32)
- * \ | /
- * \ | /
- * [CirclePadV2]
- * (U8 qparam1)
- */
-void propagate_pad_v2_quantparam(luci::CirclePadV2 *pad_v2, loco::DataType quant_type)
-{
- if (ignore_pad_v2_const_quantization(pad_v2))
- {
- // propagate input quantization paramters from input to output and padding const value
- auto pad_v2_input = loco::must_cast<luci::CircleNode *>(pad_v2->arg(0));
- overwrite_quantparam(pad_v2_input, pad_v2);
-
- auto const_value_node = loco::must_cast<luci::CircleConst *>(
- pad_v2->arg(2)); // FIX ignore_pad_v2_const_quantization UNLESS
- auto new_const = luci::clone(const_value_node);
-
- const auto pad_v2_input_qparam = pad_v2_input->quantparam();
- assert(pad_v2_input_qparam != nullptr);
- assert(pad_v2_input_qparam->scale.size() == 1);
- const auto scaling_factor = pad_v2_input_qparam->scale.at(0);
- const auto zerop = pad_v2_input_qparam->zerop.at(0);
-
- quant_const_values(new_const, scaling_factor, zerop, quant_type);
- overwrite_quantparam(pad_v2_input, new_const);
- pad_v2->constant_values(new_const);
- return;
- }
-
- // Propagate quantization paramters from output to inputs,
- // to fit both input and counstant_value in one quant range.
- auto quant_input = [pad_v2, quant_type](void (CirclePadV2::*arg_setter)(loco::Node *),
- uint32_t arg) {
- auto node = loco::must_cast<luci::CircleNode *>(pad_v2->arg(arg));
-
- // Quantize constant values
- if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
- {
- luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
- if (is_quantized(const_node))
- return;
-
- if (const_node->dtype() != loco::DataType::FLOAT32)
- throw std::runtime_error("Unsupported data type for constant input of PadV2 Op");
-
- const auto pad_v2_qparam = pad_v2->quantparam();
- if (pad_v2_qparam == nullptr)
- throw std::runtime_error("quantparam of PadV2 is not found during propagation");
-
- assert(pad_v2_qparam->scale.size() == 1);
- const auto scaling_factor = pad_v2_qparam->scale.at(0);
- const auto zerop = pad_v2_qparam->zerop.at(0);
-
- auto new_const = luci::clone(const_node);
- quant_const_values(new_const, scaling_factor, zerop, quant_type);
- overwrite_quantparam(pad_v2, new_const);
- (pad_v2->*arg_setter)(new_const);
- }
- // Subsequent PadV2 Ops quant params are not propagated
- else if (node->opcode() == luci::CircleOpcode::PADV2)
- {
- return;
- }
- else
- {
- const auto succs = loco::succs(node);
- if (succs.size() > 1)
- return;
-
- // Non-const input must have been quantized
- assert(node->quantparam() != nullptr);
- overwrite_quantparam(pad_v2, node);
- }
- };
-
- quant_input(&CirclePadV2::input, 0);
- quant_input(&CirclePadV2::constant_values, 2);
-}
-
void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const
{
auto inputs = g->inputs();
for (auto node : loco::input_nodes(g))
{
auto input = loco::must_cast<luci::CircleInput *>(node);
- if (input->dtype() == _input_type)
+ if (input->dtype() == _ctx->input_type)
continue;
// Bool type is not quantizable
if (input->dtype() == loco::DataType::BOOL)
continue;
+ if (input->dtype() == loco::DataType::S32)
+ continue;
+ if (input->dtype() == loco::DataType::S64)
+ continue;
// Insert Quantize Op
auto quant_op = create_quantize_op(input, input->dtype());
@@ -1552,22 +367,22 @@ void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const
float nudged_min{0};
float nudged_max{0};
- if (_input_type == loco::DataType::U8)
+ if (_ctx->input_type == loco::DataType::U8)
{
compute_asym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
}
else
{
- assert(_input_type == loco::DataType::S16);
+ assert(_ctx->input_type == loco::DataType::S16);
compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
}
- input->dtype(_input_type);
+ input->dtype(_ctx->input_type);
input->quantparam()->scale[0] = scaling_factor;
input->quantparam()->zerop[0] = zp;
}
auto graph_input = inputs->at(input->index());
- graph_input->dtype(_input_type);
+ graph_input->dtype(_ctx->input_type);
}
}
@@ -1577,7 +392,7 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const
for (auto node : loco::output_nodes(g))
{
auto output = loco::must_cast<luci::CircleOutput *>(node);
- if (output->dtype() == _output_type)
+ if (output->dtype() == _ctx->output_type)
continue;
// Bool type is not quantizable
@@ -1591,7 +406,7 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const
continue;
// Insert Quantize Op
- auto quant_op = create_quantize_op(from, _output_type);
+ auto quant_op = create_quantize_op(from, _ctx->output_type);
loco::replace(from).with(quant_op);
quant_op->input(from);
@@ -1599,67 +414,165 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const
luci::add_origin(quant_op, luci::get_origin(from));
auto graph_output = outputs->at(output->index());
- graph_output->dtype(_output_type);
+ graph_output->dtype(_ctx->output_type);
}
}
+/**
+ * How QuantizeWithMinMax works?
+ *
+ * We categorized tensors into four groups
+ * - Activation: Feature maps (both Const/Non-const)
+ * - Weights: Const tensors of specific Ops (Conv, FC, ...)
+ * - Bias: Const tensors of specific Ops (Conv, FC, ...)
+ * - Others: padding value, one_hot value, axis, ..
+ *
+ * Activation is quantized in different ways
+ * 1. For non-constant activation, quantize using recorded min/max
+ * 2. For constant activation, quantize using min/max of its value
+ * 3. For some Ops (ex: pad_v2), output qparam is used as input qparam (backward propagation)
+ * 4. For some Ops (ex: reshape), input qparam is used as output qparam (forward propagation)
+ * 5. For some Ops (ex: tanh), output qparam has pre-defined values
+ *
+ * Weights is quantized using min/max of its value
+ *
+ * Bias is quantized using input scale (s_i) and weights scale (s_w)
+ * - Activation and weights should be quantized earlier than bias
+ *
+ * Quantization Steps
+ * 1. Quantize Activation
+ * - Quantize using recorded min/max (QuantizeActivation)
+ * - Insert Quantize Ops for mixed-precision quantization (InsertQuantizeOp)
+ * - Remove redundant Quantize Ops (RemoveRedundantQuantizePass)
+ * - Propagate qparam backward (PropagateQParamBackwardPass)
+ * - Quantize const inputs (QuantizeConstInputActivation)
+ * - Quantize using pre-defined values (QuantizeSpecialActivation)
+ * - Propagate qparam forward (PropagateQParamForwardPass)
+ * 2. Quantize Weights
+ * 3. Quantize Bias
+ * 4. Set input dtype
+ * 5. Set output dtype
+ *
+ * Why quantization sequence was determined as above?
+ * - Activation and weights should be quantized before bias (1->2->3). Input/Output
+ * dtype can be updated at the end (4->5).
+ * - During activation quantization,
+ * - Backward propagation is performed earlier than forward propagation. This allows
+ * backward-propagated qpram to be overwritten during forward propagation.
+ * We made this decision as Ops for forward propagation (reshape, transpose, ..)
+ * are more common than backward propagation. TODO Check this decision is safe.
+ * - QuantizeSpecialActivation is called before forward propagation to make sure that
+ * the pre-defined qparam values are propagated.
+ */
bool QuantizeWithMinMaxPass::run(loco::Graph *g)
{
LOGGER(l);
INFO(l) << "QuantizeWithMinMaxPass Start" << std::endl;
+ auto info_by_name = layer_info_map(g, _ctx->layers_info);
+
+ auto quantize_dtype = [&](const luci::CircleNode *node) {
+ auto iter = info_by_name.find(node->name());
+
+ // Return designated quantization dtype
+ if (iter != info_by_name.end())
+ return iter->second.dtype;
+
+ // Return default quantization dtype
+ return _ctx->output_model_dtype;
+ };
+
+ auto quantize_granularity = [&](const luci::CircleNode *node) {
+ auto iter = info_by_name.find(node->name());
+
+ // Return designated quantization granularity
+ if (iter != info_by_name.end())
+ return iter->second.granularity;
+
+ // Return default quantization granularity
+ return _ctx->granularity;
+ };
+
// Quantize activation
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
- QuantizeActivation qa(_input_model_dtype, _output_model_dtype);
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ QuantizeActivation qa(_ctx->input_model_dtype, quantize_dtype(circle_node));
circle_node->accept(&qa);
}
- // Quantize weights
+ // Insert Quantize Op
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
- QuantizeWeights qw(_input_model_dtype, _output_model_dtype, _granularity);
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- circle_node->accept(&qw);
+ auto op_dtype = quantize_dtype(circle_node);
+ if (op_dtype != _ctx->output_model_dtype)
+ {
+ InsertQuantizeOp iqo(_ctx->output_model_dtype, op_dtype);
+ circle_node->accept(&iqo);
+ }
}
- // Quantize bias
+ // Remove redundant Quantize Op
+ {
+ logo::Phase phase;
+
+ phase.emplace_back(std::make_unique<luci::RemoveRedundantQuantizePass>());
+
+ ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
+ logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
+ phase_runner.attach(&prog);
+ phase_runner.run(phase);
+ }
+
+ // Backward propagation of activation qparam
+ {
+ PropagateQParamBackwardPass pqbp(_ctx->output_model_dtype);
+ pqbp.run(g);
+ }
+
+ // Quantize const input activation
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
- QuantizeBias qb(_input_model_dtype, _output_model_dtype, _granularity);
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- circle_node->accept(&qb);
+ QuantizeConstInputActivation qcia(quantize_dtype(circle_node));
+ circle_node->accept(&qcia);
}
- // Propagate quantization parameters of concat Op
+ // Update qparam of output of special Ops
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
- auto concat = dynamic_cast<luci::CircleConcatenation *>(node);
- if (not concat)
- continue;
-
- // Propagate qparam of concat to its inputs if
- // (1) concat is uint8-quantized
- // (2) concat has no fused activation function
- // (3) the input is not concatenation Op
- // (4) the input is not produced to Ops other than concat
- propagate_concat_quantparam(concat, _output_model_dtype);
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ QuantizeSpecialActivation qsa(_ctx->input_model_dtype, quantize_dtype(circle_node));
+ circle_node->accept(&qsa);
}
- // Quantize const inputs other than weights and bias
+ // Forward propagation of activation qparam
+ logo::Phase phase;
+
+ phase.emplace_back(std::make_unique<luci::PropagateQParamForwardPass>(_ctx->TF_style_maxpool));
+
+ ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
+ logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
+ phase_runner.attach(&prog);
+ phase_runner.run(phase);
+
+ // Quantize weights
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- quantize_const_inputs(circle_node, _output_model_dtype);
+ QuantizeWeights qw(_ctx->input_model_dtype, quantize_dtype(circle_node),
+ quantize_granularity(circle_node));
+ circle_node->accept(&qw);
}
- // Update qparam of output of special Ops
+ // Quantize bias
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
- QuantizeSpecialActivation qsa(_input_model_dtype, _output_model_dtype);
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- circle_node->accept(&qsa);
+ QuantizeBias qb(_ctx->input_model_dtype, quantize_dtype(circle_node),
+ quantize_granularity(circle_node));
+ circle_node->accept(&qb);
}
// Update output dtype
@@ -1667,11 +580,11 @@ bool QuantizeWithMinMaxPass::run(loco::Graph *g)
for (auto node : loco::output_nodes(g))
{
auto circle_node = loco::must_cast<luci::CircleOutput *>(node);
- if (static_cast<luci::CircleNode *>(circle_node->from())->dtype() == _output_model_dtype)
+ if (static_cast<luci::CircleNode *>(circle_node->from())->dtype() == _ctx->output_model_dtype)
{
- circle_node->dtype(_output_model_dtype);
+ circle_node->dtype(_ctx->output_model_dtype);
auto graph_output = graph_outputs->at(circle_node->index());
- graph_output->dtype(_output_model_dtype);
+ graph_output->dtype(_ctx->output_model_dtype);
}
}
diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp
index 75ec0cfd8..d5fa21ffd 100644
--- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp
+++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp
@@ -16,8 +16,41 @@
#include "luci/Pass/QuantizeWithMinMaxPass.h"
+#include <luci/IR/CircleNodes.h>
+
#include <gtest/gtest.h>
+class SimpleConcatGraph
+{
+public:
+ SimpleConcatGraph(loco::DataType quant_type)
+ {
+ concat_node = g.nodes()->create<luci::CircleConcatenation>(2);
+ input_1 = g.nodes()->create<luci::CircleConst>();
+ input_2 = g.nodes()->create<luci::CircleConst>();
+
+ concat_node->dtype(quant_type);
+ concat_node->fusedActivationFunction(luci::FusedActFunc::NONE);
+ input_1->dtype(quant_type);
+ input_2->dtype(quant_type);
+
+ concat_node->values(0, input_1);
+ concat_node->values(1, input_2);
+ }
+
+ ~SimpleConcatGraph()
+ {
+ concat_node->values(0, nullptr);
+ concat_node->values(1, nullptr);
+ }
+
+public:
+ loco::Graph g;
+ luci::CircleConcatenation *concat_node = nullptr;
+ luci::CircleConst *input_1 = nullptr;
+ luci::CircleConst *input_2 = nullptr;
+};
+
TEST(QuantizeWithMinMaxPassTest, name)
{
luci::QuantizeWithMinMaxPass pass(loco::DataType::FLOAT32, loco::DataType::U8,
@@ -25,3 +58,19 @@ TEST(QuantizeWithMinMaxPassTest, name)
auto const name = pass.name();
ASSERT_NE(nullptr, name);
}
+
+// Test concat of integer tensors
+// Integer tensors are not quantized
+TEST(QuantizeWithMinMaxPassTest, int_concat)
+{
+ SimpleConcatGraph g(loco::DataType::S32);
+
+ luci::QuantizeWithMinMaxPass qwmm(loco::DataType::FLOAT32, loco::DataType::U8,
+ luci::QuantizationGranularity::LayerWise);
+
+ qwmm.run(&g.g);
+
+ EXPECT_EQ(nullptr, g.concat_node->quantparam());
+ EXPECT_EQ(nullptr, g.input_1->quantparam());
+ EXPECT_EQ(nullptr, g.input_2->quantparam());
+}
diff --git a/compiler/luci/pass/src/QuantizedModelVerifier.cpp b/compiler/luci/pass/src/QuantizedModelVerifier.cpp
index f02301ed1..684d5d48a 100644
--- a/compiler/luci/pass/src/QuantizedModelVerifier.cpp
+++ b/compiler/luci/pass/src/QuantizedModelVerifier.cpp
@@ -15,10 +15,10 @@
#include "QuantizedModelVerifier.h"
-#include "VerifyQuantizedNodeLayerWiseGranularity.h"
-#include "VerifyQuantizedNodeChannelWiseGranularity.h"
-#include "VerifyQuantizedNodeU8Type.h"
-#include "VerifyQuantizedNodeS16Type.h"
+#include "VerifyQuantizedNodeGranularity.h"
+#include "VerifyQuantizedNodeType.h"
+#include "VerifyQuantizedBiasScale.h"
+#include "helpers/LayerInfoMap.h"
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
@@ -28,12 +28,33 @@ namespace luci
void QuantizedModelVerifier::verify(loco::Graph *g)
{
- if (_quantized_dtype != Type::U8 && _quantized_dtype != Type::S16)
- throw std::runtime_error("Unsupported quantized dtype");
-
- if (_granularity != Granularity::ChannelWise && _granularity != Granularity::LayerWise)
+ if (_ctx->granularity != Granularity::ChannelWise && _ctx->granularity != Granularity::LayerWise)
throw std::runtime_error("Unsupported granularity");
+ auto info_by_name = layer_info_map(g, _ctx->layers_info);
+
+ auto quantize_dtype = [&](const luci::CircleNode *node) {
+ auto iter = info_by_name.find(node->name());
+
+ // Return designated quantization dtype
+ if (iter != info_by_name.end())
+ return iter->second.dtype;
+
+ // Return default quantization dtype
+ return _ctx->output_model_dtype;
+ };
+
+ auto quantize_granularity = [&](const luci::CircleNode *node) {
+ auto iter = info_by_name.find(node->name());
+
+ // Return designated quantization granularity
+ if (iter != info_by_name.end())
+ return iter->second.granularity;
+
+ // Return default quantization granularity
+ return _ctx->granularity;
+ };
+
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
@@ -46,32 +67,17 @@ void QuantizedModelVerifier::verify(loco::Graph *g)
};
// Verify Type
- if (_quantized_dtype == Type::U8)
- {
- VerifyQuantizedNodeU8Type vt;
- if (!circle_node->accept(&vt))
- throw std::runtime_error("Wrong data type detected in " + node_name());
- }
- else if (_quantized_dtype == Type::S16)
- {
- VerifyQuantizedNodeS16Type vt;
- if (!circle_node->accept(&vt))
- throw std::runtime_error("Wrong data type detected in " + node_name());
- }
+ if (!VerifyQuantizedNodeType::create(quantize_dtype(circle_node))->verify(circle_node))
+ throw std::runtime_error("Wrong data type detected in " + node_name());
// Verify Granularity
- if (_granularity == Granularity::LayerWise)
- {
- VerifyQuantizedNodeLayerWiseGranularity vg;
- if (!circle_node->accept(&vg))
- throw std::runtime_error("Wrong granularity detected in " + node_name());
- }
- else if (_granularity == Granularity::ChannelWise)
- {
- VerifyQuantizedNodeChannelWiseGranularity vg;
- if (!circle_node->accept(&vg))
- throw std::runtime_error("Wrong granularity detected in " + node_name());
- }
+ if (!circle_node->accept(
+ VerifyQuantizedNodeGranularity::create(quantize_granularity(circle_node)).get()))
+ throw std::runtime_error("Wrong granularity detected in " + node_name());
+
+ // Verify Bias scale
+ if (!VerifyQuantizedBiasScale::create()->verify(circle_node))
+ throw std::runtime_error("Wrong bias scale detected in " + node_name());
}
}
diff --git a/compiler/luci/pass/src/QuantizedModelVerifier.h b/compiler/luci/pass/src/QuantizedModelVerifier.h
index d5fbb8e74..7409a51d7 100644
--- a/compiler/luci/pass/src/QuantizedModelVerifier.h
+++ b/compiler/luci/pass/src/QuantizedModelVerifier.h
@@ -21,6 +21,8 @@
#include <loco.h>
+#include <memory>
+
namespace luci
{
@@ -31,18 +33,40 @@ namespace luci
*/
struct QuantizedModelVerifier
{
+public:
+ struct Context
+ {
+ loco::DataType output_model_dtype = loco::DataType::Unknown;
+ QuantizationGranularity granularity = QuantizationGranularity::ChannelWise;
+ loco::DataType input_type = loco::DataType::Unknown;
+ loco::DataType output_type = loco::DataType::Unknown;
+ bool TF_style_maxpool = false;
+ std::vector<LayerInfo> layers_info;
+ };
public:
QuantizedModelVerifier(loco::DataType quantized_dtype, QuantizationGranularity granularity)
- : _quantized_dtype(quantized_dtype), _granularity(granularity)
{
+ _ctx = std::make_unique<Context>();
+ {
+ _ctx->output_model_dtype = quantized_dtype;
+ _ctx->granularity = granularity;
+ _ctx->input_type = quantized_dtype;
+ _ctx->output_type = quantized_dtype;
+ _ctx->TF_style_maxpool = false;
+ }
+ }
+
+public:
+ QuantizedModelVerifier(std::unique_ptr<Context> &&ctx) : _ctx{std::move(ctx)}
+ {
+ // DO NOTHING
}
void verify(loco::Graph *g);
private:
- loco::DataType _quantized_dtype;
- QuantizationGranularity _granularity;
+ std::unique_ptr<Context> _ctx;
};
} // namespace luci
diff --git a/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp b/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp
index 3a6d86c33..cebafd32b 100644
--- a/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp
+++ b/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp
@@ -17,6 +17,7 @@
#include "QuantizedModelVerifier.h"
#include "luci/Pass/QuantizeWithMinMaxPass.h"
+#include "luci/Pass/QuantizationParameters.h"
#include <luci/test/TestIOGraph.h>
@@ -112,57 +113,77 @@ void quantize_and_verify(loco::Graph *g, Type quantized_dtype, Granularity granu
verifier.verify(g);
}
-// Helper function to reduce duplicate test codes
-// Assumption: g->output()->from() is the target node
-void quantize_and_verify_with_wrong_type(luci::test::TestIOGraph *g, Type quantized_dtype,
- Granularity granularity, Type wrong_dtype)
+void quantize_and_verify_with_layer_info(loco::Graph *g, Type quantized_dtype,
+ Granularity granularity)
{
- luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity);
- pass.run(g->g());
-
- auto node = loco::must_cast<luci::CircleNode *>(g->output()->from());
- node->dtype(wrong_dtype);
+ // A layer named "test" has dtype different from quantized_dtype
+ luci::LayerInfo info;
+ {
+ info.name = "test";
+ // dtype is different from quantized_dtype
+ info.dtype = quantized_dtype == Type::U8 ? Type::S16 : Type::U8;
+ info.granularity = Granularity::ChannelWise;
+ }
- luci::QuantizedModelVerifier verifier(quantized_dtype, granularity);
- verifier.verify(g->g());
-}
+ // Do quantization
+ {
+ auto ctx = std::make_unique<luci::QuantizeWithMinMaxPass::Context>();
+ {
+ ctx->input_model_dtype = Type::FLOAT32;
+ ctx->output_model_dtype = quantized_dtype;
+ ctx->granularity = granularity;
+ ctx->input_type = quantized_dtype;
+ ctx->output_type = quantized_dtype;
+ ctx->TF_style_maxpool = false;
+ ctx->layers_info.push_back(info);
+ }
-void quantize_and_verify_with_wrong_type(luci::test::TestIOGraph *g, Type quantized_dtype,
- Granularity granularity, Type wrong_dtype,
- luci::CircleNode *target)
-{
- luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity);
- pass.run(g->g());
+ luci::QuantizeWithMinMaxPass pass(std::move(ctx));
+ pass.run(g);
+ }
- target->dtype(wrong_dtype);
+ // Do verification
+ {
+ auto ctx = std::make_unique<luci::QuantizedModelVerifier::Context>();
+ {
+ ctx->output_model_dtype = quantized_dtype;
+ ctx->granularity = granularity;
+ ctx->input_type = quantized_dtype;
+ ctx->output_type = quantized_dtype;
+ ctx->TF_style_maxpool = false;
+ ctx->layers_info.push_back(info);
+ }
- luci::QuantizedModelVerifier verifier(quantized_dtype, granularity);
- verifier.verify(g->g());
+ luci::QuantizedModelVerifier verifier(std::move(ctx));
+ verifier.verify(g);
+ }
}
// Helper function to reduce duplicate test codes
// Assumption: g->output()->from() is the target node
-void quantize_and_verify_with_wrong_granularity(luci::test::TestIOGraph *g, Type quantized_dtype,
- Granularity granularity)
+void quantize_and_verify_with_wrong_type(luci::test::TestIOGraph *g, Type quantized_dtype,
+ Granularity granularity, Type wrong_dtype)
{
luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity);
pass.run(g->g());
auto node = loco::must_cast<luci::CircleNode *>(g->output()->from());
- insert_scale_zp(node, 1.0, 1);
+ node->dtype(wrong_dtype);
luci::QuantizedModelVerifier verifier(quantized_dtype, granularity);
verifier.verify(g->g());
}
// Helper function to reduce duplicate test codes
+// Assumption: g->output()->from() is the target node
void quantize_and_verify_with_wrong_granularity(luci::test::TestIOGraph *g, Type quantized_dtype,
- Granularity granularity, luci::CircleNode *target)
+ Granularity granularity)
{
luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity);
pass.run(g->g());
- insert_scale_zp(target, 1.0, 1);
+ auto node = loco::must_cast<luci::CircleNode *>(g->output()->from());
+ insert_scale_zp(node, 1.0, 1);
luci::QuantizedModelVerifier verifier(quantized_dtype, granularity);
verifier.verify(g->g());
@@ -230,6 +251,8 @@ public:
_instnorm->input(input());
_instnorm->gamma(_gamma);
_instnorm->beta(_beta);
+ _instnorm->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _instnorm->name("test");
}
output()->from(_instnorm);
@@ -256,6 +279,7 @@ public:
_logistic = g()->nodes()->create<luci::CircleLogistic>();
{
_logistic->x(input());
+ _logistic->name("test");
}
output()->from(_logistic);
@@ -275,6 +299,7 @@ public:
_lrn = g()->nodes()->create<luci::CircleLocalResponseNormalization>();
{
_lrn->input(input());
+ _lrn->name("test");
}
output()->from(_lrn);
@@ -295,6 +320,7 @@ public:
{
_softmax->logits(input());
_softmax->beta(0.1);
+ _softmax->name("test");
}
output()->from(_softmax);
@@ -324,6 +350,7 @@ public:
_stob->input(input());
_stob->block_shape(_block_shape);
_stob->paddings(_paddings);
+ _stob->name("test");
}
output()->from(_stob);
@@ -346,6 +373,7 @@ public:
{
_stod->input(input());
_stod->block_size(2);
+ _stod->name("test");
}
output()->from(_stod);
@@ -375,6 +403,7 @@ public:
_slice->input(input());
_slice->begin(_begin);
_slice->size(_size);
+ _slice->name("test");
}
output()->from(_slice);
@@ -472,6 +501,7 @@ public:
_slice->begin(_begin);
_slice->end(_end);
_slice->strides(_strides);
+ _slice->name("test");
}
output()->from(_slice);
@@ -499,6 +529,7 @@ public:
{
_reshape->tensor(input());
_reshape->shape(_shape);
+ _reshape->name("test");
}
output()->from(_reshape);
@@ -519,6 +550,7 @@ public:
_tanh = g()->nodes()->create<luci::CircleTanh>();
{
_tanh->x(input());
+ _tanh->name("test");
}
output()->from(_tanh);
@@ -538,6 +570,7 @@ public:
_floor = g()->nodes()->create<luci::CircleFloor>();
{
_floor->x(input());
+ _floor->name("test");
}
output()->from(_floor);
@@ -601,6 +634,7 @@ public:
_btos->input(input());
_btos->block_shape(_block_shape);
_btos->crops(_crops);
+ _btos->name("test");
}
output()->from(_btos);
@@ -623,6 +657,7 @@ public:
{
_dtos->input(input());
_dtos->block_size(2);
+ _dtos->name("test");
}
output()->from(_dtos);
@@ -645,6 +680,7 @@ public:
_pack->values(0, input());
_pack->values(1, _param);
_pack->axis(0);
+ _pack->name("test");
}
output()->from(_pack);
@@ -680,6 +716,7 @@ public:
{
_pad->input(input());
_pad->paddings(_paddings);
+ _pad->name("test");
}
output()->from(_pad);
@@ -707,6 +744,7 @@ public:
_pad->input(input());
_pad->paddings(_paddings);
_pad->constant_values(_constant_values);
+ _pad->name("test");
}
output()->from(_pad);
@@ -735,6 +773,7 @@ public:
_mirror_pad->input(input());
_mirror_pad->paddings(_paddings);
_mirror_pad->mode(luci::MirrorPadMode::REFLECT);
+ _mirror_pad->name("test");
}
output()->from(_mirror_pad);
@@ -761,6 +800,7 @@ public:
{
_transpose->a(input());
_transpose->perm(_perm);
+ _transpose->name("test");
}
output()->from(_transpose);
@@ -784,6 +824,8 @@ public:
_concat->values(0, input());
_concat->values(1, _param);
_concat->axis(0);
+ _concat->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _concat->name("test");
}
output()->from(_concat);
@@ -795,6 +837,54 @@ private:
luci::CircleConst *_param = nullptr;
};
+template <Type indexT> class OneHotTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32, 10});
+ {
+ // input dtype is float by default, but OneHot's input should have indexType (s32/s64)
+ input()->dtype(indexT);
+ }
+
+ _depth = g()->nodes()->template create<luci::CircleConst>();
+ {
+ _depth->dtype(loco::DataType::S32);
+ }
+
+ _on_value = g()->nodes()->template create<luci::CircleConst>();
+ {
+ _on_value->dtype(loco::DataType::FLOAT32);
+ }
+
+ _off_value = g()->nodes()->template create<luci::CircleConst>();
+ {
+ _off_value->dtype(loco::DataType::FLOAT32);
+ }
+
+ _one_hot = g()->nodes()->template create<luci::CircleOneHot>();
+ {
+ _one_hot->indices(input());
+ _one_hot->depth(_depth);
+ _one_hot->on_value(_on_value);
+ _one_hot->off_value(_off_value);
+ _one_hot->axis(-1);
+ _one_hot->dtype(loco::DataType::FLOAT32);
+ _one_hot->name("test");
+ }
+ output()->from(_one_hot);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+private:
+ luci::CircleOneHot *_one_hot = nullptr;
+ luci::CircleConst *_depth = nullptr;
+ luci::CircleConst *_on_value = nullptr;
+ luci::CircleConst *_off_value = nullptr;
+};
+
// Test graph for comparison Ops
// GREATER, GREATER_EQUAL, LESS, LESS_EQUAL, EQUAL, NOT_EQUAL
template <class Op> class ComparisonOpTestGraph final : public SimpleTestGraph
@@ -866,6 +956,7 @@ public:
{
_div->x(input());
_div->y(_const);
+ _div->name("test");
}
output()->from(_div);
@@ -893,6 +984,7 @@ public:
{
_floor_div->x(input());
_floor_div->y(_const);
+ _floor_div->name("test");
}
output()->from(_floor_div);
@@ -917,6 +1009,7 @@ public:
_rsqrt = g()->nodes()->create<luci::CircleRsqrt>();
{
_rsqrt->x(input());
+ _rsqrt->name("test");
}
output()->from(_rsqrt);
@@ -936,6 +1029,7 @@ public:
_sqrt = g()->nodes()->create<luci::CircleSqrt>();
{
_sqrt->x(input());
+ _sqrt->name("test");
}
output()->from(_sqrt);
@@ -955,6 +1049,7 @@ public:
_elu = g()->nodes()->create<luci::CircleElu>();
{
_elu->features(input());
+ _elu->name("test");
}
output()->from(_elu);
@@ -977,6 +1072,7 @@ public:
{
_pow->x(input());
_pow->y(_const);
+ _pow->name("test");
}
output()->from(_pow);
@@ -1004,6 +1100,7 @@ public:
{
_resize_bilinear->input(input());
_resize_bilinear->size(_size);
+ _resize_bilinear->name("test");
}
output()->from(_resize_bilinear);
@@ -1027,6 +1124,7 @@ public:
{
_resize_nearest_neighbor->input(input());
_resize_nearest_neighbor->size(_size);
+ _resize_nearest_neighbor->name("test");
}
output()->from(_resize_nearest_neighbor);
@@ -1067,6 +1165,62 @@ private:
luci::CircleConst *_unpack_dim = nullptr;
};
+class MulTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+
+ _const = create_dummy_const<Type::FLOAT32>(g(), {32});
+ _mul = g()->nodes()->create<luci::CircleMul>();
+ {
+ _mul->x(input());
+ _mul->y(_const);
+ _mul->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _mul->name("test");
+ }
+ output()->from(_mul);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+ loco::Node *x() { return _mul->x(); }
+ loco::Node *y() { return _mul->y(); }
+
+private:
+ luci::CircleMul *_mul = nullptr;
+ luci::CircleConst *_const = nullptr;
+};
+
+class AddTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+
+ _const = create_dummy_const<Type::FLOAT32>(g(), {32});
+ _add = g()->nodes()->create<luci::CircleAdd>();
+ {
+ _add->x(input());
+ _add->y(_const);
+ _add->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _add->name("test");
+ }
+ output()->from(_add);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+ loco::Node *x() { return _add->x(); }
+ loco::Node *y() { return _add->y(); }
+
+private:
+ luci::CircleAdd *_add = nullptr;
+ luci::CircleConst *_const = nullptr;
+};
+
} // namespace
// Quantize and verify with given configurations
@@ -1078,6 +1232,15 @@ private:
EXPECT_NO_THROW(quantize_and_verify(g.g(), type, granularity)); \
} while (0)
+// Quantize and verify with layer info
+#define TEST_WITH_LAYER_INFO(graph, type, granularity) \
+ do \
+ { \
+ graph g; \
+ g.init(); \
+ EXPECT_NO_THROW(quantize_and_verify_with_layer_info(g.g(), type, granularity)); \
+ } while (0)
+
// Quantize and verify with wrong type
#define TEST_WITH_WRONG_TYPE(graph, type, granularity, wrong_dtype) \
do \
@@ -1098,25 +1261,34 @@ private:
// Quantize and verify with wrong type
// Users can specify the test target
-#define TEST_WITH_WRONG_TYPE_TARGET(graph, type, granularity, wrong_dtype, target) \
- do \
- { \
- graph g; \
- g.init(); \
- auto node = loco::must_cast<luci::CircleNode *>(target); \
- EXPECT_ANY_THROW( \
- quantize_and_verify_with_wrong_type(&g, type, granularity, wrong_dtype, node)); \
+#define TEST_WITH_WRONG_TYPE_TARGET(graph, type, granularity, wrong_dtype, target) \
+ do \
+ { \
+ graph g; \
+ g.init(); \
+ auto node = loco::must_cast<luci::CircleNode *>(target); \
+ luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, type, granularity); \
+ pass.run(g.g()); \
+ auto after_node = loco::must_cast<luci::CircleNode *>(target); \
+ after_node->dtype(wrong_dtype); \
+ luci::QuantizedModelVerifier verifier(type, granularity); \
+ EXPECT_ANY_THROW(verifier.verify(g.g())); \
} while (0)
// Quantize and verify with wrong granularity
// Users can specify the test target
-#define TEST_WITH_WRONG_GRANULARITY_TARGET(graph, type, granularity, target) \
- do \
- { \
- graph g; \
- g.init(); \
- auto node = loco::must_cast<luci::CircleNode *>(target); \
- EXPECT_ANY_THROW(quantize_and_verify_with_wrong_granularity(&g, type, granularity, node)); \
+#define TEST_WITH_WRONG_GRANULARITY_TARGET(graph, type, granularity, target) \
+ do \
+ { \
+ graph g; \
+ g.init(); \
+ auto node = loco::must_cast<luci::CircleNode *>(target); \
+ luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, type, granularity); \
+ pass.run(g.g()); \
+ auto after_node = loco::must_cast<luci::CircleNode *>(target); \
+ insert_scale_zp(after_node, 1.0, 1); \
+ luci::QuantizedModelVerifier verifier(type, granularity); \
+ EXPECT_ANY_THROW(verifier.verify(g.g())); \
} while (0)
// Test a local helper function
@@ -1145,6 +1317,10 @@ TEST(QuantizedModelVerifierTest, InstanceNorm)
TEST_WITH_GRAPH(InstanceNormTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(InstanceNormTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(InstanceNormTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(InstanceNormTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(InstanceNormTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(InstanceNormTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1169,6 +1345,10 @@ TEST(QuantizedModelVerifierTest, LocalResponseNormalization)
TEST_WITH_GRAPH(LocalResponseNormalizationTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(LocalResponseNormalizationTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(LocalResponseNormalizationTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(LocalResponseNormalizationTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(LocalResponseNormalizationTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(LocalResponseNormalizationTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1199,6 +1379,10 @@ TEST(QuantizedModelVerifierTest, Logistic)
TEST_WITH_GRAPH(LogisticTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(LogisticTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(LogisticTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(LogisticTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(LogisticTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(LogisticTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1223,6 +1407,10 @@ TEST(QuantizedModelVerifierTest, Softmax)
TEST_WITH_GRAPH(SoftmaxTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(SoftmaxTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(SoftmaxTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(SoftmaxTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(SoftmaxTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(SoftmaxTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1247,6 +1435,10 @@ TEST(QuantizedModelVerifierTest, SpaceToBatchND)
TEST_WITH_GRAPH(SpaceToBatchNDTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(SpaceToBatchNDTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(SpaceToBatchNDTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(SpaceToBatchNDTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(SpaceToBatchNDTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(SpaceToBatchNDTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1271,6 +1463,10 @@ TEST(QuantizedModelVerifierTest, SpaceToDepth)
TEST_WITH_GRAPH(SpaceToDepthTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(SpaceToDepthTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(SpaceToDepthTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(SpaceToDepthTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(SpaceToDepthTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(SpaceToDepthTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1299,6 +1495,14 @@ TEST(QuantizedModelVerifierTest, Slice)
TEST_WITH_GRAPH(SliceTestGraph<Type::S64>, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(SliceTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(SliceTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(SliceTestGraph<Type::S32>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(SliceTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(SliceTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(SliceTestGraph<Type::S64>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(SliceTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(SliceTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1379,6 +1583,10 @@ TEST(QuantizedModelVerifierTest, StridedSlice)
TEST_WITH_GRAPH(StridedSliceTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(StridedSliceTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(StridedSliceTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(StridedSliceTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(StridedSliceTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(StridedSliceTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1463,6 +1671,10 @@ TEST(QuantizedModelVerifierTest, BatchToSpaceND)
TEST_WITH_GRAPH(BatchToSpaceNDTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(BatchToSpaceNDTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(BatchToSpaceNDTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(BatchToSpaceNDTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(BatchToSpaceNDTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(BatchToSpaceNDTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1487,6 +1699,10 @@ TEST(QuantizedModelVerifierTest, DepthToSpace)
TEST_WITH_GRAPH(DepthToSpaceTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(DepthToSpaceTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(DepthToSpaceTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(DepthToSpaceTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(DepthToSpaceTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(DepthToSpaceTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1511,6 +1727,10 @@ TEST(QuantizedModelVerifierTest, Concatenation)
TEST_WITH_GRAPH(ConcatenationTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(ConcatenationTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(ConcatenationTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(ConcatenationTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(ConcatenationTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(ConcatenationTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1557,6 +1777,10 @@ TEST(QuantizedModelVerifierTest, Reshape)
TEST_WITH_GRAPH(ReshapeTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(ReshapeTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(ReshapeTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(ReshapeTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(ReshapeTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(ReshapeTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1581,6 +1805,10 @@ TEST(QuantizedModelVerifierTest, Tanh)
TEST_WITH_GRAPH(TanhTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(TanhTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(TanhTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(TanhTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(TanhTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(TanhTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1606,6 +1834,10 @@ TEST(QuantizedModelVerifierTest, Pack)
TEST_WITH_GRAPH(PackTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(PackTestGraph, Type::S16, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(PackTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(PackTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(PackTestGraph, Type::S16, Granularity::ChannelWise);
+
// Test if Pack's qparam is propagated to the input
{
PackTestGraph g;
@@ -1640,6 +1872,10 @@ TEST(QuantizedModelVerifierTest, Pad)
TEST_WITH_GRAPH(PadTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(PadTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(PadTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(PadTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(PadTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(PadTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1664,6 +1900,10 @@ TEST(QuantizedModelVerifierTest, PadV2)
TEST_WITH_GRAPH(PadV2TestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(PadV2TestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(PadV2TestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(PadV2TestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(PadV2TestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(PadV2TestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1688,6 +1928,10 @@ TEST(QuantizedModelVerifierTest, MirrorPad)
TEST_WITH_GRAPH(MirrorPadTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(MirrorPadTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(MirrorPadTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(MirrorPadTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(MirrorPadTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(MirrorPadTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1712,6 +1956,10 @@ TEST(QuantizedModelVerifierTest, Transpose)
TEST_WITH_GRAPH(TransposeTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(TransposeTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(TransposeTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(TransposeTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(TransposeTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(TransposeTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1736,6 +1984,10 @@ TEST(QuantizedModelVerifierTest, Floor)
TEST_WITH_GRAPH(FloorTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(FloorTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(FloorTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(FloorTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(FloorTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(FloorTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1869,11 +2121,59 @@ TEST(QuantizedModelVerifierTest, NotEqual_wrong_granularity_NEG)
SUCCEED();
}
+TEST(QuantizedModelVerifierTest, OneHot)
+{
+ TEST_WITH_GRAPH(OneHotTestGraph<Type::S32>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(OneHotTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(OneHotTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_GRAPH(OneHotTestGraph<Type::S64>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(OneHotTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(OneHotTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(OneHotTestGraph<Type::S32>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(OneHotTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(OneHotTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(OneHotTestGraph<Type::S64>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(OneHotTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(OneHotTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, OneHot_wrong_input_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(OneHotTestGraph<Type::S32>, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(OneHotTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(OneHotTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise, Type::U8);
+
+ TEST_WITH_WRONG_TYPE(OneHotTestGraph<Type::S64>, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(OneHotTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(OneHotTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, OneHot_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(OneHotTestGraph<Type::S32>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(OneHotTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(OneHotTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_WRONG_GRANULARITY(OneHotTestGraph<Type::S64>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(OneHotTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(OneHotTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
TEST(QuantizedModelVerifierTest, Div)
{
TEST_WITH_GRAPH(DivTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(DivTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(DivTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(DivTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(DivTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(DivTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1902,6 +2202,10 @@ TEST(QuantizedModelVerifierTest, FloorDiv)
TEST_WITH_GRAPH(FloorDivTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(FloorDivTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(FloorDivTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(FloorDivTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(FloorDivTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(FloorDivTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1930,6 +2234,10 @@ TEST(QuantizedModelVerifierTest, Rsqrt)
TEST_WITH_GRAPH(RsqrtTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(RsqrtTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(RsqrtTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(RsqrtTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(RsqrtTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(RsqrtTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1954,6 +2262,10 @@ TEST(QuantizedModelVerifierTest, Sqrt)
TEST_WITH_GRAPH(SqrtTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(SqrtTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(SqrtTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(SqrtTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(SqrtTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(SqrtTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1978,6 +2290,10 @@ TEST(QuantizedModelVerifierTest, Elu)
TEST_WITH_GRAPH(EluTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(EluTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(EluTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(EluTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(EluTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(EluTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -2002,6 +2318,10 @@ TEST(QuantizedModelVerifierTest, Pow)
TEST_WITH_GRAPH(PowTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(PowTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(PowTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(PowTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(PowTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(PowTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -2030,6 +2350,10 @@ TEST(QuantizedModelVerifierTest, ResizeBilinear)
TEST_WITH_GRAPH(ResizeBilinearTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(ResizeBilinearTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(ResizeBilinearTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(ResizeBilinearTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(ResizeBilinearTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(ResizeBilinearTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -2054,6 +2378,10 @@ TEST(QuantizedModelVerifierTest, ResizeNearestNeighbor)
TEST_WITH_GRAPH(ResizeNearestNeighborTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(ResizeNearestNeighborTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(ResizeNearestNeighborTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(ResizeBilinearTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(ResizeBilinearTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(ResizeBilinearTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -2099,6 +2427,93 @@ TEST(QuantizedModelVerifierTest, Unpack_wrong_granularity_NEG)
SUCCEED();
}
+TEST(QuantizedModelVerifierTest, Add)
+{
+ TEST_WITH_GRAPH(AddTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(AddTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(AddTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(AddTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(AddTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(AddTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Add_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(AddTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(AddTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(AddTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Add_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY_TARGET(AddTestGraph, Type::U8, Granularity::LayerWise, g.x());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(AddTestGraph, Type::U8, Granularity::ChannelWise, g.x());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(AddTestGraph, Type::S16, Granularity::ChannelWise, g.x());
+
+ TEST_WITH_WRONG_GRANULARITY_TARGET(AddTestGraph, Type::U8, Granularity::LayerWise, g.y());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(AddTestGraph, Type::U8, Granularity::ChannelWise, g.y());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(AddTestGraph, Type::S16, Granularity::ChannelWise, g.y());
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Mul)
+{
+ TEST_WITH_GRAPH(MulTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(MulTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(MulTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(MulTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(MulTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(MulTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Mul_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(MulTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(MulTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(MulTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Mul_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY_TARGET(MulTestGraph, Type::U8, Granularity::LayerWise, g.x());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(MulTestGraph, Type::U8, Granularity::ChannelWise, g.x());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(MulTestGraph, Type::S16, Granularity::ChannelWise, g.x());
+
+ TEST_WITH_WRONG_GRANULARITY_TARGET(MulTestGraph, Type::U8, Granularity::LayerWise, g.y());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(MulTestGraph, Type::U8, Granularity::ChannelWise, g.y());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(MulTestGraph, Type::S16, Granularity::ChannelWise, g.y());
+ SUCCEED();
+}
+
+// TODO Add following testcases
+//
+// CircleConv2D
+//
+// CircleDepthwiseConv2D
+//
+// CirclePRelu
+//
+// CircleTransposeConv
+//
+// CircleFullyConnected
+//
+// CircleAveragePool2D
+//
+// CircleMaxPool2D
+//
+// CircleMean
+//
+// CircleRelu
+//
+// CircleCast
+//
+
#undef TEST_WITH_GRAPH
#undef TEST_WITH_WRONG_TYPE
#undef TEST_WITH_WRONG_GRANULARITY
diff --git a/compiler/luci/pass/src/RemoveRedundantQuantizePass.cpp b/compiler/luci/pass/src/RemoveRedundantQuantizePass.cpp
new file mode 100644
index 000000000..8a10ad4a0
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveRedundantQuantizePass.cpp
@@ -0,0 +1,104 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/RemoveRedundantQuantizePass.h"
+
+#include <luci/IR/CircleNode.h>
+
+/**
+ * Remove redundant quantize operations. For subsequent Quantize Ops,
+ * only the last Quantize Op is valid, so we can remove the rest of the Quantize Op.
+ *
+ * BEFORE
+ * [CircleNode_1]
+ * |
+ * [CircleQuantize, dtype_1, scale_1, zero_point_1]
+ * |
+ * [CircleQuantize, dtype_2, scale_2, zero_point_2]
+ * |
+ * [CircleNode_2]
+ *
+ * AFTER
+ * [CircleNode_1]
+ * / \
+ * / \
+ * / \
+ * / \
+ * / \
+ * [CircleQuantize, dtype_2, scale_2, zero_point_2] [CircleQuantize, dtype_1, scale_1, zero_point_1]
+ * |
+ * [CircleNode_2]
+ *
+ */
+
+namespace
+{
+
+bool remove_redundant_quantize(luci::CircleQuantize *node)
+{
+ auto pred_node = loco::must_cast<luci::CircleNode *>(node->input());
+
+ if (node->quantparam() == nullptr or pred_node->quantparam() == nullptr)
+ return false;
+
+ if (node->quantparam()->scale.size() != 1 or node->quantparam()->zerop.size() != 1 or
+ pred_node->quantparam()->scale.size() != 1 or pred_node->quantparam()->zerop.size() != 1)
+ {
+ return false;
+ }
+
+ if (node->dtype() != pred_node->dtype() or
+ pred_node->quantparam()->scale.at(0) != node->quantparam()->scale.at(0) or
+ pred_node->quantparam()->zerop.at(0) != node->quantparam()->zerop.at(0))
+ {
+ return false;
+ }
+
+ replace(node).with(pred_node);
+
+ return true;
+}
+
+bool remove_redundant_subsequent_quantize(luci::CircleQuantize *node)
+{
+ auto pred_node = dynamic_cast<luci::CircleQuantize *>(node->input());
+ if (pred_node == nullptr)
+ return remove_redundant_quantize(node);
+
+ node->input(pred_node->input());
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool RemoveRedundantQuantizePass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
+ {
+ if (auto quantize_node = dynamic_cast<luci::CircleQuantize *>(node))
+ {
+ if (remove_redundant_subsequent_quantize(quantize_node))
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/RemoveRedundantQuantizePass.test.cpp b/compiler/luci/pass/src/RemoveRedundantQuantizePass.test.cpp
new file mode 100644
index 000000000..d0166bd20
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveRedundantQuantizePass.test.cpp
@@ -0,0 +1,166 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/RemoveRedundantQuantizePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <luci/test/TestIOGraph.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class QuantizeGraphlet
+{
+public:
+ QuantizeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g)
+ {
+ _first_quantize = g->nodes()->create<luci::CircleQuantize>();
+ _first_quantize->dtype(loco::DataType::U8);
+ {
+ auto quantize_param = std::make_unique<luci::CircleQuantParam>();
+ quantize_param->scale = {0.5};
+ quantize_param->zerop = {0};
+ _first_quantize->quantparam(std::move(quantize_param));
+ }
+ _first_quantize->name("first_quantize");
+
+ _second_quantize = g->nodes()->create<luci::CircleQuantize>();
+ _second_quantize->dtype(loco::DataType::U8);
+ {
+ auto quantize_param = std::make_unique<luci::CircleQuantParam>();
+ quantize_param->scale = {0.5};
+ quantize_param->zerop = {0};
+ _second_quantize->quantparam(std::move(quantize_param));
+ }
+ _second_quantize->name("second_quantize");
+ }
+
+protected:
+ luci::CircleQuantize *_first_quantize = nullptr;
+ luci::CircleQuantize *_second_quantize = nullptr;
+};
+
+class RedundantSubsequentQuantizeGraph : public TestIOGraph, public QuantizeGraphlet
+{
+public:
+ RedundantSubsequentQuantizeGraph() = default;
+
+public:
+ void init(void)
+ {
+ TestIOGraph::init({1}, {1});
+ QuantizeGraphlet::init(g());
+
+ input()->dtype(loco::DataType::U8);
+ {
+ auto quantize_param = std::make_unique<luci::CircleQuantParam>();
+ quantize_param->scale = {1};
+ quantize_param->zerop = {1};
+ input()->quantparam(std::move(quantize_param));
+ }
+
+ _first_quantize->input(input());
+ _second_quantize->input(_first_quantize);
+
+ output()->from(_second_quantize);
+ output()->dtype(loco::DataType::U8);
+ }
+};
+
+class RedundantQuantizeGraph : public TestIOGraph, public QuantizeGraphlet
+{
+public:
+ RedundantQuantizeGraph() = default;
+
+public:
+ void init(void)
+ {
+ TestIOGraph::init({1}, {1});
+ QuantizeGraphlet::init(g());
+
+ input()->dtype(loco::DataType::U8);
+ {
+ auto quantize_param = std::make_unique<luci::CircleQuantParam>();
+ quantize_param->scale = {0.5};
+ quantize_param->zerop = {0};
+ input()->quantparam(std::move(quantize_param));
+ }
+
+ _first_quantize->input(input());
+
+ output()->from(_first_quantize);
+ output()->dtype(loco::DataType::U8);
+ }
+};
+
+} // namespace
+
+TEST(RemoveRedundantQuantizePass, name)
+{
+ luci::RemoveRedundantQuantizePass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST(RemoveRedundantQuantizePass, remove_subsequent_quantize)
+{
+ RedundantSubsequentQuantizeGraph g;
+ luci::RemoveRedundantQuantizePass pass;
+
+ g.init();
+
+ EXPECT_TRUE(pass.run(g.g()));
+
+ int count = 0;
+ for (auto node : loco::active_nodes(loco::output_nodes(g.g())))
+ {
+ if (dynamic_cast<luci::CircleQuantize *>(node))
+ {
+ count++;
+ }
+ }
+
+ ASSERT_EQ(1, count);
+}
+
+TEST(RemoveRedundantQuantizePass, remove_quantize)
+{
+ RedundantQuantizeGraph g;
+ luci::RemoveRedundantQuantizePass pass;
+
+ g.init();
+
+ EXPECT_TRUE(pass.run(g.g()));
+
+ int count = 0;
+ for (auto node : loco::active_nodes(loco::output_nodes(g.g())))
+ {
+ if (dynamic_cast<luci::CircleQuantize *>(node))
+ {
+ count++;
+ }
+ }
+
+ ASSERT_EQ(0, count);
+}
diff --git a/compiler/luci/pass/src/RemoveRedundantTransposePass.cpp b/compiler/luci/pass/src/RemoveRedundantTransposePass.cpp
index 71c51ecda..75cf72795 100644
--- a/compiler/luci/pass/src/RemoveRedundantTransposePass.cpp
+++ b/compiler/luci/pass/src/RemoveRedundantTransposePass.cpp
@@ -71,7 +71,7 @@ bool remove_consecutive_transpose_function(luci::CircleTranspose *target_node)
for (uint32_t i = 0; i < pred_perm->size<loco::DataType::S32>(); i++)
{
new_const_node->at<loco::DataType::S32>(i) =
- target_perm->at<loco::DataType::S32>(pred_perm->at<loco::DataType::S32>(i));
+ pred_perm->at<loco::DataType::S32>(target_perm->at<loco::DataType::S32>(i));
}
new_const_node->name(name + "/Transpose/perm");
diff --git a/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp b/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp
index e80623499..bb8e292d4 100644
--- a/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp
+++ b/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp
@@ -271,6 +271,31 @@ TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type2)
ASSERT_EQ(2, perm->at<loco::DataType::S32>(3));
}
+TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type3)
+{
+ auto graph = loco::make_graph();
+ create_redundunt_transpose(graph.get(), {0, 3, 2, 1}, {0, 2, 3, 1});
+
+ luci::RemoveRedundantTransposePass pass;
+ while (pass.run(graph.get()))
+ ;
+ luci::CircleTranspose *transpose_node = nullptr;
+ for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
+ {
+ auto trans = dynamic_cast<luci::CircleTranspose *>(node);
+ if (not trans)
+ continue;
+ transpose_node = trans;
+ break;
+ }
+ ASSERT_NE(nullptr, transpose_node);
+ auto perm = loco::must_cast<luci::CircleConst *>(transpose_node->perm());
+ ASSERT_EQ(0, perm->at<loco::DataType::S32>(0));
+ ASSERT_EQ(2, perm->at<loco::DataType::S32>(1));
+ ASSERT_EQ(1, perm->at<loco::DataType::S32>(2));
+ ASSERT_EQ(3, perm->at<loco::DataType::S32>(3));
+}
+
/**
* @brief Test case that first transpose output become input of operations more than one.
*/
diff --git a/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.cpp b/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.cpp
index 3f0c4ee82..fb46f490d 100644
--- a/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.cpp
+++ b/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.cpp
@@ -58,6 +58,25 @@ bool remove_no_effect_reshape(luci::CircleNode *node)
namespace luci
{
+/**
+ * BEFORE
+ * [CircleNode]
+ * |
+ * [CircleReshape]
+ * |
+ * [CircleNode]
+ *
+ * AFTER
+ * [CircleNode]
+ * | \
+ * | [CircleReshape]
+ * |
+ * [CircleNode]
+ *
+ * NOTE
+ * This pass will remove Reshape when input and output has same shape
+ */
+
bool RemoveUnnecessaryReshapePass::run(loco::Graph *g)
{
bool changed = false;
diff --git a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp
index a0cc0194f..bca0a9483 100644
--- a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp
+++ b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp
@@ -26,8 +26,17 @@ namespace
luci::CircleConst *create_weights_from_gamma(luci::CircleConst *gamma)
{
- assert(gamma->rank() == 1);
- auto channel_size = gamma->dim(0).value();
+ assert(gamma->rank() == 1 or gamma->rank() == 4);
+
+ uint32_t channel_idx = gamma->rank() - 1;
+ uint32_t channel_size = gamma->dim(channel_idx).value();
+
+ // Gamma should be broadcastable in the channel direction
+ for (uint32_t i = 0; i < gamma->rank(); i++)
+ {
+ if (i != channel_idx)
+ assert(gamma->dim(i).value() == 1); // FIX is_batchnorm_mul UNLESS
+ }
auto name = gamma->name();
assert(name.length() > 0);
@@ -53,8 +62,17 @@ luci::CircleConst *create_weights_from_gamma(luci::CircleConst *gamma)
luci::CircleConst *create_bias_from_beta(luci::CircleConst *beta)
{
- assert(beta->rank() == 1);
- auto channel_size = beta->dim(0).value();
+ assert(beta->rank() == 1 or beta->rank() == 4);
+
+ uint32_t channel_idx = beta->rank() - 1;
+ uint32_t channel_size = beta->dim(channel_idx).value();
+
+ // Beta should be broadcastable in the channel direction
+ for (uint32_t i = 0; i < beta->rank(); i++)
+ {
+ if (i != channel_idx)
+ assert(beta->dim(i).value() == 1); // FIX is_batchnorm_add UNLESS
+ }
auto name = beta->name();
assert(name.length() > 0);
diff --git a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp
index 903d4dcc9..bac033112 100644
--- a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp
+++ b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp
@@ -141,6 +141,37 @@ TEST(ReplaceMulAddWithDepthwiseConv, simple)
}
}
+TEST(ReplaceMulAddWithDepthwiseConv, simple_rank4)
+{
+ SimpleGraph g;
+
+ const uint32_t channel_size = 16;
+ g.gamma->shape({1, 1, 1, channel_size});
+ g.beta->shape({1, 1, 1, channel_size});
+
+ luci::ReplaceMulAddWithDepthwiseConvPass pass;
+ while (pass.run(&g.g))
+ ;
+
+ auto dwconv = dynamic_cast<luci::CircleDepthwiseConv2D *>(g.output->from());
+ EXPECT_NE(nullptr, dwconv);
+
+ auto weights = dynamic_cast<luci::CircleConst *>(dwconv->filter());
+ auto bias = dynamic_cast<luci::CircleConst *>(dwconv->bias());
+ EXPECT_NE(nullptr, weights);
+ EXPECT_EQ(4, weights->rank());
+ EXPECT_EQ(channel_size, weights->dim(3).value());
+ EXPECT_NE(nullptr, bias);
+ EXPECT_EQ(1, bias->rank());
+ EXPECT_EQ(channel_size, bias->dim(0).value());
+
+ for (int i = 0; i < channel_size; i++)
+ {
+ EXPECT_FLOAT_EQ(i, weights->at<loco::DataType::FLOAT32>(i));
+ EXPECT_FLOAT_EQ(i, bias->at<loco::DataType::FLOAT32>(i));
+ }
+}
+
TEST(ReplaceMulAddWithDepthwiseConv, wrong_op_NEG)
{
SimpleGraph g;
@@ -154,3 +185,18 @@ TEST(ReplaceMulAddWithDepthwiseConv, wrong_op_NEG)
EXPECT_EQ(false, changed);
}
+
+TEST(ReplaceMulAddWithDepthwiseConv, rank3_NEG)
+{
+ SimpleGraph g;
+
+ g.input->shape({4, 4, 16});
+ g.mul->shape({4, 4, 16});
+ g.add->shape({4, 4, 16});
+ g.output->shape({4, 4, 16});
+
+ luci::ReplaceMulAddWithDepthwiseConvPass pass;
+ auto changed = pass.run(&g.g);
+
+ EXPECT_EQ(false, changed);
+}
diff --git a/compiler/luci/pass/src/SubstituteSplitVToSplitPass.cpp b/compiler/luci/pass/src/SubstituteSplitVToSplitPass.cpp
index 9cba9a9e7..57c386d99 100644
--- a/compiler/luci/pass/src/SubstituteSplitVToSplitPass.cpp
+++ b/compiler/luci/pass/src/SubstituteSplitVToSplitPass.cpp
@@ -24,15 +24,6 @@
namespace
{
-void copy_quantparam(luci::CircleNode *dst, const luci::CircleNode *src)
-{
- auto q = src->quantparam();
- if (q == nullptr)
- dst->quantparam(nullptr);
- else
- dst->quantparam(std::make_unique<luci::CircleQuantParam>(*q));
-}
-
// SplitV is substituted to Split if the contents of size_splits are all same
// For example,
// size_splits = [32, 32] -> substitute
@@ -67,7 +58,7 @@ bool resolve_splitv(luci::CircleSplitV *sv)
split_node->split_dim(sv->split_dim());
split_node->num_split(sv->num_split());
split_node->name(sv->name());
- copy_quantparam(split_node, sv);
+ copy_quantparam(sv, split_node);
luci::add_origin(split_node, luci::get_origin(sv));
auto succs = loco::succs(sv);
@@ -78,7 +69,7 @@ bool resolve_splitv(luci::CircleSplitV *sv)
so_node->input(split_node);
so_node->index(svo->index());
so_node->name(svo->name());
- copy_quantparam(so_node, svo);
+ copy_quantparam(svo, so_node);
luci::add_origin(so_node, luci::get_origin(svo));
replace(svo).with(so_node);
diff --git a/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp b/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp
index f48763782..df7266df9 100644
--- a/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp
+++ b/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp
@@ -76,18 +76,6 @@ std::vector<uint32_t> node_shape(const luci::CircleNode *input)
}
/**
- * @brief copy quantparam of src to dst
- */
-void copy_quantparam(luci::CircleNode *dst, const luci::CircleNode *src)
-{
- auto q = src->quantparam();
- if (q == nullptr)
- dst->quantparam(nullptr);
- else
- dst->quantparam(std::make_unique<luci::CircleQuantParam>(*q));
-}
-
-/**
* @brief return CircleConst ptr with values of new_shape
*/
luci::CircleConst *create_shape_const(loco::Graph *graph, const std::vector<uint32_t> &new_shape)
@@ -142,7 +130,7 @@ bool substitute_squeeze_to_reshape(luci::CircleSqueeze *squeeze)
auto graph = squeeze->graph();
auto reshape = graph->nodes()->create<luci::CircleReshape>();
auto shape_const = create_shape_const(graph, reshape_shape);
- copy_quantparam(reshape, squeeze);
+ copy_quantparam(squeeze, reshape);
reshape->name(name + "/Reshape");
luci::add_origin(reshape, luci::get_origin(squeeze));
shape_const->name(name + "/Reshape/shape");
diff --git a/compiler/luci/pass/src/SubstituteStridedSliceToReshapePass.cpp b/compiler/luci/pass/src/SubstituteStridedSliceToReshapePass.cpp
index f50f2f54f..9e1c5a4a3 100644
--- a/compiler/luci/pass/src/SubstituteStridedSliceToReshapePass.cpp
+++ b/compiler/luci/pass/src/SubstituteStridedSliceToReshapePass.cpp
@@ -124,7 +124,7 @@ bool substitute_strided_slice_to_reshape(luci::CircleStridedSlice *ss_node)
std::bitset<32> end_mask(ss_node->end_mask());
std::bitset<32> shrink_axis_mask(ss_node->shrink_axis_mask());
- uint input_rank = input_node->rank();
+ uint32_t input_rank = input_node->rank();
for (uint32_t i = 0; i < input_rank; i++)
{
if (!input_node->dim(i).known())
diff --git a/compiler/luci/pass/src/VerifyQuantizedBiasScale.cpp b/compiler/luci/pass/src/VerifyQuantizedBiasScale.cpp
new file mode 100644
index 000000000..e65d576cd
--- /dev/null
+++ b/compiler/luci/pass/src/VerifyQuantizedBiasScale.cpp
@@ -0,0 +1,105 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "VerifyQuantizedBiasScale.h"
+
+#include <cmath>
+
+// This macro is undef at the end of the file
+#define RETURN_FALSE_UNLESS(ARG) \
+ if (not(ARG)) \
+ { \
+ return false; \
+ }
+
+namespace
+{
+
+bool same(float a, float b)
+{
+ constexpr float epsilon = 1e-10;
+ return abs(a - b) < epsilon;
+}
+
+// Check bias scale = input scale * weight scale
+// This function checks both LWQ and CWQ
+bool check_bias_scale(const loco::Node *input, const loco::Node *weights, const loco::Node *bias)
+{
+ auto input_node = loco::must_cast<const luci::CircleNode *>(input);
+ auto input_qparam = input_node->quantparam();
+ RETURN_FALSE_UNLESS(input_qparam != nullptr);
+
+ auto weights_node = loco::must_cast<const luci::CircleNode *>(weights);
+ auto weights_qparam = weights_node->quantparam();
+ RETURN_FALSE_UNLESS(weights_qparam != nullptr);
+
+ auto bias_node = loco::must_cast<const luci::CircleNode *>(bias);
+ auto bias_qparam = bias_node->quantparam();
+ RETURN_FALSE_UNLESS(bias_qparam != nullptr);
+
+ RETURN_FALSE_UNLESS(input_qparam->scale.size() == 1);
+ RETURN_FALSE_UNLESS(weights_qparam->scale.size() == bias_qparam->scale.size());
+
+ auto input_scale = input_qparam->scale[0];
+ for (uint32_t i = 0; i < weights_qparam->scale.size(); i++)
+ {
+ auto weights_scale = weights_qparam->scale[i];
+ auto bias_scale = bias_qparam->scale[i];
+ RETURN_FALSE_UNLESS(same(bias_scale, input_scale * weights_scale));
+ }
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool VerifyQuantizedBiasScale::visit(const luci::CircleConv2D *node)
+{
+ RETURN_FALSE_UNLESS(check_bias_scale(node->input(), node->filter(), node->bias()));
+ return true;
+}
+
+bool VerifyQuantizedBiasScale::visit(const luci::CircleDepthwiseConv2D *node)
+{
+ RETURN_FALSE_UNLESS(check_bias_scale(node->input(), node->filter(), node->bias()));
+ return true;
+}
+
+bool VerifyQuantizedBiasScale::visit(const luci::CircleFullyConnected *node)
+{
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ {
+ RETURN_FALSE_UNLESS(check_bias_scale(node->input(), node->weights(), node->bias()));
+ }
+ return true;
+}
+
+bool VerifyQuantizedBiasScale::visit(const luci::CircleTransposeConv *node)
+{
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ {
+ RETURN_FALSE_UNLESS(check_bias_scale(node->outBackprop(), node->filter(), node->bias()));
+ }
+ return true;
+}
+
+} // namespace luci
+
+#undef RETURN_FALSE_UNLESS
diff --git a/compiler/luci/pass/src/VerifyQuantizedBiasScale.h b/compiler/luci/pass/src/VerifyQuantizedBiasScale.h
new file mode 100644
index 000000000..b41f78eca
--- /dev/null
+++ b/compiler/luci/pass/src/VerifyQuantizedBiasScale.h
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_VERIFY_QUANTIZED_BIAS_SCALE_H__
+#define __LUCI_VERIFY_QUANTIZED_BIAS_SCALE_H__
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+#include <memory>
+
+namespace luci
+{
+
+/**
+ * @brief Verify the scale of quantized bias node
+ * @details
+ *
+ * Bias of CONV, DCONV, TCONV, FC layers should meet the following condition.
+ *
+ * bias scale = input scale * weights scale
+ */
+class VerifyQuantizedBiasScale : public luci::CircleNodeVisitor<bool>
+{
+public:
+ static std::shared_ptr<VerifyQuantizedBiasScale> create()
+ {
+ return std::make_shared<VerifyQuantizedBiasScale>();
+ };
+
+public:
+ bool verify(luci::CircleNode *node) { return node->accept(this); }
+
+private:
+ // Operators with bias
+ bool visit(const luci::CircleConv2D *node);
+ bool visit(const luci::CircleDepthwiseConv2D *node);
+ bool visit(const luci::CircleFullyConnected *node);
+ bool visit(const luci::CircleTransposeConv *node);
+
+ bool visit(const luci::CircleNode *) { return true; }
+};
+
+} // namespace luci
+
+#endif // __LUCI_VERIFY_QUANTIZED_BIAS_SCALE_H__
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.cpp b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.cpp
new file mode 100644
index 000000000..8697090a7
--- /dev/null
+++ b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.cpp
@@ -0,0 +1,38 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "VerifyQuantizedNodeGranularity.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Pass/QuantizationParameters.h>
+
+#include <memory>
+
+namespace luci
+{
+
+std::shared_ptr<VerifyQuantizedNodeGranularity>
+VerifyQuantizedNodeGranularity::create(Granularity granularity)
+{
+ if (granularity == Granularity::ChannelWise)
+ return std::make_shared<VerifyQuantizedNodeChannelWiseGranularity>();
+ else if (granularity == Granularity::LayerWise)
+ return std::make_shared<VerifyQuantizedNodeLayerWiseGranularity>();
+ else
+ throw std::domain_error("Not supported Granularity type");
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeChannelWiseGranularity.h b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h
index bf3ff2e8a..442183c18 100644
--- a/compiler/luci/pass/src/VerifyQuantizedNodeChannelWiseGranularity.h
+++ b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h
@@ -1,5 +1,6 @@
/*
- * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
@@ -13,13 +14,15 @@
* limitations under the License.
*/
-#ifndef __LUCI_VERIFY_QUANTIZED_NODE_CHANNELWISE_GRANULARITY_H__
-#define __LUCI_VERIFY_QUANTIZED_NODE_CHANNELWISE_GRANULARITY_H__
+#ifndef __LUCI_VERIFY_QUANTIZED_NODE_GRANULARITY_H__
+#define __LUCI_VERIFY_QUANTIZED_NODE_GRANULARITY_H__
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
#include <luci/Pass/QuantizationParameters.h>
+#include <memory>
+
using Granularity = luci::QuantizationGranularity;
// This macro is undef at the end of the file
@@ -33,16 +36,19 @@ namespace luci
{
/**
- * @brief Verify the granualrity of channel-wise quantized node
+ * @brief Verify the granualrity of quantized node
* @details
*
* Targets to verify
* - node's output (i.e., node itself)
* - node's inputs
*/
-struct VerifyQuantizedNodeChannelWiseGranularity final : public luci::CircleNodeVisitor<bool>
+class VerifyQuantizedNodeGranularity : public luci::CircleNodeVisitor<bool>
{
-private:
+public:
+ static std::shared_ptr<VerifyQuantizedNodeGranularity> create(Granularity granularity);
+
+protected:
bool is_lwq(const loco::Node *node)
{
auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
@@ -59,48 +65,15 @@ private:
return true;
}
- uint32_t rank(const loco::Node *node)
- {
- auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
- return circle_node->rank();
- }
-
- bool is_cwq_const(const loco::Node *node, uint32_t channel_dim)
- {
- auto circle_node = loco::must_cast<const luci::CircleConst *>(node);
-
- assert(channel_dim < circle_node->rank()); // FIX_CALLER_UNLESS
- auto channel_size = circle_node->dim(channel_dim).value();
-
- if (circle_node->quantparam() == nullptr)
- return false;
-
- if (circle_node->quantparam()->quantized_dimension != static_cast<int32_t>(channel_dim))
- return false;
-
- if (circle_node->quantparam()->scale.size() != channel_size)
- return false;
-
- if (circle_node->quantparam()->zerop.size() != channel_size)
- return false;
-
- return true;
- }
-
private:
- bool visit(const luci::CircleConv2D *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
- return true;
- }
+ virtual bool visit(const luci::CircleConv2D *node) = 0;
bool visit(const luci::CircleConcatenation *node)
{
+ // Skip granularity check for concatenation of indices
+ if (node->dtype() == loco::DataType::S32 or node->dtype() == loco::DataType::S64)
+ return true;
+
RETURN_FALSE_UNLESS(is_lwq(node))
for (uint32_t i = 0; i < node->numValues(); i++)
{
@@ -116,25 +89,9 @@ private:
return true;
}
- bool visit(const luci::CircleDepthwiseConv2D *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 3))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
- return true;
- }
+ virtual bool visit(const luci::CircleDepthwiseConv2D *node) = 0;
- bool visit(const luci::CircleInstanceNorm *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_cwq_const(node->gamma(), rank(node->gamma()) - 1))
- RETURN_FALSE_UNLESS(is_cwq_const(node->beta(), rank(node->beta()) - 1))
- return true;
- }
+ virtual bool visit(const luci::CircleInstanceNorm *node) = 0;
bool visit(const luci::CirclePack *node)
{
@@ -168,37 +125,11 @@ private:
return true;
}
- bool visit(const luci::CirclePRelu *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_cwq_const(node->alpha(), rank(node->alpha()) - 1))
- return true;
- }
-
- bool visit(const luci::CircleTransposeConv *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->outBackprop()))
- RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
+ virtual bool visit(const luci::CirclePRelu *node) = 0;
- return true;
- }
+ virtual bool visit(const luci::CircleTransposeConv *node) = 0;
- bool visit(const luci::CircleFullyConnected *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_cwq_const(node->weights(), 0))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- // Bias is optional (it can be CircleOutputExclude)
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
- return true;
- }
+ virtual bool visit(const luci::CircleFullyConnected *node) = 0;
bool visit(const luci::CircleAdd *node)
{
@@ -258,6 +189,14 @@ private:
return true;
}
+ bool visit(const luci::CircleOneHot *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->off_value()));
+ RETURN_FALSE_UNLESS(is_lwq(node->on_value()));
+ return true;
+ }
+
bool visit(const luci::CircleRelu *node)
{
RETURN_FALSE_UNLESS(is_lwq(node));
@@ -480,8 +419,186 @@ private:
bool visit(const luci::CircleNode *) { return true; }
};
+class VerifyQuantizedNodeChannelWiseGranularity final : public VerifyQuantizedNodeGranularity
+{
+private:
+ uint32_t rank(const loco::Node *node)
+ {
+ auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
+ return circle_node->rank();
+ }
+
+ bool is_cwq_const(const loco::Node *node, uint32_t channel_dim)
+ {
+ auto circle_node = loco::must_cast<const luci::CircleConst *>(node);
+
+ assert(channel_dim < circle_node->rank()); // FIX_CALLER_UNLESS
+ auto channel_size = circle_node->dim(channel_dim).value();
+
+ if (circle_node->quantparam() == nullptr)
+ return false;
+
+ if (circle_node->quantparam()->quantized_dimension != static_cast<int32_t>(channel_dim))
+ return false;
+
+ if (circle_node->quantparam()->scale.size() != channel_size)
+ return false;
+
+ if (circle_node->quantparam()->zerop.size() != channel_size)
+ return false;
+
+ return true;
+ }
+
+private:
+ bool visit(const luci::CircleConv2D *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
+ return true;
+ }
+
+ bool visit(const luci::CircleDepthwiseConv2D *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 3))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
+ return true;
+ }
+
+ bool visit(const luci::CircleInstanceNorm *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->gamma(), rank(node->gamma()) - 1))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->beta(), rank(node->beta()) - 1))
+ return true;
+ }
+
+ bool visit(const luci::CirclePRelu *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->alpha(), rank(node->alpha()) - 1))
+ return true;
+ }
+
+ bool visit(const luci::CircleTransposeConv *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->outBackprop()))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
+
+ return true;
+ }
+
+ bool visit(const luci::CircleFullyConnected *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->weights(), 0))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ // Bias is optional (it can be CircleOutputExclude)
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
+ return true;
+ }
+};
+
+class VerifyQuantizedNodeLayerWiseGranularity final : public VerifyQuantizedNodeGranularity
+{
+private:
+ bool is_lwq_const(const loco::Node *node)
+ {
+ auto circle_node = loco::must_cast<const luci::CircleConst *>(node);
+
+ if (circle_node->quantparam() == nullptr)
+ return false;
+
+ if (circle_node->quantparam()->scale.size() != 1)
+ return false;
+
+ if (circle_node->quantparam()->zerop.size() != 1)
+ return false;
+
+ return true;
+ }
+
+private:
+ bool visit(const luci::CircleConv2D *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
+ return true;
+ }
+
+ bool visit(const luci::CircleDepthwiseConv2D *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
+ return true;
+ }
+
+ bool visit(const luci::CircleInstanceNorm *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->gamma()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->beta()))
+ return true;
+ }
+
+ bool visit(const luci::CirclePRelu *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->alpha()))
+ return true;
+ }
+
+ bool visit(const luci::CircleTransposeConv *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->outBackprop()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
+ return true;
+ }
+
+ bool visit(const luci::CircleFullyConnected *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->weights()))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
+ return true;
+ }
+};
+
} // namespace luci
#undef RETURN_FALSE_UNLESS
-#endif // __LUCI_VERIFY_QUANTIZED_NODE_CHANNELWISE_GRANULARITY_H__
+#endif // __LUCI_VERIFY_QUANTIZED_NODE_GRANULARITY_H__
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeLayerWiseGranularity.h b/compiler/luci/pass/src/VerifyQuantizedNodeLayerWiseGranularity.h
deleted file mode 100644
index 9bc8b31df..000000000
--- a/compiler/luci/pass/src/VerifyQuantizedNodeLayerWiseGranularity.h
+++ /dev/null
@@ -1,473 +0,0 @@
-/*
- * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __LUCI_VERIFY_QUANTIZED_NODE_LAYERWISE_GRANULARITY_H__
-#define __LUCI_VERIFY_QUANTIZED_NODE_LAYERWISE_GRANULARITY_H__
-
-#include <luci/IR/CircleNodes.h>
-#include <luci/IR/CircleNodeVisitor.h>
-#include <luci/Pass/QuantizationParameters.h>
-
-using Granularity = luci::QuantizationGranularity;
-
-// This macro is undef at the end of the file
-#define RETURN_FALSE_UNLESS(ARG) \
- if (not(ARG)) \
- { \
- return false; \
- }
-
-namespace luci
-{
-
-/**
- * @brief Verify the granualrity of layer-wise quantized node
- * @details
- *
- * Targets to verify
- * - node's output (i.e., node itself)
- * - node's inputs
- */
-struct VerifyQuantizedNodeLayerWiseGranularity final : public luci::CircleNodeVisitor<bool>
-{
-private:
- bool is_lwq(const loco::Node *node)
- {
- auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
-
- if (circle_node->quantparam() == nullptr)
- return false;
-
- if (circle_node->quantparam()->scale.size() != 1)
- return false;
-
- if (circle_node->quantparam()->zerop.size() != 1)
- return false;
-
- return true;
- }
-
- bool is_lwq_const(const loco::Node *node)
- {
- auto circle_node = loco::must_cast<const luci::CircleConst *>(node);
-
- if (circle_node->quantparam() == nullptr)
- return false;
-
- if (circle_node->quantparam()->scale.size() != 1)
- return false;
-
- if (circle_node->quantparam()->zerop.size() != 1)
- return false;
-
- return true;
- }
-
-private:
- bool visit(const luci::CircleConv2D *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
- return true;
- }
-
- bool visit(const luci::CircleConcatenation *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- for (uint32_t i = 0; i < node->numValues(); i++)
- {
- RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
- }
- return true;
- }
-
- bool visit(const luci::CircleDepthToSpace *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- return true;
- }
-
- bool visit(const luci::CircleDepthwiseConv2D *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
- return true;
- }
-
- bool visit(const luci::CircleInstanceNorm *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_lwq_const(node->gamma()))
- RETURN_FALSE_UNLESS(is_lwq_const(node->beta()))
- return true;
- }
-
- bool visit(const luci::CirclePack *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- for (uint32_t i = 0; i < node->values_count(); i++)
- {
- RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
- }
- return true;
- }
-
- bool visit(const luci::CirclePad *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- return true;
- }
-
- bool visit(const luci::CirclePadV2 *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_lwq(node->constant_values()))
- return true;
- }
-
- bool visit(const luci::CircleMirrorPad *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- return true;
- }
-
- bool visit(const luci::CirclePRelu *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_lwq_const(node->alpha()))
- return true;
- }
-
- bool visit(const luci::CircleTransposeConv *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->outBackprop()))
- RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
- return true;
- }
-
- bool visit(const luci::CircleFullyConnected *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_lwq_const(node->weights()))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
- return true;
- }
-
- bool visit(const luci::CircleAdd *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- RETURN_FALSE_UNLESS(is_lwq(node->y()));
- return true;
- }
-
- bool visit(const luci::CircleAveragePool2D *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->value()));
- return true;
- }
-
- bool visit(const luci::CircleLogicalOr *)
- {
- // Logical OR has bool-type inputs and output
- // Nothing to be checked
- return true;
- }
-
- bool visit(const luci::CircleMaxPool2D *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->value()));
- return true;
- }
-
- bool visit(const luci::CircleLocalResponseNormalization *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleMean *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleMul *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- RETURN_FALSE_UNLESS(is_lwq(node->y()));
- return true;
- }
-
- bool visit(const luci::CircleNotEqual *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- RETURN_FALSE_UNLESS(is_lwq(node->y()));
- return true;
- }
-
- bool visit(const luci::CircleRelu *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->features()));
- return true;
- }
-
- bool visit(const luci::CircleReshape *node)
- {
- auto input = loco::must_cast<const luci::CircleNode *>(node->tensor());
- bool input_quantized = input->quantparam() != nullptr;
- bool node_quantized = node->quantparam() != nullptr;
- RETURN_FALSE_UNLESS(input_quantized == node_quantized);
- RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node))
- RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input));
- return true;
- }
-
- bool visit(const luci::CircleLogistic *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- return true;
- }
-
- bool visit(const luci::CircleSoftmax *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->logits()));
- return true;
- }
-
- bool visit(const luci::CircleSpaceToBatchND *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleSpaceToDepth *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleSlice *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleSplit *node)
- {
- // node's output is the input of CircleSplitOut, thus not quantized
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleSplitOut *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- return true;
- }
-
- bool visit(const luci::CircleSplitV *node)
- {
- // node's output is the input of CircleSplitVOut, thus not quantized
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleSplitVOut *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- return true;
- }
-
- bool visit(const luci::CircleStridedSlice *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleArgMax *node)
- {
- // node's output is index, thus not quantized
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleBatchToSpaceND *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleTanh *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- return true;
- }
-
- bool visit(const luci::CircleTranspose *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->a()));
- return true;
- }
-
- bool visit(const luci::CircleFloor *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- return true;
- }
-
- bool visit(const luci::CircleGreater *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- RETURN_FALSE_UNLESS(is_lwq(node->y()));
- return true;
- }
-
- bool visit(const luci::CircleGreaterEqual *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- RETURN_FALSE_UNLESS(is_lwq(node->y()));
- return true;
- }
-
- bool visit(const luci::CircleDiv *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- RETURN_FALSE_UNLESS(is_lwq(node->y()));
- return true;
- }
-
- bool visit(const luci::CircleFloorDiv *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- RETURN_FALSE_UNLESS(is_lwq(node->y()));
- return true;
- }
-
- bool visit(const luci::CircleRsqrt *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- return true;
- }
-
- bool visit(const luci::CircleSqrt *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- return true;
- }
-
- bool visit(const luci::CircleElu *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->features()));
- return true;
- }
-
- bool visit(const luci::CirclePow *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- RETURN_FALSE_UNLESS(is_lwq(node->y()));
- return true;
- }
-
- bool visit(const luci::CircleResizeBilinear *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleResizeNearestNeighbor *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleUnpack *node)
- {
- // node's output is the input of CircleUnpackOut, thus not quantized
- RETURN_FALSE_UNLESS(is_lwq(node->value()));
- return true;
- }
-
- bool visit(const luci::CircleUnpackOut *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- return true;
- }
-
- bool visit(const luci::CircleCast *node)
- {
- auto input = loco::must_cast<const luci::CircleNode *>(node->x());
- bool input_quantized = input->quantparam() != nullptr;
- bool node_quantized = node->quantparam() != nullptr;
- RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input));
- RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node));
- return true;
- }
-
- // TODO: Implement more Ops
-
- bool visit(const luci::CircleNode *) { return true; }
-};
-
-} // namespace luci
-
-#undef RETURN_FALSE_UNLESS
-
-#endif // __LUCI_VERIFY_QUANTIZED_NODE_LAYERWISE_GRANULARITY_H__
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeS16Type.h b/compiler/luci/pass/src/VerifyQuantizedNodeS16Type.h
deleted file mode 100644
index eeec7b82b..000000000
--- a/compiler/luci/pass/src/VerifyQuantizedNodeS16Type.h
+++ /dev/null
@@ -1,516 +0,0 @@
-/*
- * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __LUCI_VERIFY_QUANTIZED_NODE_S16_TYPE_H__
-#define __LUCI_VERIFY_QUANTIZED_NODE_S16_TYPE_H__
-
-#include <luci/IR/CircleNodes.h>
-#include <luci/IR/CircleNodeVisitor.h>
-
-#include <cmath>
-
-using Type = loco::DataType;
-
-// This macro is undef at the end of the file
-#define RETURN_FALSE_UNLESS(ARG) \
- if (not(ARG)) \
- { \
- return false; \
- }
-
-namespace luci
-{
-
-/**
- * @brief Verify the data type of INT16 quantized node
- * @details
- *
- * Targets to verify
- * - node's output (i.e., node itself)
- * - node's inputs
- */
-struct VerifyQuantizedNodeS16Type final : public luci::CircleNodeVisitor<bool>
-{
-private:
- bool has_type(const loco::Node *node, Type dtype)
- {
- auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
- return circle_node->dtype() == dtype;
- }
-
-private:
- bool visit(const luci::CircleConv2D *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->filter(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S64))
- return true;
- }
-
- bool visit(const luci::CircleConcatenation *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- for (uint32_t i = 0; i < node->numValues(); i++)
- {
- RETURN_FALSE_UNLESS(has_type(node->values(i), Type::S16))
- }
- return true;
- }
-
- bool visit(const luci::CircleDepthToSpace *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleDepthwiseConv2D *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->filter(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S64))
- return true;
- }
-
- bool visit(const luci::CircleInstanceNorm *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->gamma(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->beta(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CirclePack *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- for (uint32_t i = 0; i < node->values_count(); i++)
- {
- RETURN_FALSE_UNLESS(has_type(node->values(i), Type::S16))
- }
- return true;
- }
-
- bool visit(const luci::CirclePad *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CirclePadV2 *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32))
- RETURN_FALSE_UNLESS(has_type(node->constant_values(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleMirrorPad *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CirclePRelu *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->alpha(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleTransposeConv *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->outBackprop(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->filter(), Type::S16))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(has_type(bias, Type::S64))
- return true;
- }
-
- bool visit(const luci::CircleFullyConnected *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->weights(), Type::S16))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(has_type(bias, Type::S64))
- return true;
- }
-
- bool visit(const luci::CircleAdd *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleAveragePool2D *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->value(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleLogicalOr *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::BOOL))
- return true;
- }
-
- bool visit(const luci::CircleMaxPool2D *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->value(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleLocalResponseNormalization *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleMean *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->reduction_indices(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleMul *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleNotEqual *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleRelu *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->features(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleReshape *node)
- {
- if (node->quantparam())
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->tensor(), Type::S16))
- }
- else
- {
- RETURN_FALSE_UNLESS(has_type(node->tensor(), node->dtype()))
- }
- luci::CircleConst *shape = dynamic_cast<luci::CircleConst *>(node->shape());
- if (shape != nullptr)
- RETURN_FALSE_UNLESS(has_type(shape, Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleLogistic *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
-
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32768.0f);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
- return true;
- }
-
- bool visit(const luci::CircleSoftmax *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->logits(), Type::S16))
-
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32767.0f);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
- return true;
- }
-
- bool visit(const luci::CircleSpaceToBatchND *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleSpaceToDepth *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleSlice *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->begin(), Type::S32) || has_type(node->begin(), Type::S64))
- RETURN_FALSE_UNLESS(has_type(node->size(), Type::S32) || has_type(node->size(), Type::S64))
- return true;
- }
-
- bool visit(const luci::CircleSplit *node)
- {
- // node's output is the input of CircleSplitOut, thus not quantized
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleSplitOut *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
-
- // SplitOut has the same qparam with the input of Split
- auto split = loco::must_cast<luci::CircleSplit *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(split->input());
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
- return true;
- }
-
- bool visit(const luci::CircleSplitV *node)
- {
- // node's output is the input of CircleSplitVOut, thus not quantized
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleSplitVOut *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
-
- // SplitVOut has the same qparam with the input of SplitV
- auto splitv = loco::must_cast<luci::CircleSplitV *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(splitv->input());
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
- return true;
- }
-
- bool visit(const luci::CircleStridedSlice *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
-
- auto input = loco::must_cast<luci::CircleNode *>(node->input());
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
- return true;
- }
-
- bool visit(const luci::CircleArgMax *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, node->output_type()))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->dimension(), Type::S32) ||
- has_type(node->dimension(), Type::S64))
- return true;
- }
-
- bool visit(const luci::CircleBatchToSpaceND *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleTanh *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
-
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32768.0f);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
- return true;
- }
-
- bool visit(const luci::CircleTranspose *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->a(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->perm(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleFloor *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
-
- // This checks the value of scale is an integer
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]);
- return true;
- }
-
- bool visit(const luci::CircleGreater *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleGreaterEqual *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleDiv *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleFloorDiv *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
-
- // This checks the value of scale is an integer
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]);
- return true;
- }
-
- bool visit(const luci::CircleRsqrt *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleSqrt *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleElu *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->features(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CirclePow *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleResizeBilinear *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleResizeNearestNeighbor *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleUnpack *node)
- {
- // node's output is the input of CircleUnpackOut, thus not quantized
- RETURN_FALSE_UNLESS(has_type(node->value(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleUnpackOut *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
-
- // UnpackOut has the same qparam with the input of Unpack
- auto Unpack = loco::must_cast<luci::CircleUnpack *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(Unpack->value());
- RETURN_FALSE_UNLESS(node->quantparam() && input->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
- return true;
- }
-
- bool visit(const luci::CircleCast *node)
- {
- auto *input = loco::must_cast<luci::CircleNode *>(node->x());
- RETURN_FALSE_UNLESS(has_type(input, node->in_data_type()))
-
- bool input_quantized = input->quantparam() != nullptr;
- if (input_quantized)
- RETURN_FALSE_UNLESS(has_type(input, Type::S16))
-
- RETURN_FALSE_UNLESS(has_type(node, node->out_data_type()))
-
- bool node_quantized = node->quantparam() != nullptr;
- if (node_quantized)
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- return true;
- }
-
- // TODO: Implement more Ops
-
- bool visit(const luci::CircleNode *) { return true; }
-};
-
-} // namespace luci
-
-#undef RETURN_FALSE_UNLESS
-
-#endif // __LUCI_VERIFY_QUNTIZED_NODE_S16_TYPE_H__
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp b/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp
new file mode 100644
index 000000000..4e1c062c0
--- /dev/null
+++ b/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp
@@ -0,0 +1,554 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "VerifyQuantizedNodeType.h"
+
+#include <cmath>
+#include <memory>
+
+// This macro is undef at the end of the file
+#define RETURN_FALSE_UNLESS(ARG) \
+ if (not(ARG)) \
+ { \
+ return false; \
+ }
+
+namespace luci
+{
+
+std::shared_ptr<VerifyQuantizedNodeType> VerifyQuantizedNodeType::create(loco::DataType dtype)
+{
+ if (dtype == loco::DataType::U8)
+ return std::make_shared<VerifyQuantizedNodeU8Type>();
+ else if (dtype == loco::DataType::S16)
+ return std::make_shared<VerifyQuantizedNodeS16Type>();
+ else
+ throw std::domain_error("Not supported Quantized type");
+}
+
+} // namespace luci
+
+namespace luci
+{
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleAdd *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleArgMax *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, node->output_type()))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->dimension(), loco::DataType::S32) ||
+ has_type(node->dimension(), loco::DataType::S64))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleAveragePool2D *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleBatchToSpaceND *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleCast *node)
+{
+ auto *input = loco::must_cast<luci::CircleNode *>(node->x());
+ bool input_quantized = input->quantparam() != nullptr;
+ if (input_quantized)
+ {
+ RETURN_FALSE_UNLESS(has_type(input, node->in_data_type()))
+ RETURN_FALSE_UNLESS(has_type(input, Qtype))
+ }
+
+ bool node_quantized = node->quantparam() != nullptr;
+ if (node_quantized)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, node->out_data_type()))
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ }
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleConv2D *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->filter(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->bias(), Btype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleConcatenation *node)
+{
+ // Allow concatenation of indices
+ if (group_has_type(node, loco::DataType::S32) or group_has_type(node, loco::DataType::S64))
+ return true;
+
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleDepthToSpace *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleDepthwiseConv2D *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->filter(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->bias(), Btype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleDiv *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleElu *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleFloor *node)
+{
+ RETURN_FALSE_UNLESS(group_has_type(node, Qtype));
+
+ // This checks the value of scale is an integer
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]);
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleFloorDiv *node)
+{
+ RETURN_FALSE_UNLESS(group_has_type(node, Qtype));
+
+ // This checks the value of scale is an integer
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]);
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleFullyConnected *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->weights(), Qtype))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(has_type(bias, Btype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleGreater *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, loco::DataType::BOOL))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleGreaterEqual *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, loco::DataType::BOOL))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleInstanceNorm *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(
+ const luci::CircleLocalResponseNormalization *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleLogicalOr *node)
+{
+ return group_has_type(node, loco::DataType::BOOL);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleMaxPool2D *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleMean *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->reduction_indices(), loco::DataType::S32))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleMirrorPad *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->paddings(), loco::DataType::S32))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleMul *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleNotEqual *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, loco::DataType::BOOL))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleOneHot *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype));
+ RETURN_FALSE_UNLESS(has_type(node->indices(), loco::DataType::S32) ||
+ has_type(node->indices(), loco::DataType::S64));
+ RETURN_FALSE_UNLESS(has_type(node->depth(), loco::DataType::S32));
+ RETURN_FALSE_UNLESS(has_type(node->on_value(), Qtype));
+ RETURN_FALSE_UNLESS(has_type(node->off_value(), Qtype));
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePack *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePad *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->paddings(), loco::DataType::S32))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePadV2 *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->paddings(), loco::DataType::S32))
+ RETURN_FALSE_UNLESS(has_type(node->constant_values(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePRelu *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePow *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleRelu *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleReshape *node)
+{
+ if (node->quantparam())
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->tensor(), Qtype))
+ }
+ else
+ {
+ RETURN_FALSE_UNLESS(has_type(node->tensor(), node->dtype()))
+ }
+ luci::CircleConst *shape = dynamic_cast<luci::CircleConst *>(node->shape());
+ if (shape != nullptr)
+ RETURN_FALSE_UNLESS(has_type(shape, loco::DataType::S32))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleResizeBilinear *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleResizeNearestNeighbor *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleRsqrt *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSlice *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->begin(), loco::DataType::S32) ||
+ has_type(node->begin(), loco::DataType::S64))
+ RETURN_FALSE_UNLESS(has_type(node->size(), loco::DataType::S32) ||
+ has_type(node->size(), loco::DataType::S64))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSpaceToBatchND *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSpaceToDepth *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSplit *node)
+{
+ // node's output is the input of CircleSplitOut, thus not quantized
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSplitOut *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+
+ // SplitOut has the same qparam with the input of Split
+ auto split = loco::must_cast<luci::CircleSplit *>(node->input());
+ auto input = loco::must_cast<luci::CircleNode *>(split->input());
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSplitV *node)
+{
+ // node's output is the input of CircleSplitVOut, thus not quantized
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSplitVOut *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+
+ // SplitVOut has the same qparam with the input of SplitV
+ auto splitv = loco::must_cast<luci::CircleSplitV *>(node->input());
+ auto input = loco::must_cast<luci::CircleNode *>(splitv->input());
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSqrt *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleStridedSlice *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+
+ auto input = loco::must_cast<luci::CircleNode *>(node->input());
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleTranspose *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->a(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->perm(), loco::DataType::S32))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleTransposeConv *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->outBackprop(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->filter(), Qtype))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(has_type(bias, Btype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleUnpack *node)
+{
+ // node's output is the input of CircleUnpackOut, thus not quantized
+ RETURN_FALSE_UNLESS(has_type(node->value(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleUnpackOut *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+
+ // UnpackOut has the same qparam with the input of Unpack
+ auto Unpack = loco::must_cast<luci::CircleUnpack *>(node->input());
+ auto input = loco::must_cast<luci::CircleNode *>(Unpack->value());
+ RETURN_FALSE_UNLESS(node->quantparam() && input->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
+ return true;
+}
+
+} // namespace luci
+
+namespace luci
+{
+
+bool VerifyQuantizedNodeU8Type::visit(const luci::CircleTanh *node)
+{
+ RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::U8));
+
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 2.0f / 256.0f);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 128);
+ return true;
+}
+
+bool VerifyQuantizedNodeU8Type::visit(const luci::CircleLogistic *node)
+{
+ RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::U8));
+
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 256.0f);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
+ return true;
+}
+
+bool VerifyQuantizedNodeU8Type::visit(const luci::CircleSoftmax *node)
+{
+ RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::U8));
+
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 255.0f);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
+ return true;
+}
+
+} // namespace luci
+
+namespace luci
+{
+
+bool VerifyQuantizedNodeS16Type::visit(const luci::CircleTanh *node)
+{
+ RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::S16));
+
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32768.0f);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
+ return true;
+}
+
+bool VerifyQuantizedNodeS16Type::visit(const luci::CircleLogistic *node)
+{
+ RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::S16));
+
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32768.0f);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
+ return true;
+}
+
+bool VerifyQuantizedNodeS16Type::visit(const luci::CircleSoftmax *node)
+{
+ RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::S16));
+
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32767.0f);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
+ return true;
+}
+
+} // namespace luci
+
+#undef RETURN_FALSE_UNLESS
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeType.h b/compiler/luci/pass/src/VerifyQuantizedNodeType.h
new file mode 100644
index 000000000..ff1acbd6f
--- /dev/null
+++ b/compiler/luci/pass/src/VerifyQuantizedNodeType.h
@@ -0,0 +1,157 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_VERIFY_QUANTIZED_NODE_TYPE_H__
+#define __LUCI_VERIFY_QUANTIZED_NODE_TYPE_H__
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+namespace luci
+{
+
+/**
+ * @brief Verify the data type of quantized node
+ * @details
+ *
+ * Targets to verify
+ * - node's output (i.e., node itself)
+ * - node's inputs
+ */
+class VerifyQuantizedNodeType
+{
+public:
+ static std::shared_ptr<VerifyQuantizedNodeType> create(loco::DataType dtype);
+
+public:
+ virtual bool verify(luci::CircleNode *node) = 0;
+};
+
+/**
+ * @brief Verify using quantization type of a node and bias
+ *
+ * @tparam Qtype Quantization type for a node (e.g. Q8, Q16, ...)
+ * @tparam Btype Bias quantization type (e.g. For Q8, S32 is used)
+ */
+template <loco::DataType Qtype, loco::DataType Btype>
+class VerifyQuantizedNodeTypeBase : public luci::CircleNodeVisitor<bool>,
+ public VerifyQuantizedNodeType
+{
+public:
+ bool verify(luci::CircleNode *node) { return node->accept(this); }
+
+protected:
+ bool has_type(const loco::Node *node, loco::DataType dtype)
+ {
+ auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
+ return circle_node->dtype() == dtype;
+ }
+
+ // Check whether a node and all of its inputs have dtype or not
+ bool group_has_type(const loco::Node *node, loco::DataType dtype)
+ {
+ if (!has_type(node, dtype))
+ return false;
+
+ for (uint32_t i = 0; i < node->arity(); ++i)
+ if (!has_type(node->arg(i), dtype))
+ return false;
+
+ return true;
+ }
+
+private:
+ bool visit(const luci::CircleAdd *node);
+ bool visit(const luci::CircleArgMax *node);
+ bool visit(const luci::CircleAveragePool2D *node);
+ bool visit(const luci::CircleBatchToSpaceND *node);
+ bool visit(const luci::CircleCast *node);
+ bool visit(const luci::CircleConv2D *node);
+ bool visit(const luci::CircleConcatenation *node);
+ bool visit(const luci::CircleDepthToSpace *node);
+ bool visit(const luci::CircleDepthwiseConv2D *node);
+ bool visit(const luci::CircleDiv *node);
+ bool visit(const luci::CircleElu *node);
+ bool visit(const luci::CircleFloor *node);
+ bool visit(const luci::CircleFloorDiv *node);
+ bool visit(const luci::CircleFullyConnected *node);
+ bool visit(const luci::CircleGreater *node);
+ bool visit(const luci::CircleGreaterEqual *node);
+ bool visit(const luci::CircleInstanceNorm *node);
+ bool visit(const luci::CircleLocalResponseNormalization *node);
+ bool visit(const luci::CircleLogicalOr *node);
+ bool visit(const luci::CircleMaxPool2D *node);
+ bool visit(const luci::CircleMean *node);
+ bool visit(const luci::CircleMirrorPad *node);
+ bool visit(const luci::CircleMul *node);
+ bool visit(const luci::CircleNotEqual *node);
+ bool visit(const luci::CircleOneHot *node);
+ bool visit(const luci::CirclePack *node);
+ bool visit(const luci::CirclePad *node);
+ bool visit(const luci::CirclePadV2 *node);
+ bool visit(const luci::CirclePRelu *node);
+ bool visit(const luci::CirclePow *node);
+ bool visit(const luci::CircleRelu *node);
+ bool visit(const luci::CircleReshape *node);
+ bool visit(const luci::CircleResizeBilinear *node);
+ bool visit(const luci::CircleResizeNearestNeighbor *node);
+ bool visit(const luci::CircleRsqrt *node);
+ bool visit(const luci::CircleSlice *node);
+ bool visit(const luci::CircleSpaceToBatchND *node);
+ bool visit(const luci::CircleSpaceToDepth *node);
+ bool visit(const luci::CircleSplit *node);
+ bool visit(const luci::CircleSplitOut *node);
+ bool visit(const luci::CircleSplitV *node);
+ bool visit(const luci::CircleSplitVOut *node);
+ bool visit(const luci::CircleSqrt *node);
+ bool visit(const luci::CircleStridedSlice *node);
+ bool visit(const luci::CircleTranspose *node);
+ bool visit(const luci::CircleTransposeConv *node);
+ bool visit(const luci::CircleUnpack *node);
+ bool visit(const luci::CircleUnpackOut *node);
+
+ // NOTE below nodes has differnent implementation for Qtype/Btype and
+ // implementations exist in VerifyQuantizedNodeU8Type, VerifyQuantizedNodeS16Type
+ // bool visit(const luci::CircleLogistic *node);
+ // bool visit(const luci::CircleSoftmax *node);
+ // bool visit(const luci::CircleTanh *node);
+
+ // TODO: Implement more Ops
+
+ bool visit(const luci::CircleNode *) { return true; }
+};
+
+class VerifyQuantizedNodeU8Type
+ : public VerifyQuantizedNodeTypeBase<loco::DataType::U8, loco::DataType::S32>
+{
+private:
+ bool visit(const luci::CircleLogistic *node);
+ bool visit(const luci::CircleSoftmax *node);
+ bool visit(const luci::CircleTanh *node);
+};
+
+class VerifyQuantizedNodeS16Type
+ : public VerifyQuantizedNodeTypeBase<loco::DataType::S16, loco::DataType::S64>
+{
+private:
+ bool visit(const luci::CircleLogistic *node);
+ bool visit(const luci::CircleSoftmax *node);
+ bool visit(const luci::CircleTanh *node);
+};
+
+} // namespace luci
+
+#endif // __LUCI_VERIFY_QUANTIZED_NODE_TYPE_H__
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h b/compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h
deleted file mode 100644
index e7dd1b072..000000000
--- a/compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h
+++ /dev/null
@@ -1,518 +0,0 @@
-/*
- * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __LUCI_VERIFY_QUANTIZED_NODE_U8_TYPE_H__
-#define __LUCI_VERIFY_QUANTIZED_NODE_U8_TYPE_H__
-
-#include <luci/IR/CircleNodes.h>
-#include <luci/IR/CircleNodeVisitor.h>
-
-#include <cmath>
-
-using Type = loco::DataType;
-
-// This macro is undef at the end of the file
-#define RETURN_FALSE_UNLESS(ARG) \
- if (not(ARG)) \
- { \
- return false; \
- }
-
-namespace luci
-{
-
-/**
- * @brief Verify the data type of UINT8 quantized node
- * @details
- *
- * Targets to verify
- * - node's output (i.e., node itself)
- * - node's inputs
- */
-struct VerifyQuantizedNodeU8Type final : public luci::CircleNodeVisitor<bool>
-{
-private:
- bool has_type(const loco::Node *node, Type dtype)
- {
- auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
- return circle_node->dtype() == dtype;
- }
-
-private:
- bool visit(const luci::CircleConv2D *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->filter(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleConcatenation *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- for (uint32_t i = 0; i < node->numValues(); i++)
- {
- RETURN_FALSE_UNLESS(has_type(node->values(i), Type::U8))
- }
- return true;
- }
-
- bool visit(const luci::CircleDepthToSpace *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleDepthwiseConv2D *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->filter(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleInstanceNorm *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->gamma(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->beta(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CirclePack *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- for (uint32_t i = 0; i < node->values_count(); i++)
- {
- RETURN_FALSE_UNLESS(has_type(node->values(i), Type::U8))
- }
- return true;
- }
-
- bool visit(const luci::CirclePad *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CirclePadV2 *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32))
- RETURN_FALSE_UNLESS(has_type(node->constant_values(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleMirrorPad *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CirclePRelu *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->alpha(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleTransposeConv *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->outBackprop(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->filter(), Type::U8))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(has_type(bias, Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleFullyConnected *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->weights(), Type::U8))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(has_type(bias, Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleAdd *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleAveragePool2D *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->value(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleBatchToSpaceND *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleLogicalOr *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::BOOL))
- return true;
- }
-
- bool visit(const luci::CircleMaxPool2D *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->value(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleLocalResponseNormalization *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleMean *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->reduction_indices(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleMul *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleNotEqual *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleRelu *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->features(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleReshape *node)
- {
- if (node->quantparam())
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->tensor(), Type::U8))
- }
- else
- {
- RETURN_FALSE_UNLESS(has_type(node->tensor(), node->dtype()))
- }
- luci::CircleConst *shape = dynamic_cast<luci::CircleConst *>(node->shape());
- if (shape != nullptr)
- RETURN_FALSE_UNLESS(has_type(shape, Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleLogistic *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
-
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 256.0f);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
- return true;
- }
-
- bool visit(const luci::CircleSoftmax *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->logits(), Type::U8))
-
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 255.0f);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
- return true;
- }
-
- bool visit(const luci::CircleSpaceToBatchND *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleSpaceToDepth *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleSlice *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->begin(), Type::S32) || has_type(node->begin(), Type::S64))
- RETURN_FALSE_UNLESS(has_type(node->size(), Type::S32) || has_type(node->size(), Type::S64))
- return true;
- }
-
- bool visit(const luci::CircleSplit *node)
- {
- // node's output is the input of CircleSplitOut, thus not quantized
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleSplitOut *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
-
- // SplitOut has the same qparam with the input of Split
- auto split = loco::must_cast<luci::CircleSplit *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(split->input());
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
- return true;
- }
-
- bool visit(const luci::CircleSplitV *node)
- {
- // node's output is the input of CircleSplitVOut, thus not quantized
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleSplitVOut *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
-
- // SplitVOut has the same qparam with the input of SplitV
- auto splitv = loco::must_cast<luci::CircleSplitV *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(splitv->input());
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
- return true;
- }
-
- bool visit(const luci::CircleStridedSlice *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
-
- auto input = loco::must_cast<luci::CircleNode *>(node->input());
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
- return true;
- }
-
- bool visit(const luci::CircleArgMax *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, node->output_type()))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->dimension(), Type::S32) ||
- has_type(node->dimension(), Type::S64))
- return true;
- }
-
- bool visit(const luci::CircleTanh *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
-
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 2.0f / 256.0f);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 128);
- return true;
- }
-
- bool visit(const luci::CircleTranspose *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->a(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->perm(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleFloor *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
-
- // This checks the value of scale is an integer
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]);
- return true;
- }
-
- bool visit(const luci::CircleGreater *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleGreaterEqual *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleDiv *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleFloorDiv *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
-
- // This checks the value of scale is an integer
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]);
- return true;
- }
-
- bool visit(const luci::CircleRsqrt *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleSqrt *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleElu *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->features(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CirclePow *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleResizeBilinear *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleResizeNearestNeighbor *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleUnpack *node)
- {
- // node's output is the input of CircleUnpackOut, thus not quantized
- RETURN_FALSE_UNLESS(has_type(node->value(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleUnpackOut *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
-
- // UnpackOut has the same qparam with the input of Unpack
- auto Unpack = loco::must_cast<luci::CircleUnpack *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(Unpack->value());
- RETURN_FALSE_UNLESS(node->quantparam() && input->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
- return true;
- }
-
- bool visit(const luci::CircleCast *node)
- {
- auto *input = loco::must_cast<luci::CircleNode *>(node->x());
- bool input_quantized = input->quantparam() != nullptr;
- if (input_quantized)
- {
- RETURN_FALSE_UNLESS(has_type(input, node->in_data_type()))
- RETURN_FALSE_UNLESS(has_type(input, Type::U8))
- }
-
- bool node_quantized = node->quantparam() != nullptr;
- if (node_quantized)
- {
- RETURN_FALSE_UNLESS(has_type(node, node->out_data_type()))
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- }
- return true;
- }
-
- // TODO: Implement more Ops
-
- bool visit(const luci::CircleNode *) { return true; }
-};
-
-} // namespace luci
-
-#undef RETURN_FALSE_UNLESS
-
-#endif // __LUCI_VERIFY_QUNTIZED_NODE_U8_TYPE_H__
diff --git a/compiler/luci/pass/src/helpers/LayerInfoMap.cpp b/compiler/luci/pass/src/helpers/LayerInfoMap.cpp
new file mode 100644
index 000000000..ac07f9ec9
--- /dev/null
+++ b/compiler/luci/pass/src/helpers/LayerInfoMap.cpp
@@ -0,0 +1,189 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "LayerInfoMap.h"
+
+#include <luci/IR/CircleNode.h>
+
+#include <cassert>
+
+namespace luci
+{
+namespace
+{
+
+bool is_multiple_output_node(const luci::CircleNode *node)
+{
+ switch (node->opcode())
+ {
+ // The following nodes have multiple outputs. Output tensors are not produced by themselves but
+ // by the corresponding *Out nodes.
+ case luci::CircleOpcode::SPLIT:
+ case luci::CircleOpcode::SPLIT_V:
+ case luci::CircleOpcode::TOPK_V2:
+ case luci::CircleOpcode::UNIQUE:
+ case luci::CircleOpcode::UNPACK:
+ return true;
+ // TODO: Support ops
+ case luci::CircleOpcode::BIDIRECTIONAL_SEQUENCE_LSTM:
+ case luci::CircleOpcode::CUSTOM:
+ case luci::CircleOpcode::IF:
+ case luci::CircleOpcode::NON_MAX_SUPPRESSION_V4:
+ case luci::CircleOpcode::NON_MAX_SUPPRESSION_V5:
+ case luci::CircleOpcode::WHILE:
+ throw std::runtime_error("Unsupported op now");
+ default:
+ return false;
+ }
+}
+
+const luci::CircleNode *get_multi_output_node(const luci::CircleNode *node)
+{
+ if (is_multiple_output_node(node))
+ return node;
+
+ switch (node->opcode())
+ {
+ // The following nodes denote outputs of multiple-output nodes.
+ case luci::CircleOpcode::CIRCLESPLITOUT:
+ {
+ const auto split_out = loco::must_cast<const CircleSplitOut *>(node);
+ return loco::must_cast<luci::CircleNode *>(split_out->input());
+ }
+ case luci::CircleOpcode::CIRCLESPLITVOUT:
+ {
+ const auto splitv_out = loco::must_cast<const CircleSplitVOut *>(node);
+ return loco::must_cast<luci::CircleNode *>(splitv_out->input());
+ }
+ case luci::CircleOpcode::CIRCLETOPKV2OUT:
+ {
+ const auto top_kv2_out = loco::must_cast<const CircleTopKV2Out *>(node);
+ return loco::must_cast<luci::CircleNode *>(top_kv2_out->input());
+ }
+ case luci::CircleOpcode::CIRCLEUNIQUEOUT:
+ {
+ const auto unique_out = loco::must_cast<const CircleUniqueOut *>(node);
+ return loco::must_cast<luci::CircleNode *>(unique_out->input());
+ }
+ case luci::CircleOpcode::CIRCLEUNPACKOUT:
+ {
+ const auto unpack_out = loco::must_cast<const CircleUnpackOut *>(node);
+ return loco::must_cast<luci::CircleNode *>(unpack_out->input());
+ }
+ // TODO: Support these ops
+ case luci::CircleOpcode::CIRCLEBIDIRECTIONAL_SEQUENCE_LSTM_OUT:
+ case luci::CircleOpcode::CIRCLECUSTOMOUT:
+ case luci::CircleOpcode::CIRCLEIFOUT:
+ case luci::CircleOpcode::CIRCLENONMAXSUPPRESSIONV4OUT:
+ case luci::CircleOpcode::CIRCLENONMAXSUPPRESSIONV5OUT:
+ case luci::CircleOpcode::CIRCLEWHILEOUT:
+ throw std::runtime_error("Unsupported op now");
+ default:
+ return nullptr;
+ }
+}
+
+bool same_setting(const LayerInfo &left, const LayerInfo &right)
+{
+ return left.dtype == right.dtype and left.granularity == right.granularity;
+}
+
+void add_multi_output_node(LayerInfoMap &info_by_name, LayerInfo &layer_info,
+ const luci::CircleNode *node)
+{
+ assert(is_multiple_output_node(node)); // FIX_CALLER_UNLESS
+
+ const auto succs_nodes = loco::succs(node);
+ const auto name = node->name();
+
+ if (info_by_name.find(name) != info_by_name.end())
+ {
+ // Check that all outputs have equal dtype and granularity
+ for (const auto succs_node : succs_nodes)
+ {
+ const auto succs_circle_node = loco::must_cast<luci::CircleNode *>(succs_node);
+
+ const auto it = info_by_name.find(succs_circle_node->name());
+ if (it != info_by_name.end() and not same_setting(layer_info, (it->second)))
+ throw std::runtime_error("Outputs of multiple-output nodes should have equal dtype and "
+ "granularity. Check the quantization configuration file");
+ }
+ return;
+ }
+
+ // Add multiple output node to info_by_name
+ info_by_name[name] = {name, layer_info.dtype, layer_info.granularity};
+
+ // Add outputs node to info_by_name
+ for (const auto succs_node : succs_nodes)
+ {
+ const auto succs_circle_node = loco::must_cast<luci::CircleNode *>(succs_node);
+ const auto succs_circle_node_name = succs_circle_node->name();
+ info_by_name[succs_circle_node_name] = {succs_circle_node_name, layer_info.dtype,
+ layer_info.granularity};
+ }
+}
+
+} // namespace
+
+LayerInfoMap layer_info_map(loco::Graph *g, std::vector<LayerInfo> &layers_info)
+{
+ LayerInfoMap info_by_name;
+
+ for (auto &&info : layers_info)
+ {
+ auto name = info.name;
+ bool found = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto cnode = loco::must_cast<luci::CircleNode *>(node);
+ if (cnode->opcode() == luci::CircleOpcode::CIRCLEOUTPUT)
+ continue;
+
+ if (cnode->name() == name)
+ {
+ // Check and add multiple-output node and its outputs to info_by_name
+ if (const auto multi_output = get_multi_output_node(cnode))
+ {
+ add_multi_output_node(info_by_name, info, multi_output);
+ found = true;
+ continue;
+ }
+
+ if (info_by_name.find(name) != info_by_name.end())
+ {
+ throw std::runtime_error("Duplicate layer name " + name +
+ ". Check layer names in the quantization configuration file.");
+ }
+
+ info_by_name[name] = info;
+ found = true;
+ continue;
+ }
+ }
+
+ if (not found)
+ throw std::runtime_error("No such layer named " + name +
+ ". Check layer names in the quantization configuration file.");
+ }
+
+ // TODO Check all names in layers_info exist in the info_by_name
+ // TODO Check names in info_by_name but not in layers_info are from virtual outputs
+
+ return info_by_name;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/helpers/LayerInfoMap.h b/compiler/luci/pass/src/helpers/LayerInfoMap.h
new file mode 100644
index 000000000..bb4724a50
--- /dev/null
+++ b/compiler/luci/pass/src/helpers/LayerInfoMap.h
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_PASS_HELPERS_LAYER_INFO_MAP_H__
+#define __LUCI_PASS_HELPERS_LAYER_INFO_MAP_H__
+
+#include <luci/Pass/QuantizationParameters.h>
+
+#include <unordered_map>
+
+namespace luci
+{
+
+using LayerInfoMap = std::unordered_map<std::string, luci::LayerInfo>;
+
+LayerInfoMap layer_info_map(loco::Graph *g, std::vector<LayerInfo> &layers_info);
+
+} // namespace luci
+
+#endif // __LUCI_PASS_HELPERS_LAYER_INFO_MAP_H__
diff --git a/compiler/luci/pass/src/helpers/LayerInfoMap.test.cpp b/compiler/luci/pass/src/helpers/LayerInfoMap.test.cpp
new file mode 100644
index 000000000..2ed28eda4
--- /dev/null
+++ b/compiler/luci/pass/src/helpers/LayerInfoMap.test.cpp
@@ -0,0 +1,201 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "LayerInfoMap.h"
+
+#include <luci/IR/CircleNode.h>
+#include <luci/test/TestIOGraph.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+class SoftmaxTestGraph : public luci::test::TestIOGraph
+{
+public:
+ void init(void)
+ {
+ TestIOGraph::init({32}, {32});
+ _softmax = g()->nodes()->create<luci::CircleSoftmax>();
+ {
+ _softmax->logits(input());
+ _softmax->beta(0.1);
+ _softmax->name("test");
+ }
+ output()->from(_softmax);
+ }
+
+private:
+ luci::CircleSoftmax *_softmax = nullptr;
+};
+
+class SplitAddTestGraph : public luci::test::TestIOGraph
+{
+public:
+ void init(void)
+ {
+ TestIOGraph::init({6, 1, 2}, {3, 1, 2});
+ _split_dim = g()->nodes()->create<luci::CircleConst>();
+ {
+ _split_dim->rank(1);
+ _split_dim->dtype(loco::DataType::S32);
+ _split_dim->size<loco::DataType::S32>(1);
+ _split_dim->at<loco::DataType::S32>(0);
+ _split_dim->shape({1});
+ _split_dim->name("split_dim");
+ }
+
+ _split = g()->nodes()->create<luci::CircleSplit>();
+ {
+ _split->input(input());
+ _split->num_split(2);
+ _split->split_dim(_split_dim);
+ _split->name("split0");
+ }
+
+ _split_out_1 = g()->nodes()->create<luci::CircleSplitOut>();
+ {
+ _split_out_1->input(_split);
+ _split_out_1->index(0);
+ _split_out_1->name("split0");
+ }
+
+ _split_out_2 = g()->nodes()->create<luci::CircleSplitOut>();
+ {
+ _split_out_2->input(_split);
+ _split_out_2->index(1);
+ _split_out_2->name("split1");
+ }
+
+ _add = g()->nodes()->create<luci::CircleAdd>();
+ {
+ _add->x(_split_out_1);
+ _add->y(_split_out_2);
+ _add->name("add");
+ }
+ output()->from(_add);
+ }
+
+private:
+ luci::CircleSplit *_split = nullptr;
+ luci::CircleSplitOut *_split_out_1 = nullptr;
+ luci::CircleSplitOut *_split_out_2 = nullptr;
+ luci::CircleConst *_split_dim = nullptr;
+ luci::CircleAdd *_add = nullptr;
+};
+
+} // namespace
+
+TEST(LayerInfoMapTest, simple_test)
+{
+ SoftmaxTestGraph g;
+ g.init();
+
+ luci::LayerInfo info;
+ {
+ info.name = "test";
+ info.dtype = loco::DataType::U8;
+ info.granularity = luci::QuantizationGranularity::ChannelWise;
+ }
+ std::vector<luci::LayerInfo> v;
+ v.emplace_back(info);
+ auto map = luci::layer_info_map(g.g(), v);
+
+ EXPECT_EQ("test", map["test"].name);
+ EXPECT_EQ(loco::DataType::U8, map["test"].dtype);
+ EXPECT_EQ(luci::QuantizationGranularity::ChannelWise, map["test"].granularity);
+}
+
+TEST(LayerInfoMapTest, multiple_output_node_test)
+{
+ SplitAddTestGraph g;
+ g.init();
+
+ luci::LayerInfo info;
+ {
+ info.name = "split0";
+ info.dtype = loco::DataType::U8;
+ info.granularity = luci::QuantizationGranularity::ChannelWise;
+ }
+ std::vector<luci::LayerInfo> v;
+ v.emplace_back(info);
+ auto map = luci::layer_info_map(g.g(), v);
+
+ EXPECT_EQ(map.size(), 2);
+ EXPECT_EQ("split0", map["split0"].name);
+ EXPECT_EQ("split1", map["split1"].name);
+
+ EXPECT_EQ(loco::DataType::U8, map["split0"].dtype);
+ EXPECT_EQ(luci::QuantizationGranularity::ChannelWise, map["split0"].granularity);
+}
+
+TEST(LayerInfoMapTest, invalid_layer_info_multiple_output_node_NEG)
+{
+ SplitAddTestGraph g;
+ g.init();
+
+ luci::LayerInfo info_0;
+ {
+ info_0.name = "split0";
+ info_0.dtype = loco::DataType::U8;
+ info_0.granularity = luci::QuantizationGranularity::ChannelWise;
+ }
+ luci::LayerInfo info_1;
+ {
+ info_1.name = "split1";
+ info_1.dtype = loco::DataType::S16;
+ info_1.granularity = luci::QuantizationGranularity::ChannelWise;
+ }
+ std::vector<luci::LayerInfo> v;
+ v.emplace_back(info_0);
+ v.emplace_back(info_1);
+
+ EXPECT_ANY_THROW(luci::layer_info_map(g.g(), v));
+}
+
+TEST(LayerInfoMapTest, duplicate_name_NEG)
+{
+ SoftmaxTestGraph g;
+ g.init();
+ g.input()->name("test");
+
+ luci::LayerInfo info;
+ {
+ info.name = "test";
+ info.dtype = loco::DataType::U8;
+ info.granularity = luci::QuantizationGranularity::ChannelWise;
+ }
+ std::vector<luci::LayerInfo> v;
+ v.emplace_back(info);
+ EXPECT_ANY_THROW(luci::layer_info_map(g.g(), v));
+}
+
+TEST(LayerInfoMapTest, no_name_NEG)
+{
+ SoftmaxTestGraph g;
+ g.init();
+
+ luci::LayerInfo info;
+ {
+ info.name = "noname";
+ info.dtype = loco::DataType::U8;
+ info.granularity = luci::QuantizationGranularity::ChannelWise;
+ }
+ std::vector<luci::LayerInfo> v;
+ v.emplace_back(info);
+ EXPECT_ANY_THROW(luci::layer_info_map(g.g(), v));
+}
diff --git a/compiler/luci/requires.cmake b/compiler/luci/requires.cmake
index 3ccc58128..e896188be 100644
--- a/compiler/luci/requires.cmake
+++ b/compiler/luci/requires.cmake
@@ -4,8 +4,8 @@ require("loco")
require("locop")
require("logo")
require("logo-core")
-require("mio-circle")
-require("mio-tflite")
+require("mio-circle04")
+require("mio-tflite280")
require("oops")
require("hermes")
require("hermes-std")
diff --git a/compiler/luci/service/CMakeLists.txt b/compiler/luci/service/CMakeLists.txt
index 0e6097f96..24bdfc152 100644
--- a/compiler/luci/service/CMakeLists.txt
+++ b/compiler/luci/service/CMakeLists.txt
@@ -10,7 +10,6 @@ add_library(luci_service ${LUCI_LIBRARY_TYPE} ${SOURCES})
target_include_directories(luci_service PRIVATE src)
target_include_directories(luci_service PUBLIC include)
target_link_libraries(luci_service PUBLIC luci_lang)
-target_link_libraries(luci_service PUBLIC mio_circle)
target_link_libraries(luci_service PUBLIC logo_core)
target_link_libraries(luci_service PRIVATE luci_log)
target_link_libraries(luci_service PRIVATE luci_logex)
diff --git a/compiler/luci/service/include/luci/Service/CircleShapeInference.h b/compiler/luci/service/include/luci/Service/CircleShapeInference.h
index ead12d074..2c1120941 100644
--- a/compiler/luci/service/include/luci/Service/CircleShapeInference.h
+++ b/compiler/luci/service/include/luci/Service/CircleShapeInference.h
@@ -17,11 +17,12 @@
#ifndef __LUCI_CIRCLE_SHAPE_INFERENCE_H__
#define __LUCI_CIRCLE_SHAPE_INFERENCE_H__
-#include <loco/IR/Nodes.h>
-
+#include <luci/Service/CircleShapeInferenceRule.h>
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
-#include <luci/Service/CircleShapeInferenceRule.h>
+
+#include <loco/IR/NodeShape.h>
+#include <loco/IR/TensorShape.h>
namespace luci
{
diff --git a/compiler/luci/service/include/luci/Service/CircleTypeInference.h b/compiler/luci/service/include/luci/Service/CircleTypeInference.h
index d62731380..e0ceabeac 100644
--- a/compiler/luci/service/include/luci/Service/CircleTypeInference.h
+++ b/compiler/luci/service/include/luci/Service/CircleTypeInference.h
@@ -17,13 +17,11 @@
#ifndef __LUCI_CIRCLE_TYPE_INFERENCE_H__
#define __LUCI_CIRCLE_TYPE_INFERENCE_H__
-#include <loco/IR/Nodes.h>
-
-#include <mio/circle/schema_generated.h>
-
+#include <luci/Service/CircleTypeInferenceRule.h>
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
-#include <luci/Service/CircleTypeInferenceRule.h>
+
+#include <loco/IR/DataType.h>
namespace luci
{
diff --git a/compiler/luci/service/src/CircleCloneNode.h b/compiler/luci/service/src/CircleCloneNode.h
index 3926147f5..99e4561b3 100644
--- a/compiler/luci/service/src/CircleCloneNode.h
+++ b/compiler/luci/service/src/CircleCloneNode.h
@@ -208,6 +208,7 @@ public:
luci::CircleNode *visit(const luci::CircleSquaredDifference *) final;
luci::CircleNode *visit(const luci::CircleSqueeze *) final;
luci::CircleNode *visit(const luci::CircleStridedSlice *) final;
+ luci::CircleNode *visit(const luci::CircleSVDF *) final;
luci::CircleNode *visit(const luci::CircleSub *) final;
luci::CircleNode *visit(const luci::CircleSum *) final;
luci::CircleNode *visit(const luci::CircleTanh *) final;
@@ -269,6 +270,7 @@ public:
luci::CircleNode *visit(const luci::CircleTopKV2Out *) final;
luci::CircleNode *visit(const luci::CircleUniqueOut *) final;
luci::CircleNode *visit(const luci::CircleUnpackOut *) final;
+ luci::CircleNode *visit(const luci::CircleVariable *) final;
luci::CircleNode *visit(const luci::CircleWhileOut *) final;
// Handle in CircleNode
diff --git a/compiler/luci/service/src/CircleNodeClone.cpp b/compiler/luci/service/src/CircleNodeClone.cpp
index d2033dd0c..220c6096c 100644
--- a/compiler/luci/service/src/CircleNodeClone.cpp
+++ b/compiler/luci/service/src/CircleNodeClone.cpp
@@ -14,6 +14,7 @@
* limitations under the License.
*/
+#include "luci/IR/CircleQuantParam.h"
#include "luci/Service/CircleNodeClone.h"
#include "CircleCloneNode.h"
@@ -45,18 +46,7 @@ void copy_common_attributes(const luci::CircleNode *src, luci::CircleNode *dst)
dst->shape_status(src->shape_status());
// quantparam
- const auto *quantparam = src->quantparam();
- if (quantparam != nullptr)
- {
- auto qparam = std::make_unique<luci::CircleQuantParam>();
- qparam->scale = quantparam->scale;
- qparam->zerop = quantparam->zerop;
- qparam->min = quantparam->min;
- qparam->max = quantparam->max;
- qparam->quantized_dimension = quantparam->quantized_dimension;
-
- dst->quantparam(std::move(qparam));
- }
+ copy_quantparam(src, dst);
// sparsity
const auto *sparsity = src->sparsityparam();
diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.cpp
index 5d6a31050..9d156f3e2 100644
--- a/compiler/luci/service/src/CircleShapeInferenceRule.cpp
+++ b/compiler/luci/service/src/CircleShapeInferenceRule.cpp
@@ -1,5 +1,6 @@
/*
* Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -196,23 +197,18 @@ template <class CIRCLENODE> loco::NodeShape broadcast_xy(const CIRCLENODE *node)
return loco::NodeShape{output_shape};
}
-template <class CIRCLENODE> loco::NodeShape use_inputs(const CIRCLENODE *node)
-{
- auto inputs_shape = luci::shape_get(node->inputs()).template as<loco::TensorShape>();
- return loco::NodeShape{inputs_shape};
-}
+#define DECLARE_USE_SINGLE(NAME) \
+ template <class CIRCLENODE> loco::NodeShape use_##NAME(const CIRCLENODE *node) \
+ { \
+ auto inputs_shape = luci::shape_get(node->NAME()).template as<loco::TensorShape>(); \
+ return loco::NodeShape{inputs_shape}; \
+ }
-template <class CIRCLENODE> loco::NodeShape use_x(const CIRCLENODE *node)
-{
- auto x_shape = luci::shape_get(node->x()).template as<loco::TensorShape>();
- return loco::NodeShape{x_shape};
-}
+DECLARE_USE_SINGLE(inputs);
+DECLARE_USE_SINGLE(x);
+DECLARE_USE_SINGLE(logits);
-template <class CIRCLENODE> loco::NodeShape use_logits(const CIRCLENODE *node)
-{
- auto shape = luci::shape_get(node->logits()).template as<loco::TensorShape>();
- return loco::NodeShape{shape};
-}
+#undef DECLARE_USE_SINGLE
template <class CIRCLENODE>
loco::NodeShape use_paddings(const CIRCLENODE *node, const luci::CircleConst *paddings)
@@ -721,6 +717,8 @@ loco::NodeShape infer_fully_connected(const luci::CircleFullyConnected *node)
auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
auto weights_shape = luci::shape_get(node->weights()).as<loco::TensorShape>();
+// TODO Remove following unused code
+#if 0
// Checking shape capability for fully connected layer
// Input: a tensor of at least rank 2 [D1, D2, ... Dn]
// Weight: [# of units, K]
@@ -741,6 +739,40 @@ loco::NodeShape infer_fully_connected(const luci::CircleFullyConnected *node)
out_shape.rank(2);
out_shape.dim(0) = batch_size;
out_shape.dim(1) = weights_shape.dim(0);
+#endif
+
+ loco::TensorShape out_shape;
+
+ // NOTE Some recipes in some repositories are using rank 4 input for FullyConnected.
+ // Until they are all fixed, disable following assert.
+ // TODO Enable following assert after related fixes are applied
+ // https://github.com/tensorflow/tensorflow/blob/ea33c1e7a25d8025e8ee405ad8ab7be261798d76/tensorflow/lite/kernels/fully_connected.cc#L194
+ // LUCI_ASSERT(input_shape.rank() == 2 || input_shape.rank() == 3,
+ // "Input rank of FullyConnected should be 2 or 3");
+
+ // https://github.com/tensorflow/tensorflow/blob/ea33c1e7a25d8025e8ee405ad8ab7be261798d76/tensorflow/lite/kernels/fully_connected.cc#L225
+ LUCI_ASSERT(weights_shape.rank() == 2, "Weights of FullyConnected should be 2");
+
+ // https://github.com/tensorflow/tensorflow/blob/ea33c1e7a25d8025e8ee405ad8ab7be261798d76/tensorflow/lite/kernels/fully_connected.cc#L353-L367
+ if (node->keep_num_dims())
+ {
+ out_shape.rank(input_shape.rank());
+ for (uint32_t i = 0; i < input_shape.rank(); ++i)
+ out_shape.dim(i) = input_shape.dim(i);
+ out_shape.dim(out_shape.rank() - 1) = weights_shape.dim(0);
+ }
+ else
+ {
+ uint32_t input_size = 1;
+ for (uint32_t i = 0; i < input_shape.rank(); i++)
+ {
+ input_size = input_size * input_shape.dim(i).value();
+ }
+ const uint32_t batch_size = input_size / weights_shape.dim(1).value();
+ out_shape.rank(2);
+ out_shape.dim(0) = batch_size;
+ out_shape.dim(1) = weights_shape.dim(0);
+ }
return loco::NodeShape{out_shape};
}
@@ -1554,6 +1586,30 @@ loco::NodeShape infer_squeeze(const luci::CircleSqueeze *node)
return loco::NodeShape{output_shape};
}
+loco::NodeShape infer_svdf(const luci::CircleSVDF *node)
+{
+ const auto ifm_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
+ const auto weight_feature_shape = luci::shape_get(node->weight_feature()).as<loco::TensorShape>();
+
+ assert(ifm_shape.rank() == 2);
+ assert(weight_feature_shape.rank() == 2);
+
+ assert(ifm_shape.dim(1) == weight_feature_shape.dim(1));
+ assert(weight_feature_shape.dim(0).known());
+
+ const auto rank = node->svdf_rank();
+ const auto num_filters = weight_feature_shape.dim(0).value();
+ assert(num_filters % rank == 0);
+ const auto num_units = num_filters / rank;
+
+ loco::TensorShape ofm_shape;
+ ofm_shape.rank(2);
+ ofm_shape.dim(0) = ifm_shape.dim(0);
+ ofm_shape.dim(1) = num_units;
+
+ return loco::NodeShape{ofm_shape};
+}
+
loco::NodeShape infer_tile(const luci::CircleTile *node)
{
const loco::DataType S32 = loco::DataType::S32;
@@ -2393,6 +2449,8 @@ public:
return loco::NodeShape{output_shape};
}
+ loco::NodeShape visit(const luci::CircleSVDF *node) final { return infer_svdf(node); }
+
loco::NodeShape visit(const luci::CircleTanh *node) final { return use_x(node); }
loco::NodeShape visit(const luci::CircleTile *node) final { return infer_tile(node); }
@@ -2486,6 +2544,8 @@ public:
loco::NodeShape visit(const luci::CircleUnpackOut *node) final { return infer_unpack_out(node); }
+ loco::NodeShape visit(const luci::CircleVariable *node) final { return use_own(node); }
+
loco::NodeShape visit(const luci::CircleWhileOut *node) final { return infer_while_out(node); }
};
diff --git a/compiler/luci/service/src/CircleTypeInferenceRule.cpp b/compiler/luci/service/src/CircleTypeInferenceRule.cpp
index 5f6d46f2b..438c4a364 100644
--- a/compiler/luci/service/src/CircleTypeInferenceRule.cpp
+++ b/compiler/luci/service/src/CircleTypeInferenceRule.cpp
@@ -478,6 +478,11 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT
loco::DataType visit(const luci::CircleSum *node) final { return luci::dtype_get(node->input()); }
+ loco::DataType visit(const luci::CircleSVDF *node) final
+ {
+ return luci::dtype_get(node->input());
+ }
+
loco::DataType visit(const luci::CircleTanh *node) final { return luci::dtype_get(node->x()); }
loco::DataType visit(const luci::CircleTile *node) final
@@ -605,6 +610,8 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT
return loco::DataType::S32;
}
+ loco::DataType visit(const luci::CircleVariable *node) final { return node->dtype(); }
+
loco::DataType visit(const luci::CircleUniqueOut *node) final
{
if (node->index() == 0)
diff --git a/compiler/luci/service/src/Nodes/CircleSVDF.cpp b/compiler/luci/service/src/Nodes/CircleSVDF.cpp
new file mode 100644
index 000000000..d4c3ce88f
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSVDF.cpp
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleSVDF *node)
+{
+ if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return nullptr;
+
+ auto *cloned = _graph->nodes()->create<luci::CircleSVDF>();
+ if (cloned != nullptr)
+ {
+ cloned->fusedActivationFunction(node->fusedActivationFunction());
+ cloned->asymmetric_quantize_inputs(node->asymmetric_quantize_inputs());
+ cloned->svdf_rank(node->svdf_rank());
+ }
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleSVDF.test.cpp b/compiler/luci/service/src/Nodes/CircleSVDF.test.cpp
new file mode 100644
index 000000000..d6edaf1cc
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSVDF.test.cpp
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_SVDF)
+{
+ auto g = loco::make_graph();
+ auto node_svdf = g->nodes()->create<luci::CircleSVDF>();
+ node_svdf->fusedActivationFunction(luci::FusedActFunc::RELU);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_svdf, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_svdf = dynamic_cast<luci::CircleSVDF *>(cloned);
+ ASSERT_NE(nullptr, cloned_svdf);
+ ASSERT_EQ(node_svdf->asymmetric_quantize_inputs(), cloned_svdf->asymmetric_quantize_inputs());
+ ASSERT_EQ(node_svdf->svdf_rank(), cloned_svdf->svdf_rank());
+}
+
+TEST(CloneNodeTest, clone_SVDF_NEG)
+{
+ auto g = loco::make_graph();
+ auto node_svdf = g->nodes()->create<luci::CircleSVDF>();
+ node_svdf->fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_svdf, gc.get());
+ ASSERT_EQ(nullptr, cloned);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleVariable.cpp b/compiler/luci/service/src/Nodes/CircleVariable.cpp
new file mode 100644
index 000000000..c1430bd3a
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleVariable.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleVariable *)
+{
+ return _graph->nodes()->create<luci::CircleVariable>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleVariable.test.cpp b/compiler/luci/service/src/Nodes/CircleVariable.test.cpp
new file mode 100644
index 000000000..7d29438be
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleVariable.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Variable)
+{
+ auto g = loco::make_graph();
+ auto node_dummy = g->nodes()->create<luci::CircleVariable>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_dummy, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_variable = dynamic_cast<luci::CircleVariable *>(cloned);
+ ASSERT_NE(nullptr, cloned_variable);
+}
diff --git a/compiler/luci/tests/CMakeLists.txt b/compiler/luci/tests/CMakeLists.txt
index c03835823..1333efb7d 100644
--- a/compiler/luci/tests/CMakeLists.txt
+++ b/compiler/luci/tests/CMakeLists.txt
@@ -1,3 +1,14 @@
+set(CIRCLECHEF_FILE_PATH $<TARGET_FILE:circlechef-file>)
+set(TFLCHEF_FILE_PATH $<TARGET_FILE:tflchef-file>)
+set(TFLITE2CIRCLE_PATH $<TARGET_FILE:tflite2circle>)
+if(DEFINED ENV{BUILD_HOST_EXEC})
+ # TODO use better way to represent path for host executable
+ set(CIRCLECHEF_FILE_PATH $ENV{BUILD_HOST_EXEC}/compiler/circlechef/tools/file/circlechef-file)
+ set(TFLCHEF_FILE_PATH $ENV{BUILD_HOST_EXEC}/compiler/tflchef/tools/file/tflchef-file)
+ set(TFLITE2CIRCLE_PATH $ENV{BUILD_HOST_EXEC}/compiler/tflite2circle/tflite2circle)
+ message(STATUS "TFLITE2CIRCLE_PATH = ${TFLITE2CIRCLE_PATH}")
+endif(DEFINED ENV{BUILD_HOST_EXEC})
+
# TODO use local test.recipe files for small networks
file(GLOB RECIPES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*/test.recipe")
@@ -17,14 +28,14 @@ foreach(RECIPE IN ITEMS ${RECIPES})
# Generate .tflite
add_custom_command(OUTPUT "${RECIPE_OUTPUT_FILE}"
- COMMAND tflchef-file "${RECIPE_SOURCE_FILE}" "${RECIPE_OUTPUT_FILE}"
- DEPENDS tflchef-file "${RECIPE_SOURCE_FILE}"
+ COMMAND ${TFLCHEF_FILE_PATH} "${RECIPE_SOURCE_FILE}" "${RECIPE_OUTPUT_FILE}"
+ DEPENDS ${TFLCHEF_FILE_PATH} "${RECIPE_SOURCE_FILE}"
COMMENT "Generating ${RECIPE_OUTPUT_FILE}")
# Generate .circle
add_custom_command(OUTPUT "${CIRCLE_OUTPUT_FILE}"
- COMMAND tflite2circle "${RECIPE_OUTPUT_FILE}" "${CIRCLE_OUTPUT_FILE}"
- DEPENDS tflite2circle "${RECIPE_OUTPUT_FILE}"
+ COMMAND ${TFLITE2CIRCLE_PATH} "${RECIPE_OUTPUT_FILE}" "${CIRCLE_OUTPUT_FILE}"
+ DEPENDS ${TFLITE2CIRCLE_PATH} "${RECIPE_OUTPUT_FILE}"
COMMENT "Generating ${CIRCLE_OUTPUT_FILE}")
list(APPEND TESTFILES "${CIRCLE_OUTPUT_FILE}")
@@ -52,14 +63,14 @@ foreach(RECIPE IN ITEMS ${RECIPES})
# Generate .tflite
add_custom_command(OUTPUT "${RECIPE_OUTPUT_FILE}"
- COMMAND tflchef-file "${RECIPE_SOURCE_FILE}" "${RECIPE_OUTPUT_FILE}"
- DEPENDS tflchef-file "${RECIPE_SOURCE_FILE}"
+ COMMAND ${TFLCHEF_FILE_PATH} "${RECIPE_SOURCE_FILE}" "${RECIPE_OUTPUT_FILE}"
+ DEPENDS ${TFLCHEF_FILE_PATH} "${RECIPE_SOURCE_FILE}"
COMMENT "Generating ${RECIPE_OUTPUT_FILE}")
# Generate .circle
add_custom_command(OUTPUT "${CIRCLE_OUTPUT_FILE}"
- COMMAND tflite2circle "${RECIPE_OUTPUT_FILE}" "${CIRCLE_OUTPUT_FILE}"
- DEPENDS tflite2circle "${RECIPE_OUTPUT_FILE}"
+ COMMAND ${TFLITE2CIRCLE_PATH} "${RECIPE_OUTPUT_FILE}" "${CIRCLE_OUTPUT_FILE}"
+ DEPENDS ${TFLITE2CIRCLE_PATH} "${RECIPE_OUTPUT_FILE}"
COMMENT "Generating ${CIRCLE_OUTPUT_FILE}")
list(APPEND TESTFILES "${CIRCLE_OUTPUT_FILE}")
@@ -87,8 +98,8 @@ foreach(RECIPE IN ITEMS ${RECIPES2})
# Generate .circle
add_custom_command(OUTPUT "${CIRCLE_OUTPUT_FILE}"
- COMMAND circlechef-file "${RECIPE_SOURCE_FILE}" "${CIRCLE_OUTPUT_FILE}"
- DEPENDS circlechef-file "${RECIPE_SOURCE_FILE}"
+ COMMAND ${CIRCLECHEF_FILE_PATH} "${RECIPE_SOURCE_FILE}" "${CIRCLE_OUTPUT_FILE}"
+ DEPENDS ${CIRCLECHEF_FILE_PATH} "${RECIPE_SOURCE_FILE}"
COMMENT "Generating ${CIRCLE_OUTPUT_FILE}")
list(APPEND TESTFILES "${CIRCLE_OUTPUT_FILE}")
@@ -111,6 +122,8 @@ include("test.lst")
# Read "test.local.lst" if exists
include("test.local.lst" OPTIONAL)
+# NOTE $<TARGET_FILE:luci_readtester> is used as-is as test itself should
+# run in target device for cross build also
add_test(NAME luci_unit_readtest
COMMAND "${CMAKE_CURRENT_SOURCE_DIR}/readverify.sh"
"${CMAKE_CURRENT_BINARY_DIR}"
diff --git a/compiler/luci/tests/test.lst b/compiler/luci/tests/test.lst
index 28ddcf672..94e723f21 100644
--- a/compiler/luci/tests/test.lst
+++ b/compiler/luci/tests/test.lst
@@ -180,6 +180,8 @@ addread(Sub_000)
addread(Sub_U8_000)
addread(Sum_000)
addread(Sum_001)
+addread(SVDF_000)
+addread(SVDF_001)
addread(Tanh_000)
addread(Tanh_U8_000)
addread(Tile_000)
@@ -403,6 +405,8 @@ addwrite(Sub_000)
addwrite(Sub_U8_000)
addwrite(Sum_000)
addwrite(Sum_001)
+addwrite(SVDF_000)
+addwrite(SVDF_001)
addwrite(Tanh_000)
addwrite(Tanh_U8_000)
addwrite(Tile_000)