summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
authorChunseok Lee <chunseok.lee@samsung.com>2022-04-15 19:15:11 +0900
committerChunseok Lee <chunseok.lee@samsung.com>2022-04-15 19:15:11 +0900
commit3ad689f0803519e343c36d5700646e86059df961 (patch)
tree862346c401a5577518fa7f042532aa931b53aa0e /compiler
parentac6e4dd7b480e83b586ef533d7b29a8a97eb48fe (diff)
downloadnnfw-3ad689f0803519e343c36d5700646e86059df961.tar.gz
nnfw-3ad689f0803519e343c36d5700646e86059df961.tar.bz2
nnfw-3ad689f0803519e343c36d5700646e86059df961.zip
Imported Upstream version 1.20.0upstream/1.20.0submit/tizen/20220415.103159
Diffstat (limited to 'compiler')
-rw-r--r--compiler/angkor/CMakeLists.txt4
-rw-r--r--compiler/arser/tests/arser.test.cpp65
-rw-r--r--compiler/circle-eval-diff/CMakeLists.txt34
-rw-r--r--compiler/circle-eval-diff/README.md51
-rw-r--r--compiler/circle-eval-diff/driver/Driver.cpp156
-rw-r--r--compiler/circle-eval-diff/include/CircleEvalDiff.h74
-rw-r--r--compiler/circle-eval-diff/requires.cmake7
-rw-r--r--compiler/circle-eval-diff/src/CircleEvalDiff.cpp97
-rw-r--r--compiler/circle-eval-diff/src/MetricPrinter.cpp185
-rw-r--r--compiler/circle-eval-diff/src/MetricPrinter.h90
-rw-r--r--compiler/circle-eval-diff/src/MetricPrinter.test.cpp236
-rw-r--r--compiler/circle-eval-diff/src/ModuleEvalDiff.cpp216
-rw-r--r--compiler/circle-eval-diff/src/ModuleEvalDiff.h67
-rw-r--r--compiler/circle-eval-diff/src/Tensor.cpp72
-rw-r--r--compiler/circle-eval-diff/src/Tensor.h81
-rw-r--r--compiler/circle-eval-diff/src/Tensor.test.cpp101
-rw-r--r--compiler/circle-execution-plan/CMakeLists.txt6
-rw-r--r--compiler/circle-execution-plan/README.md5
-rw-r--r--compiler/circle-execution-plan/pal/IScratchpadHelper.h51
-rw-r--r--compiler/circle-execution-plan/pal/ScratchpadHelperCMSISNN.h187
-rw-r--r--compiler/circle-execution-plan/pal/ScratchpadHelperLinux.h137
-rw-r--r--compiler/circle-execution-plan/pal/ScratchpadHelperMCU.h88
-rw-r--r--compiler/circle-execution-plan/pal/TargetPlatform.h38
-rw-r--r--compiler/circle-execution-plan/src/CircleExecutionPlan.cpp47
-rw-r--r--compiler/circle-execution-plan/src/ExecutionPlanner.cpp174
-rw-r--r--compiler/circle-execution-plan/src/ExecutionPlanner.h67
-rw-r--r--compiler/circle-inspect/CMakeLists.txt7
-rw-r--r--compiler/circle-inspect/README.md16
-rw-r--r--compiler/circle-inspect/driver/Driver.cpp6
-rw-r--r--compiler/circle-inspect/requires.cmake2
-rw-r--r--compiler/circle-inspect/src/Dump.cpp25
-rw-r--r--compiler/circle-inspect/src/Dump.h9
-rw-r--r--compiler/circle-inspect/src/Reader.cpp72
-rw-r--r--compiler/circle-inspect/src/Reader.h8
-rw-r--r--compiler/circle-opselector/README.md42
-rw-r--r--compiler/circle-part-value-test/CMakeLists.txt6
-rwxr-xr-xcompiler/circle-part-value-test/part_eval_one.py86
-rw-r--r--compiler/circle-part-value-test/parts/Net_UnpackAdd_001.001.part7
-rw-r--r--compiler/circle-part-value-test/parts/Net_UnpackAdd_001.002.part7
-rw-r--r--compiler/circle-part-value-test/parts/Net_UnpackAdd_001.part7
-rw-r--r--compiler/circle-part-value-test/parts/Part_Mul_Sqrt_FC_nobias_000_000.part7
-rw-r--r--compiler/circle-part-value-test/parts/Part_Mul_Sqrt_FC_nobias_000_001.part7
-rw-r--r--compiler/circle-part-value-test/parts/Part_Mul_Sqrt_FC_nobias_000_002.part7
-rw-r--r--compiler/circle-part-value-test/parts/Part_Split_Add_000.part7
-rw-r--r--compiler/circle-part-value-test/parts/Part_Tanh_FC_nobias.part7
-rw-r--r--compiler/circle-part-value-test/parts/Part_Tanh_FC_nobias_001.part7
-rw-r--r--compiler/circle-part-value-test/parts/Part_Tanh_FC_nobias_002.part7
-rw-r--r--compiler/circle-part-value-test/parts/Part_Tanh_FC_nobias_003.part7
-rw-r--r--compiler/circle-part-value-test/parts/SignatureDef_MultiOut_000.part7
-rw-r--r--compiler/circle-part-value-test/parts/SignatureDef_MultiOut_001.part7
-rw-r--r--compiler/circle-part-value-test/test.lst21
-rw-r--r--compiler/circle-partitioner-test/CMakeLists.txt4
-rw-r--r--compiler/circle-partitioner-test/parts/Part_Add_SVDF_000.part7
-rw-r--r--compiler/circle-partitioner-test/test.lst4
-rw-r--r--compiler/circle-partitioner/CMakeLists.txt19
-rw-r--r--compiler/circle-partitioner/README.md130
-rw-r--r--compiler/circle-quantizer-dredd-recipe-test/CMakeLists.txt144
-rw-r--r--compiler/circle-quantizer-dredd-recipe-test/README.md37
-rw-r--r--compiler/circle-quantizer-dredd-recipe-test/requires.cmake6
-rw-r--r--compiler/circle-quantizer-dredd-recipe-test/test.lst15
-rwxr-xr-xcompiler/circle-quantizer-dredd-recipe-test/testall.sh100
-rw-r--r--compiler/circle-quantizer/CMakeLists.txt10
-rw-r--r--compiler/circle-quantizer/src/CircleQuantizer.cpp146
-rw-r--r--compiler/circle-tensordump/CMakeLists.txt7
-rw-r--r--compiler/circle-tensordump/requires.cmake2
-rw-r--r--compiler/circle-tensordump/src/Reader.cpp62
-rw-r--r--compiler/circle-tensordump/src/Reader.h6
-rw-r--r--compiler/circle-verify/CMakeLists.txt7
-rw-r--r--compiler/circle-verify/requires.cmake2
-rw-r--r--compiler/circle2circle-dredd-recipe-test/CMakeLists.txt4
-rw-r--r--compiler/circle2circle/CMakeLists.txt2
-rw-r--r--compiler/circle2circle/requires.cmake1
-rw-r--r--compiler/circle2circle/src/Circle2Circle.cpp16
-rw-r--r--compiler/circlechef/CMakeLists.txt6
-rw-r--r--compiler/circlechef/circle/CMakeLists.txt3
-rw-r--r--compiler/circlechef/circle/src/CircleImport.cpp35
-rw-r--r--compiler/circlechef/circle/src/CircleImport.h5
-rw-r--r--compiler/circlechef/circle/src/RecipeChef.cpp11
-rw-r--r--compiler/circlechef/core/CMakeLists.txt2
-rw-r--r--compiler/circlechef/core/src/ModelChef.cpp4
-rw-r--r--compiler/circlechef/requires.cmake3
-rw-r--r--compiler/circlechef/tests/CMakeLists.txt33
-rw-r--r--compiler/circledump/CMakeLists.txt9
-rw-r--r--compiler/circledump/README.md2
-rw-r--r--compiler/circledump/requires.cmake2
-rw-r--r--compiler/circledump/src/Dump.cpp58
-rw-r--r--compiler/circledump/src/Load.cpp2
-rw-r--r--compiler/circledump/src/OpPrinter.cpp36
-rw-r--r--compiler/circledump/src/Read.cpp61
-rw-r--r--compiler/circledump/src/Read.h9
-rw-r--r--compiler/cli/CMakeLists.txt8
-rw-r--r--compiler/common-artifacts/CMakeLists.txt117
-rw-r--r--compiler/common-artifacts/exclude.lst17
-rw-r--r--compiler/common-artifacts/options.lst6
-rw-r--r--compiler/common-artifacts/requires.cmake2
-rw-r--r--compiler/common-artifacts/src/TestDataGenerator.cpp90
-rw-r--r--compiler/dio-hdf5/CMakeLists.txt30
-rw-r--r--compiler/dio-hdf5/README.md29
-rw-r--r--compiler/dio-hdf5/include/dio_hdf5/HDF5Importer.h82
-rw-r--r--compiler/dio-hdf5/requires.cmake1
-rw-r--r--compiler/dio-hdf5/src/HDF5Importer.cpp (renamed from compiler/record-minmax/src/HDF5Importer.cpp)34
-rw-r--r--compiler/dio-hdf5/src/HDF5Importer.test.cpp134
-rwxr-xr-xcompiler/dredd-rule-lib/rule-lib.sh17
-rw-r--r--compiler/embedded-import-value-test/.gitignore1
-rw-r--r--compiler/embedded-import-value-test/CMakeLists.txt34
-rw-r--r--compiler/embedded-import-value-test/README.md13
-rwxr-xr-xcompiler/embedded-import-value-test/evalverify.sh58
-rw-r--r--compiler/embedded-import-value-test/requires.cmake6
-rw-r--r--compiler/embedded-import-value-test/src/TestDriver.cpp242
-rw-r--r--compiler/embedded-import-value-test/test.lst192
-rw-r--r--compiler/enco/CMakeLists.txt5
-rw-r--r--compiler/enco/core/CMakeLists.txt8
-rw-r--r--compiler/enco/frontend/caffe/CMakeLists.txt8
-rw-r--r--compiler/enco/frontend/tflite/CMakeLists.txt11
-rw-r--r--compiler/exo/CMakeLists.txt4
-rw-r--r--compiler/hermes-std/CMakeLists.txt4
-rw-r--r--compiler/hermes-std/include/hermes/ConsoleReporter.h4
-rw-r--r--compiler/hermes-std/src/ConsoleReporter.cpp52
-rw-r--r--compiler/hermes-std/src/ConsoleReporter.test.cpp165
-rw-r--r--compiler/hermes/CMakeLists.txt4
-rw-r--r--compiler/hermes/include/hermes/core/Message.h10
-rw-r--r--compiler/hermes/include/hermes/core/MessageBuffer.h3
-rw-r--r--compiler/hermes/src/core/MessageBuffer.cpp8
-rw-r--r--compiler/hermes/src/core/Source.cpp5
-rw-r--r--compiler/locomotiv/CMakeLists.txt4
-rw-r--r--compiler/locop/CMakeLists.txt4
-rw-r--r--compiler/logo-core/CMakeLists.txt4
-rw-r--r--compiler/logo-ex/CMakeLists.txt23
-rw-r--r--compiler/logo-ex/README.md6
-rw-r--r--compiler/logo-ex/include/logo/ConstantFoldingPass.h (renamed from compiler/logo/include/logo/ConstantFoldingPass.h)8
-rw-r--r--compiler/logo-ex/include/logo/PassesEx.h24
-rw-r--r--compiler/logo-ex/requires.cmake3
-rw-r--r--compiler/logo-ex/src/Passes/ConstantFoldingPass.cpp (renamed from compiler/logo/src/Passes/ConstantFoldingPass.cpp)2
-rw-r--r--compiler/logo-ex/src/Passes/ConstantFoldingPass.test.cpp (renamed from compiler/logo/src/Passes/ConstantFoldingPass.test.cpp)2
-rw-r--r--compiler/logo-ex/src/TestHelper.h44
-rw-r--r--compiler/logo/CMakeLists.txt5
-rw-r--r--compiler/logo/include/logo/Passes.h1
-rw-r--r--compiler/logo/requires.cmake1
-rw-r--r--compiler/luci-interpreter/README.md2
-rw-r--r--compiler/luci-interpreter/include/luci_interpreter/GraphBuilderRegistry.h35
-rw-r--r--compiler/luci-interpreter/include/luci_interpreter/Interpreter.h5
-rw-r--r--compiler/luci-interpreter/pal/cmsisnn/KernelsToBuild.lst4
-rw-r--r--compiler/luci-interpreter/pal/cmsisnn/PALAveragePool2d.h124
-rw-r--r--compiler/luci-interpreter/pal/cmsisnn/PALConv2d.h163
-rw-r--r--compiler/luci-interpreter/pal/cmsisnn/PALDepthwiseConv2d.h192
-rw-r--r--compiler/luci-interpreter/pal/cmsisnn/PALDequantize.h44
-rw-r--r--compiler/luci-interpreter/pal/cmsisnn/PALFullyConnected.h114
-rw-r--r--compiler/luci-interpreter/pal/cmsisnn/PALMul.h18
-rw-r--r--compiler/luci-interpreter/pal/cmsisnn/PALQuantize.h44
-rw-r--r--compiler/luci-interpreter/pal/cmsisnn/PALSVDF.h190
-rw-r--r--compiler/luci-interpreter/pal/cmsisnn/pal.cmake9
-rw-r--r--compiler/luci-interpreter/pal/linux/KernelsToBuild.lst7
-rw-r--r--compiler/luci-interpreter/pal/linux/PALAveragePool2d.h73
-rw-r--r--compiler/luci-interpreter/pal/linux/PALBatchMatMul.h67
-rw-r--r--compiler/luci-interpreter/pal/linux/PALConv2d.h72
-rw-r--r--compiler/luci-interpreter/pal/linux/PALDepthwiseConv2d.h91
-rw-r--r--compiler/luci-interpreter/pal/linux/PALDequantize.h34
-rw-r--r--compiler/luci-interpreter/pal/linux/PALFullyConnected.h61
-rw-r--r--compiler/luci-interpreter/pal/linux/PALGather.h35
-rw-r--r--compiler/luci-interpreter/pal/linux/PALMul.h28
-rw-r--r--compiler/luci-interpreter/pal/linux/PALQuantize.h44
-rw-r--r--compiler/luci-interpreter/pal/linux/PALSVDF.h90
-rw-r--r--compiler/luci-interpreter/pal/linux/pal.cmake30
-rw-r--r--compiler/luci-interpreter/pal/mcu/KernelsToBuild.lst4
-rw-r--r--compiler/luci-interpreter/pal/mcu/PALAveragePool2d.h73
-rw-r--r--compiler/luci-interpreter/pal/mcu/PALConv2d.h43
-rw-r--r--compiler/luci-interpreter/pal/mcu/PALDepthwiseConv2d.h91
-rw-r--r--compiler/luci-interpreter/pal/mcu/PALDequantize.h44
-rw-r--r--compiler/luci-interpreter/pal/mcu/PALFullyConnected.h61
-rw-r--r--compiler/luci-interpreter/pal/mcu/PALMul.h18
-rw-r--r--compiler/luci-interpreter/pal/mcu/PALQuantize.h44
-rw-r--r--compiler/luci-interpreter/pal/mcu/PALSVDF.h258
-rw-r--r--compiler/luci-interpreter/pal/mcu/pal.cmake4
-rw-r--r--compiler/luci-interpreter/src/CMakeLists.txt3
-rw-r--r--compiler/luci-interpreter/src/Interpreter.cpp27
-rw-r--r--compiler/luci-interpreter/src/core/CMakeLists.txt4
-rw-r--r--compiler/luci-interpreter/src/core/KernelParams.h25
-rw-r--r--compiler/luci-interpreter/src/import/CMakeLists.txt15
-rw-r--r--compiler/luci-interpreter/src/import/GraphBuilderRegistry.cpp33
-rw-r--r--compiler/luci-interpreter/src/import/Nodes/CircleReferencingConst.cpp113
-rw-r--r--compiler/luci-interpreter/src/import/Nodes/CircleReferencingConst.h39
-rw-r--r--compiler/luci-interpreter/src/kernels/Add.cpp38
-rw-r--r--compiler/luci-interpreter/src/kernels/Add.h1
-rw-r--r--compiler/luci-interpreter/src/kernels/Add.test.cpp93
-rw-r--r--compiler/luci-interpreter/src/kernels/ArgMax.test.cpp2
-rw-r--r--compiler/luci-interpreter/src/kernels/AveragePool2D.cpp21
-rw-r--r--compiler/luci-interpreter/src/kernels/AveragePool2D.h3
-rw-r--r--compiler/luci-interpreter/src/kernels/AveragePool2D.test.cpp29
-rw-r--r--compiler/luci-interpreter/src/kernels/BatchMatMul.cpp188
-rw-r--r--compiler/luci-interpreter/src/kernels/BatchMatMul.h49
-rw-r--r--compiler/luci-interpreter/src/kernels/BatchMatMul.test.cpp272
-rw-r--r--compiler/luci-interpreter/src/kernels/BatchToSpaceND.test.cpp2
-rw-r--r--compiler/luci-interpreter/src/kernels/CMakeLists.txt4
-rw-r--r--compiler/luci-interpreter/src/kernels/Cast.test.cpp2
-rw-r--r--compiler/luci-interpreter/src/kernels/Concatenation.cpp18
-rw-r--r--compiler/luci-interpreter/src/kernels/Concatenation.test.cpp55
-rw-r--r--compiler/luci-interpreter/src/kernels/Conv2D.cpp73
-rw-r--r--compiler/luci-interpreter/src/kernels/Conv2D.h3
-rw-r--r--compiler/luci-interpreter/src/kernels/DepthToSpace.test.cpp2
-rw-r--r--compiler/luci-interpreter/src/kernels/DepthwiseConv2D.cpp28
-rw-r--r--compiler/luci-interpreter/src/kernels/DepthwiseConv2D.h2
-rw-r--r--compiler/luci-interpreter/src/kernels/DepthwiseConv2D.test.cpp50
-rw-r--r--compiler/luci-interpreter/src/kernels/Dequantize.cpp79
-rw-r--r--compiler/luci-interpreter/src/kernels/Dequantize.h43
-rw-r--r--compiler/luci-interpreter/src/kernels/Dequantize.test.cpp149
-rw-r--r--compiler/luci-interpreter/src/kernels/Div.cpp36
-rw-r--r--compiler/luci-interpreter/src/kernels/Div.h1
-rw-r--r--compiler/luci-interpreter/src/kernels/Div.test.cpp69
-rw-r--r--compiler/luci-interpreter/src/kernels/Equal.cpp29
-rw-r--r--compiler/luci-interpreter/src/kernels/Equal.h1
-rw-r--r--compiler/luci-interpreter/src/kernels/Equal.test.cpp106
-rw-r--r--compiler/luci-interpreter/src/kernels/ExpandDims.cpp88
-rw-r--r--compiler/luci-interpreter/src/kernels/ExpandDims.h44
-rw-r--r--compiler/luci-interpreter/src/kernels/ExpandDims.test.cpp115
-rw-r--r--compiler/luci-interpreter/src/kernels/FullyConnected.cpp18
-rw-r--r--compiler/luci-interpreter/src/kernels/FullyConnected.test.cpp2
-rw-r--r--compiler/luci-interpreter/src/kernels/Gather.cpp139
-rw-r--r--compiler/luci-interpreter/src/kernels/Gather.h47
-rw-r--r--compiler/luci-interpreter/src/kernels/Gather.test.cpp137
-rw-r--r--compiler/luci-interpreter/src/kernels/Greater.cpp29
-rw-r--r--compiler/luci-interpreter/src/kernels/Greater.h1
-rw-r--r--compiler/luci-interpreter/src/kernels/Greater.test.cpp106
-rw-r--r--compiler/luci-interpreter/src/kernels/GreaterEqual.cpp29
-rw-r--r--compiler/luci-interpreter/src/kernels/GreaterEqual.h1
-rw-r--r--compiler/luci-interpreter/src/kernels/GreaterEqual.test.cpp105
-rw-r--r--compiler/luci-interpreter/src/kernels/L2Normalize.test.cpp2
-rw-r--r--compiler/luci-interpreter/src/kernels/L2Pool2D.test.cpp3
-rw-r--r--compiler/luci-interpreter/src/kernels/LeakyRelu.test.cpp2
-rw-r--r--compiler/luci-interpreter/src/kernels/Less.cpp29
-rw-r--r--compiler/luci-interpreter/src/kernels/Less.h1
-rw-r--r--compiler/luci-interpreter/src/kernels/Less.test.cpp106
-rw-r--r--compiler/luci-interpreter/src/kernels/LessEqual.cpp29
-rw-r--r--compiler/luci-interpreter/src/kernels/LessEqual.h1
-rw-r--r--compiler/luci-interpreter/src/kernels/LessEqual.test.cpp106
-rw-r--r--compiler/luci-interpreter/src/kernels/Logistic.test.cpp2
-rw-r--r--compiler/luci-interpreter/src/kernels/MirrorPad.cpp118
-rw-r--r--compiler/luci-interpreter/src/kernels/MirrorPad.test.cpp210
-rw-r--r--compiler/luci-interpreter/src/kernels/Mul.cpp37
-rw-r--r--compiler/luci-interpreter/src/kernels/Mul.h1
-rw-r--r--compiler/luci-interpreter/src/kernels/Mul.test.cpp126
-rw-r--r--compiler/luci-interpreter/src/kernels/NotEqual.cpp29
-rw-r--r--compiler/luci-interpreter/src/kernels/NotEqual.h1
-rw-r--r--compiler/luci-interpreter/src/kernels/NotEqual.test.cpp106
-rw-r--r--compiler/luci-interpreter/src/kernels/OneHot.cpp136
-rw-r--r--compiler/luci-interpreter/src/kernels/OneHot.h48
-rw-r--r--compiler/luci-interpreter/src/kernels/OneHot.test.cpp192
-rw-r--r--compiler/luci-interpreter/src/kernels/Pack.test.cpp2
-rw-r--r--compiler/luci-interpreter/src/kernels/Pad.cpp10
-rw-r--r--compiler/luci-interpreter/src/kernels/Pad.test.cpp26
-rw-r--r--compiler/luci-interpreter/src/kernels/Quantize.cpp160
-rw-r--r--compiler/luci-interpreter/src/kernels/Quantize.h43
-rw-r--r--compiler/luci-interpreter/src/kernels/Quantize.test.cpp254
-rw-r--r--compiler/luci-interpreter/src/kernels/ResizeBilinear.test.cpp2
-rw-r--r--compiler/luci-interpreter/src/kernels/ResizeNearestNeighbor.test.cpp2
-rw-r--r--compiler/luci-interpreter/src/kernels/ReverseV2.test.cpp2
-rw-r--r--compiler/luci-interpreter/src/kernels/SVDF.cpp241
-rw-r--r--compiler/luci-interpreter/src/kernels/SVDF.h56
-rw-r--r--compiler/luci-interpreter/src/kernels/SVDF.test.cpp341
-rw-r--r--compiler/luci-interpreter/src/kernels/Slice.cpp5
-rw-r--r--compiler/luci-interpreter/src/kernels/Slice.test.cpp4
-rw-r--r--compiler/luci-interpreter/src/kernels/Softmax.test.cpp2
-rw-r--r--compiler/luci-interpreter/src/kernels/SpaceToBatchND.test.cpp2
-rw-r--r--compiler/luci-interpreter/src/kernels/SpaceToDepth.test.cpp2
-rw-r--r--compiler/luci-interpreter/src/kernels/Split.test.cpp2
-rw-r--r--compiler/luci-interpreter/src/kernels/SplitV.test.cpp2
-rw-r--r--compiler/luci-interpreter/src/kernels/Squeeze.test.cpp2
-rw-r--r--compiler/luci-interpreter/src/kernels/Sub.cpp36
-rw-r--r--compiler/luci-interpreter/src/kernels/Sub.h1
-rw-r--r--compiler/luci-interpreter/src/kernels/Sub.test.cpp75
-rw-r--r--compiler/luci-interpreter/src/kernels/Transpose.test.cpp2
-rw-r--r--compiler/luci-interpreter/src/kernels/Unpack.test.cpp2
-rw-r--r--compiler/luci-interpreter/src/kernels/Utils.cpp22
-rw-r--r--compiler/luci-interpreter/src/kernels/Utils.h33
-rw-r--r--compiler/luci-interpreter/src/loader/CMakeLists.txt4
-rw-r--r--compiler/luci-interpreter/src/loader/GraphLoader.cpp94
-rw-r--r--compiler/luci-interpreter/src/loader/KernelBuilder.test.cpp48
-rw-r--r--compiler/luci-interpreter/src/loader/nodes/AveragePool2D.cpp22
-rw-r--r--compiler/luci-interpreter/src/loader/nodes/BatchMatMul.cpp72
-rw-r--r--compiler/luci-interpreter/src/loader/nodes/Conv2D.cpp17
-rw-r--r--compiler/luci-interpreter/src/loader/nodes/DepthwiseConv2D.cpp22
-rw-r--r--compiler/luci-interpreter/src/loader/nodes/Dequantize.cpp37
-rw-r--r--compiler/luci-interpreter/src/loader/nodes/ExpandDims.cpp37
-rw-r--r--compiler/luci-interpreter/src/loader/nodes/FullyConnected.cpp1
-rw-r--r--compiler/luci-interpreter/src/loader/nodes/Gather.cpp44
-rw-r--r--compiler/luci-interpreter/src/loader/nodes/OneHot.cpp42
-rw-r--r--compiler/luci-interpreter/src/loader/nodes/Quantize.cpp37
-rw-r--r--compiler/luci-interpreter/src/loader/nodes/SVDF.cpp93
-rw-r--r--compiler/luci-micro/CMakeLists.txt2
-rw-r--r--compiler/luci-pass-value-test/CMakeLists.txt6
-rw-r--r--compiler/luci-pass-value-test/eval_result_verifier.py56
-rw-r--r--compiler/luci-pass-value-test/test.lst4
-rw-r--r--compiler/luci-value-test/CMakeLists.txt74
-rwxr-xr-xcompiler/luci-value-test/evalverify.sh4
-rwxr-xr-xcompiler/luci-value-test/evalverify_ref.sh63
-rwxr-xr-xcompiler/luci-value-test/evalverifytol.sh71
-rwxr-xr-xcompiler/luci-value-test/evalverifytol_ref.sh70
-rwxr-xr-xcompiler/luci-value-test/luci_eval_verifier.py77
-rwxr-xr-xcompiler/luci-value-test/luci_eval_verifier_ref.py151
-rw-r--r--compiler/luci-value-test/test.lst106
-rw-r--r--compiler/luci/CMakeLists.txt4
-rw-r--r--compiler/luci/export/CMakeLists.txt4
-rw-r--r--compiler/luci/export/src/CircleBuiltinTypesExtractor.h539
-rw-r--r--compiler/luci/export/src/CircleBuiltinTypesMappingRule.h79
-rw-r--r--compiler/luci/export/src/CircleExporterImpl.cpp9
-rw-r--r--compiler/luci/export/src/CircleExporterUtils.cpp58
-rw-r--r--compiler/luci/export/src/CircleExporterUtils.h6
-rw-r--r--compiler/luci/export/src/CircleOperationExporter.cpp1696
-rw-r--r--compiler/luci/export/src/CircleOperationExporter.h2
-rw-r--r--compiler/luci/export/src/CircleOperationExporterRule.cpp277
-rw-r--r--compiler/luci/export/src/CircleOperationExporterRule.h76
-rw-r--r--compiler/luci/export/src/CircleOps.lst154
-rw-r--r--compiler/luci/export/src/CircleTensorExporter.cpp15
-rw-r--r--compiler/luci/export/src/SerializedData.h6
-rw-r--r--compiler/luci/import/CMakeLists.txt7
-rw-r--r--compiler/luci/import/include/luci/Import/CircleReader.h73
-rw-r--r--compiler/luci/import/include/luci/Import/GraphBuilderRegistry.h23
-rw-r--r--compiler/luci/import/include/luci/Import/NodeBuilder.h58
-rw-r--r--compiler/luci/import/include/luci/Import/Nodes.h2
-rw-r--r--compiler/luci/import/include/luci/Import/Nodes/CircleConst.h17
-rw-r--r--compiler/luci/import/include/luci/Import/Nodes/CircleSVDF.h37
-rw-r--r--compiler/luci/import/include/luci/Import/Nodes/CircleVariable.h37
-rw-r--r--compiler/luci/import/src/CircleImportMetadata.cpp43
-rw-r--r--compiler/luci/import/src/CircleReader.cpp186
-rw-r--r--compiler/luci/import/src/GraphBuilder.cpp15
-rw-r--r--compiler/luci/import/src/GraphBuilderMultiOutput.cpp20
-rw-r--r--compiler/luci/import/src/GraphBuilderRegistry.cpp9
-rw-r--r--compiler/luci/import/src/Importer.cpp78
-rw-r--r--compiler/luci/import/src/Importer.test.cpp50
-rw-r--r--compiler/luci/import/src/Nodes/CircleCast.cpp12
-rw-r--r--compiler/luci/import/src/Nodes/CircleConst.cpp34
-rw-r--r--compiler/luci/import/src/Nodes/CircleCustom.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp12
-rw-r--r--compiler/luci/import/src/Nodes/CircleElu.cpp10
-rw-r--r--compiler/luci/import/src/Nodes/CircleEqual.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleExp.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleExpandDims.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleFloorDiv.cpp17
-rw-r--r--compiler/luci/import/src/Nodes/CircleFloorMod.cpp9
-rw-r--r--compiler/luci/import/src/Nodes/CircleFullyConnected.cpp1
-rw-r--r--compiler/luci/import/src/Nodes/CircleGatherNd.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleGreater.cpp10
-rw-r--r--compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleIf.cpp9
-rw-r--r--compiler/luci/import/src/Nodes/CircleLess.cpp13
-rw-r--r--compiler/luci/import/src/Nodes/CircleLessEqual.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleLog.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleLogicalNot.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleLogicalOr.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleLogistic.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp22
-rw-r--r--compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp25
-rw-r--r--compiler/luci/import/src/Nodes/CircleNotEqual.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleOneHot.cpp24
-rw-r--r--compiler/luci/import/src/Nodes/CircleReduceAny.cpp17
-rw-r--r--compiler/luci/import/src/Nodes/CircleReduceProd.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleReshape.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleReverseSequence.cpp15
-rw-r--r--compiler/luci/import/src/Nodes/CircleReverseV2.cpp15
-rw-r--r--compiler/luci/import/src/Nodes/CircleRound.cpp12
-rw-r--r--compiler/luci/import/src/Nodes/CircleRsqrt.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleSVDF.cpp67
-rw-r--r--compiler/luci/import/src/Nodes/CircleScatterNd.cpp9
-rw-r--r--compiler/luci/import/src/Nodes/CircleSegmentSum.cpp15
-rw-r--r--compiler/luci/import/src/Nodes/CircleSelect.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleSelectV2.cpp14
-rw-r--r--compiler/luci/import/src/Nodes/CircleSin.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleSquare.cpp12
-rw-r--r--compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp13
-rw-r--r--compiler/luci/import/src/Nodes/CircleTanh.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleTile.cpp10
-rw-r--r--compiler/luci/import/src/Nodes/CircleTopKV2.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleTransposeConv.cpp14
-rw-r--r--compiler/luci/import/src/Nodes/CircleUnpack.cpp11
-rw-r--r--compiler/luci/import/src/Nodes/CircleVariable.cpp80
-rw-r--r--compiler/luci/import/src/Nodes/CircleWhere.cpp12
-rw-r--r--compiler/luci/import/src/Nodes/CircleWhile.cpp13
-rw-r--r--compiler/luci/import/src/ValidateHelpers.cpp39
-rw-r--r--compiler/luci/lang/include/luci/IR/CircleNodes.h9
-rw-r--r--compiler/luci/lang/include/luci/IR/CircleNodes.lst5
-rw-r--r--compiler/luci/lang/include/luci/IR/CircleQuantParam.h4
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h4
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleSVDF.h70
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleVariable.h39
-rw-r--r--compiler/luci/lang/src/CircleQuantParam.cpp46
-rw-r--r--compiler/luci/lang/src/CircleQuantParam.test.cpp78
-rw-r--r--compiler/luci/lang/src/Nodes/CircleFullyConnected.test.cpp1
-rw-r--r--compiler/luci/lang/src/Nodes/CircleSVDF.test.cpp101
-rw-r--r--compiler/luci/lang/src/Nodes/CircleVariable.test.cpp61
-rw-r--r--compiler/luci/logex/CMakeLists.txt14
-rw-r--r--compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp265
-rw-r--r--compiler/luci/logex/src/CircleNodeSummaryBuilder.h52
-rw-r--r--compiler/luci/logex/src/CircleNodeSummaryBuilder.test.cpp309
-rw-r--r--compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp1128
-rw-r--r--compiler/luci/logex/src/CircleNodeSummaryBuilders.h821
-rw-r--r--compiler/luci/logex/src/FormattedGraph.cpp2194
-rw-r--r--compiler/luci/partition/CMakeLists.txt2
-rw-r--r--compiler/luci/partition/src/ConnectNode.h2
-rw-r--r--compiler/luci/partition/src/Nodes/CircleSVDF.cpp47
-rw-r--r--compiler/luci/partition/src/Nodes/CircleSVDF.test.cpp106
-rw-r--r--compiler/luci/partition/src/Nodes/CircleVariable.cpp27
-rw-r--r--compiler/luci/partition/src/PartitionIRDump.cpp11
-rw-r--r--compiler/luci/partition/src/PartitionMerge.cpp50
-rw-r--r--compiler/luci/partition/src/PartitionPGroups.cpp240
-rw-r--r--compiler/luci/pass/CMakeLists.txt8
-rw-r--r--compiler/luci/pass/include/luci/CircleOptimizer.h20
-rw-r--r--compiler/luci/pass/include/luci/CircleQuantizer.h97
-rw-r--r--compiler/luci/pass/include/luci/Pass/ConvertToFakeQuantizedModelPass.h39
-rw-r--r--compiler/luci/pass/include/luci/Pass/CopyQuantParamPass.h53
-rw-r--r--compiler/luci/pass/include/luci/Pass/FoldGatherPass.h38
-rw-r--r--compiler/luci/pass/include/luci/Pass/PropagateQParamBackwardPass.h42
-rw-r--r--compiler/luci/pass/include/luci/Pass/PropagateQParamForwardPass.h (renamed from compiler/luci/pass/include/luci/Pass/PropagateQuantParamPass.h)19
-rw-r--r--compiler/luci/pass/include/luci/Pass/QuantizationParameters.h11
-rw-r--r--compiler/luci/pass/include/luci/Pass/QuantizeDequantizeWeightsPass.h28
-rw-r--r--compiler/luci/pass/include/luci/Pass/QuantizePreCheckerPass.h39
-rw-r--r--compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h39
-rw-r--r--compiler/luci/pass/include/luci/Pass/RemoveRedundantQuantizePass.h37
-rw-r--r--compiler/luci/pass/src/BatchNormPatternFinder.cpp40
-rw-r--r--compiler/luci/pass/src/BatchNormPatternFinder.test.cpp107
-rw-r--r--compiler/luci/pass/src/CircleOptimizer.cpp224
-rw-r--r--compiler/luci/pass/src/CircleOptimizer.test.cpp168
-rw-r--r--compiler/luci/pass/src/CircleQuantizer.cpp458
-rw-r--r--compiler/luci/pass/src/CircleQuantizer.test.cpp191
-rw-r--r--compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp6
-rw-r--r--compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp36
-rw-r--r--compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp214
-rw-r--r--compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.test.cpp277
-rw-r--r--compiler/luci/pass/src/CopyQuantParamPass.cpp82
-rw-r--r--compiler/luci/pass/src/FoldGatherPass.cpp185
-rw-r--r--compiler/luci/pass/src/FoldGatherPass.test.cpp214
-rw-r--r--compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp36
-rw-r--r--compiler/luci/pass/src/PropagateQParamBackwardPass.cpp482
-rw-r--r--compiler/luci/pass/src/PropagateQParamBackwardPass.test.cpp167
-rw-r--r--compiler/luci/pass/src/PropagateQParamForwardPass.cpp194
-rw-r--r--compiler/luci/pass/src/PropagateQParamForwardPass.test.cpp260
-rw-r--r--compiler/luci/pass/src/PropagateQuantParamPass.cpp107
-rw-r--r--compiler/luci/pass/src/PropagateQuantParamPass.test.cpp125
-rw-r--r--compiler/luci/pass/src/QuantizationUtils.cpp158
-rw-r--r--compiler/luci/pass/src/QuantizationUtils.h36
-rw-r--r--compiler/luci/pass/src/QuantizeActivation.cpp296
-rw-r--r--compiler/luci/pass/src/QuantizeActivation.h165
-rw-r--r--compiler/luci/pass/src/QuantizeBias.cpp300
-rw-r--r--compiler/luci/pass/src/QuantizeBias.h56
-rw-r--r--compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp259
-rw-r--r--compiler/luci/pass/src/QuantizeDequantizeWeightsPass.test.cpp14
-rw-r--r--compiler/luci/pass/src/QuantizePreCheckerPass.cpp119
-rw-r--r--compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp401
-rw-r--r--compiler/luci/pass/src/QuantizeWeights.cpp394
-rw-r--r--compiler/luci/pass/src/QuantizeWeights.h55
-rw-r--r--compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp1773
-rw-r--r--compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp49
-rw-r--r--compiler/luci/pass/src/QuantizedModelVerifier.cpp70
-rw-r--r--compiler/luci/pass/src/QuantizedModelVerifier.h30
-rw-r--r--compiler/luci/pass/src/QuantizedModelVerifier.test.cpp497
-rw-r--r--compiler/luci/pass/src/RemoveRedundantQuantizePass.cpp104
-rw-r--r--compiler/luci/pass/src/RemoveRedundantQuantizePass.test.cpp166
-rw-r--r--compiler/luci/pass/src/RemoveRedundantTransposePass.cpp2
-rw-r--r--compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp25
-rw-r--r--compiler/luci/pass/src/RemoveUnnecessaryReshapePass.cpp19
-rw-r--r--compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp26
-rw-r--r--compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp46
-rw-r--r--compiler/luci/pass/src/SubstituteSplitVToSplitPass.cpp13
-rw-r--r--compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp14
-rw-r--r--compiler/luci/pass/src/SubstituteStridedSliceToReshapePass.cpp2
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedBiasScale.cpp105
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedBiasScale.h59
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeGranularity.cpp38
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h (renamed from compiler/luci/pass/src/VerifyQuantizedNodeChannelWiseGranularity.h)301
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeLayerWiseGranularity.h473
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeS16Type.h516
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeType.cpp554
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeType.h157
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h518
-rw-r--r--compiler/luci/pass/src/helpers/LayerInfoMap.cpp189
-rw-r--r--compiler/luci/pass/src/helpers/LayerInfoMap.h33
-rw-r--r--compiler/luci/pass/src/helpers/LayerInfoMap.test.cpp201
-rw-r--r--compiler/luci/requires.cmake4
-rw-r--r--compiler/luci/service/CMakeLists.txt1
-rw-r--r--compiler/luci/service/include/luci/Service/CircleShapeInference.h7
-rw-r--r--compiler/luci/service/include/luci/Service/CircleTypeInference.h8
-rw-r--r--compiler/luci/service/src/CircleCloneNode.h2
-rw-r--r--compiler/luci/service/src/CircleNodeClone.cpp14
-rw-r--r--compiler/luci/service/src/CircleShapeInferenceRule.cpp90
-rw-r--r--compiler/luci/service/src/CircleTypeInferenceRule.cpp7
-rw-r--r--compiler/luci/service/src/Nodes/CircleSVDF.cpp37
-rw-r--r--compiler/luci/service/src/Nodes/CircleSVDF.test.cpp47
-rw-r--r--compiler/luci/service/src/Nodes/CircleVariable.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleVariable.test.cpp33
-rw-r--r--compiler/luci/tests/CMakeLists.txt33
-rw-r--r--compiler/luci/tests/test.lst4
-rw-r--r--compiler/mio-circle/CMakeLists.txt12
-rw-r--r--compiler/mio-circle/include/mio_circle/Helper.h36
-rw-r--r--compiler/mio-circle/src/Helper.cpp81
-rw-r--r--compiler/mio-circle04/CMakeLists.txt52
-rw-r--r--compiler/mio-circle04/README.md3
-rw-r--r--compiler/mio-circle04/example.cpp41
-rw-r--r--compiler/mio-circle04/include/mio_circle/Helper.h37
-rw-r--r--compiler/mio-circle04/src/Helper.cpp110
-rw-r--r--compiler/mio-circle04/src/Helper.test.cpp153
-rw-r--r--compiler/mio-tflite/CMakeLists.txt2
-rw-r--r--compiler/mio-tflite260/CMakeLists.txt24
-rw-r--r--compiler/mio-tflite260/include/mio_tflite260/Helper.h37
-rw-r--r--compiler/mio-tflite260/src/Helper.cpp104
-rw-r--r--compiler/mio-tflite260/src/Helper.test.cpp159
-rw-r--r--compiler/mio-tflite280/CMakeLists.txt69
-rw-r--r--compiler/mio-tflite280/README.md3
-rw-r--r--compiler/mio-tflite280/example.cpp41
-rw-r--r--compiler/mio-tflite280/include/mio_tflite280/Helper.h37
-rw-r--r--compiler/mio-tflite280/src/Helper.cpp104
-rw-r--r--compiler/mio-tflite280/src/Helper.test.cpp159
-rw-r--r--compiler/mir/src/mir_onnx_importer/CMakeLists.txt4
-rw-r--r--compiler/mir/src/mir_tflite_importer/CMakeLists.txt2
-rw-r--r--compiler/mir2loco/CMakeLists.txt8
-rw-r--r--compiler/moco-tf/CMakeLists.txt2
-rw-r--r--compiler/moco-tf/requires.cmake1
-rw-r--r--compiler/moco-tf/src/Transforms.h1
-rw-r--r--compiler/morph/CMakeLists.txt8
-rw-r--r--compiler/nest/core/CMakeLists.txt8
-rw-r--r--compiler/nike/CMakeLists.txt8
-rw-r--r--compiler/nnc/unittests/soft_backend/ModelAnalyzer.cpp2
-rw-r--r--compiler/nnop/CMakeLists.txt8
-rw-r--r--compiler/one-cmds/CMakeLists.txt39
-rw-r--r--compiler/one-cmds/how-to-prepare-virtualenv.txt19
-rw-r--r--compiler/one-cmds/how-to-use-one-commands.txt2
-rw-r--r--compiler/one-cmds/one-build18
-rw-r--r--compiler/one-cmds/one-import-bcq9
-rw-r--r--compiler/one-cmds/one-import-onnx26
-rw-r--r--compiler/one-cmds/one-import-pytorch366
-rw-r--r--compiler/one-cmds/one-import-tf9
-rw-r--r--compiler/one-cmds/one-import-tflite7
-rw-r--r--compiler/one-cmds/one-optimize6
-rw-r--r--compiler/one-cmds/one-prepare-venv39
-rw-r--r--compiler/one-cmds/one-quantize82
-rw-r--r--compiler/one-cmds/onecc33
-rw-r--r--compiler/one-cmds/onelib/constant.py86
-rw-r--r--compiler/one-cmds/onelib/make_cmd.py100
-rwxr-xr-xcompiler/one-cmds/onnx_legalizer.py1065
-rw-r--r--compiler/one-cmds/tests/CMakeLists.txt26
-rw-r--r--compiler/one-cmds/tests/one-quantize_009.qconf.json36
-rw-r--r--compiler/one-cmds/tests/one-quantize_009.test55
-rw-r--r--compiler/one-cmds/tests/onnx-operations/CMakeLists.txt86
-rw-r--r--compiler/one-cmds/tests/onnx-operations/README.md28
-rw-r--r--compiler/one-cmds/tests/onnx-operations/prepare_test_materials.sh26
-rw-r--r--compiler/one-cmds/tests/onnx_legalize_run_compare.py129
-rw-r--r--compiler/one-cmds/tests/prepare_test_materials.sh33
-rw-r--r--compiler/one-cmds/tests/print_onnx_model.py20
-rw-r--r--compiler/one-cmds/tests/pytorch-operations/CMakeLists.txt109
-rw-r--r--compiler/one-cmds/tests/pytorch-operations/README.md28
-rw-r--r--compiler/one-cmds/tests/pytorch-operations/aux_generator.py83
-rw-r--r--compiler/one-cmds/tests/pytorch-operations/entire_model.test40
-rw-r--r--compiler/one-cmds/tests/pytorch-operations/example_generator.py116
-rw-r--r--compiler/one-cmds/tests/pytorch-operations/mar_state_dict_model.test40
-rw-r--r--compiler/one-cmds/tests/pytorch-operations/mar_torchscript_model.test40
-rw-r--r--compiler/one-cmds/tests/pytorch-operations/prepare_test_materials.sh26
-rw-r--r--compiler/one-cmds/tests/pytorch-operations/state_dict_model.test39
-rw-r--r--compiler/one-cmds/tests/pytorch-operations/torchscript_model.test39
-rw-r--r--compiler/one-cmds/utils.py184
-rw-r--r--compiler/oneco/CMakeLists.txt8
-rw-r--r--compiler/pepper-strcast/CMakeLists.txt4
-rw-r--r--compiler/pota-quantization-value-test/CMakeLists.txt30
-rw-r--r--compiler/pota-quantization-value-test/config_files/Add_002/channel/int16/qconf.json9
-rw-r--r--compiler/pota-quantization-value-test/config_files/Add_002/layer/uint8/qconf.json9
-rw-r--r--compiler/pota-quantization-value-test/config_files/AveragePool2D_000/channel/int16/qconf.json9
-rw-r--r--compiler/pota-quantization-value-test/config_files/AveragePool2D_000/layer/uint8/qconf.json9
-rw-r--r--compiler/pota-quantization-value-test/config_files/Concatenation_001/channel/int16/qconf.json9
-rw-r--r--compiler/pota-quantization-value-test/config_files/Concatenation_001/layer/uint8/qconf.json9
-rw-r--r--compiler/pota-quantization-value-test/config_files/Conv2D_004/channel/int16/qconf.json9
-rw-r--r--compiler/pota-quantization-value-test/config_files/Conv2D_004/layer/uint8/qconf.json9
-rw-r--r--compiler/pota-quantization-value-test/config_files/DepthwiseConv2D_002/channel/int16/qconf.json9
-rw-r--r--compiler/pota-quantization-value-test/config_files/DepthwiseConv2D_002/layer/uint8/qconf.json9
-rw-r--r--compiler/pota-quantization-value-test/config_files/FullyConnected_003/channel/int16/qconf.json9
-rw-r--r--compiler/pota-quantization-value-test/config_files/FullyConnected_003/layer/uint8/qconf.json9
-rw-r--r--compiler/pota-quantization-value-test/config_files/InstanceNorm_001/channel/int16/qconf.json9
-rw-r--r--compiler/pota-quantization-value-test/config_files/InstanceNorm_001/layer/uint8/qconf.json9
-rw-r--r--compiler/pota-quantization-value-test/config_files/MaxPool2D_000/channel/int16/qconf.json9
-rw-r--r--compiler/pota-quantization-value-test/config_files/MaxPool2D_000/layer/uint8/qconf.json9
-rw-r--r--compiler/pota-quantization-value-test/config_files/Mean_000/channel/int16/qconf.json9
-rw-r--r--compiler/pota-quantization-value-test/config_files/Mean_000/layer/uint8/qconf.json9
-rw-r--r--compiler/pota-quantization-value-test/config_files/Mul_001/channel/int16/qconf.json9
-rw-r--r--compiler/pota-quantization-value-test/config_files/Mul_001/layer/uint8/qconf.json9
-rw-r--r--compiler/pota-quantization-value-test/config_files/PRelu_001/channel/int16/qconf.json9
-rw-r--r--compiler/pota-quantization-value-test/config_files/PRelu_001/layer/uint8/qconf.json9
-rw-r--r--compiler/pota-quantization-value-test/config_files/ReLU_000/channel/int16/qconf.json9
-rw-r--r--compiler/pota-quantization-value-test/config_files/ReLU_000/layer/uint8/qconf.json9
-rw-r--r--compiler/pota-quantization-value-test/config_files/Split_000/channel/int16/qconf.json14
-rw-r--r--compiler/pota-quantization-value-test/config_files/Split_000/channel/uint8/qconf.json14
-rw-r--r--compiler/pota-quantization-value-test/config_files/TransposeConv_001/channel/int16/qconf.json9
-rw-r--r--compiler/pota-quantization-value-test/config_files/TransposeConv_001/layer/uint8/qconf.json9
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Add_002_config/channel/int16/quantization/ifm1_Quantize.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Add_002_config/channel/int16/quantization/ifm2.json32
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Add_002_config/channel/int16/quantization/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Add_002_config/layer/uint8/quantization/ifm1_Quantize.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Add_002_config/layer/uint8/quantization/ifm2.json32
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Add_002_config/layer/uint8/quantization/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/AveragePool2D_000_config/channel/int16/quantization/ifm_Quantize.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/AveragePool2D_000_config/channel/int16/quantization/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/AveragePool2D_000_config/layer/uint8/quantization/ifm_Quantize.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/AveragePool2D_000_config/layer/uint8/quantization/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Concatenation_001_config/channel/int16/quantization/ifm1_Quantize.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Concatenation_001_config/channel/int16/quantization/ifm2.json28
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Concatenation_001_config/channel/int16/quantization/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Concatenation_001_config/layer/uint8/quantization/ifm1_Quantize.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Concatenation_001_config/layer/uint8/quantization/ifm2.json28
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Concatenation_001_config/layer/uint8/quantization/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/channel/int16/fake_quantization/ker.json48
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/channel/int16/quantization/bias.json7
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/channel/int16/quantization/ifm_Quantize.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/channel/int16/quantization/ker.json52
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/channel/int16/quantization/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/layer/uint8/fake_quantization/ker.json48
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/layer/uint8/quantization/bias.json10
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/layer/uint8/quantization/ifm_Quantize.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/layer/uint8/quantization/ker.json61
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/layer/uint8/quantization/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/channel/int16/uint8/fake_quantization/ker.json34
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/channel/int16/uint8/quantization/bias.json9
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/channel/int16/uint8/quantization/ifm_Quantize.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/channel/int16/uint8/quantization/ker.json38
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/channel/int16/uint8/quantization/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/layer/uint8/fake_quantization/ker.json34
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/layer/uint8/quantization/bias.json14
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/layer/uint8/quantization/ifm_Quantize.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/layer/uint8/quantization/ker.json53
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/layer/uint8/quantization/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/channel/int16/fake_quantization/weight.json76
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/channel/int16/quantization/bias.json9
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/channel/int16/quantization/in_Quantize.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/channel/int16/quantization/out.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/channel/int16/quantization/weight.json80
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/layer/uint8/fake_quantization/weight.json76
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/layer/uint8/quantization/bias.json14
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/layer/uint8/quantization/in_Quantize.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/layer/uint8/quantization/out.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/layer/uint8/quantization/weight.json95
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/MaxPool2D_000_config/channel/int16/quantization/ifm_Quantize.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/MaxPool2D_000_config/channel/int16/quantization/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/MaxPool2D_000_config/layer/uint8/quantization/ifm_Quantize.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/MaxPool2D_000_config/layer/uint8/quantization/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Mean_000_config/channel/int16/quantization/ifm_Quantize.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Mean_000_config/channel/int16/quantization/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Mean_000_config/layer/uint8/quantization/ifm_Quantize.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Mean_000_config/layer/uint8/quantization/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Mean_000_config/layer/uint8/quantization/reduction_indices.json5
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Mul_001_config/channel/int16/quantization/ifm1_Quantize.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Mul_001_config/channel/int16/quantization/ifm2.json32
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Mul_001_config/channel/int16/quantization/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Mul_001_config/layer/uint8/quantization/ifm1_Quantize.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Mul_001_config/layer/uint8/quantization/ifm2.json32
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Mul_001_config/layer/uint8/quantization/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/PRelu_001_config/channel/int16/quantization/alpha.json13
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/PRelu_001_config/channel/int16/quantization/ifm_Quantize.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/PRelu_001_config/channel/int16/quantization/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/PRelu_001_config/layer/uint8/quantization/alpha.json21
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/PRelu_001_config/layer/uint8/quantization/ifm_Quantize.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/PRelu_001_config/layer/uint8/quantization/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/ReLU_000_config/channel/int16/quantization/ifm_Quantize.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/ReLU_000_config/channel/int16/quantization/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/ReLU_000_config/layer/uint8/quantization/ifm_Quantize.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/ReLU_000_config/layer/uint8/quantization/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/int16/quantization/ifm_Quantize.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/int16/quantization/ofm1.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/int16/quantization/ofm2.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/int16/quantization/split_dim.json5
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/uint8/quantization/ifm_Quantize.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/uint8/quantization/ofm1.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/uint8/quantization/ofm2.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/uint8/quantization/split_dim.json5
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/channel/int16/fake_quantization/ker.json48
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/channel/int16/quantization/ifm_Quantize.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/channel/int16/quantization/ker.json52
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/channel/int16/quantization/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/layer/uint8/fake_quantization/ker.json48
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/layer/uint8/quantization/ifm_Quantize.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/layer/uint8/quantization/ker.json58
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/layer/uint8/quantization/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/requires.cmake2
-rw-r--r--compiler/pota-quantization-value-test/test.lst29
-rwxr-xr-xcompiler/pota-quantization-value-test/test_fake_wquant_with_config.sh87
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Add_002_config/channel/int16/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Add_002_config/channel/int16/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Add_002_config/channel/int16/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Add_002_config/channel/int16/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Add_002_config/channel/int16/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Add_002_config/layer/uint8/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Add_002_config/layer/uint8/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Add_002_config/layer/uint8/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Add_002_config/layer/uint8/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Add_002_config/layer/uint8/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/channel/int16/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/channel/int16/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/channel/int16/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/channel/int16/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/channel/int16/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/layer/uint8/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/layer/uint8/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/layer/uint8/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/layer/uint8/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/layer/uint8/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/channel/int16/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/channel/int16/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/channel/int16/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/channel/int16/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/channel/int16/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/layer/uint8/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/layer/uint8/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/layer/uint8/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/layer/uint8/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/layer/uint8/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/channel/int16/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/channel/int16/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/channel/int16/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/channel/int16/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/channel/int16/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/layer/uint8/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/layer/uint8/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/layer/uint8/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/layer/uint8/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/layer/uint8/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/channel/int16/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/channel/int16/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/channel/int16/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/channel/int16/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/channel/int16/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/layer/uint8/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/layer/uint8/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/layer/uint8/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/layer/uint8/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/layer/uint8/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/channel/int16/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/channel/int16/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/channel/int16/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/channel/int16/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/channel/int16/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/layer/uint8/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/layer/uint8/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/layer/uint8/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/layer/uint8/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/layer/uint8/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/channel/int16/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/channel/int16/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/channel/int16/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/channel/int16/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/channel/int16/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/layer/uint8/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/layer/uint8/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/layer/uint8/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/layer/uint8/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/layer/uint8/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Mean_000_config/channel/int16/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Mean_000_config/channel/int16/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Mean_000_config/channel/int16/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Mean_000_config/channel/int16/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Mean_000_config/channel/int16/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Mean_000_config/layer/uint8/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Mean_000_config/layer/uint8/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Mean_000_config/layer/uint8/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Mean_000_config/layer/uint8/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Mean_000_config/layer/uint8/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Mul_001_config/channel/int16/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Mul_001_config/channel/int16/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Mul_001_config/channel/int16/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Mul_001_config/channel/int16/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Mul_001_config/channel/int16/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Mul_001_config/layer/uint8/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Mul_001_config/layer/uint8/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Mul_001_config/layer/uint8/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Mul_001_config/layer/uint8/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Mul_001_config/layer/uint8/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/channel/int16/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/channel/int16/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/channel/int16/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/channel/int16/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/channel/int16/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/layer/uint8/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/layer/uint8/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/layer/uint8/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/layer/uint8/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/layer/uint8/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/channel/int16/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/channel/int16/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/channel/int16/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/channel/int16/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/channel/int16/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/layer/uint8/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/layer/uint8/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/layer/uint8/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/layer/uint8/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/layer/uint8/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/int16/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/int16/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/int16/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/int16/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/int16/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/uint8/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/uint8/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/uint8/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/uint8/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/uint8/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/channel/int16/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/channel/int16/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/channel/int16/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/channel/int16/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/channel/int16/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/layer/uint8/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/layer/uint8/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/layer/uint8/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/layer/uint8/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/layer/uint8/4.txt1
-rwxr-xr-xcompiler/pota-quantization-value-test/test_quantization_with_config.sh109
-rw-r--r--compiler/pp/CMakeLists.txt4
-rw-r--r--compiler/record-minmax-conversion-test/CMakeLists.txt2
-rw-r--r--compiler/record-minmax/CMakeLists.txt10
-rw-r--r--compiler/record-minmax/requires.cmake1
-rw-r--r--compiler/record-minmax/src/HDF5Importer.h87
-rw-r--r--compiler/record-minmax/src/MinMaxObserver.cpp13
-rw-r--r--compiler/record-minmax/src/RecordMinMax.cpp106
-rw-r--r--compiler/souschef/CMakeLists.txt2
-rw-r--r--compiler/tf2tfliteV2-conversion-test/CMakeLists.txt2
-rw-r--r--compiler/tfl-inspect/CMakeLists.txt3
-rw-r--r--compiler/tfl-inspect/requires.cmake2
-rw-r--r--compiler/tfl-inspect/src/Reader.cpp74
-rw-r--r--compiler/tfl-inspect/src/Reader.h7
-rw-r--r--compiler/tfl-verify/CMakeLists.txt2
-rw-r--r--compiler/tfl-verify/requires.cmake2
-rw-r--r--compiler/tflchef/CMakeLists.txt6
-rw-r--r--compiler/tflchef/core/CMakeLists.txt2
-rw-r--r--compiler/tflchef/core/src/ModelChef.cpp8
-rw-r--r--compiler/tflchef/core/src/Op/FullyConnected.cpp1
-rw-r--r--compiler/tflchef/core/src/Op/SVDF.cpp41
-rw-r--r--compiler/tflchef/core/src/Op/SVDF.h46
-rw-r--r--compiler/tflchef/core/src/OpChef.def1
-rw-r--r--compiler/tflchef/core/src/OpChefs.h1
-rw-r--r--compiler/tflchef/proto/tflchef.proto13
-rw-r--r--compiler/tflchef/requires.cmake2
-rw-r--r--compiler/tflchef/tests/CMakeLists.txt43
-rw-r--r--compiler/tflchef/tests/signature_def_index/test.recipe3
-rw-r--r--compiler/tflchef/tests/signature_def_name/test.recipe3
-rw-r--r--compiler/tflchef/tflite/CMakeLists.txt3
-rw-r--r--compiler/tflchef/tflite/src/Op/FullyConnected.cpp1
-rw-r--r--compiler/tflchef/tflite/src/Op/SVDF.cpp59
-rw-r--r--compiler/tflchef/tflite/src/Op/SVDF.h39
-rw-r--r--compiler/tflchef/tflite/src/RecipeChef.cpp11
-rw-r--r--compiler/tflchef/tflite/src/TFliteImport.cpp49
-rw-r--r--compiler/tflchef/tflite/src/TFliteImport.h6
-rw-r--r--compiler/tflchef/tflite/src/TFliteOpChefs.h1
-rw-r--r--compiler/tflchef/tflite/src/TFliteOpRegistry.h1
-rw-r--r--compiler/tfldump/CMakeLists.txt10
-rw-r--r--compiler/tfldump/requires.cmake2
-rw-r--r--compiler/tfldump/src/Dump.cpp26
-rw-r--r--compiler/tfldump/src/Load.cpp2
-rw-r--r--compiler/tfldump/src/OpPrinter.cpp18
-rw-r--r--compiler/tfldump/src/Read.cpp72
-rw-r--r--compiler/tfldump/src/Read.h7
-rw-r--r--compiler/tflite2circle/CMakeLists.txt9
-rw-r--r--compiler/tflite2circle/requires.cmake4
-rw-r--r--compiler/tflite2circle/src/BuildBuiltinOptions.h1
-rw-r--r--compiler/tflite2circle/src/BuildBuiltinOptions/FullyConnectedOptions.cpp1
-rw-r--r--compiler/tflite2circle/src/BuildBuiltinOptions/SVDFOptions.cpp41
-rw-r--r--compiler/tflite2circle/src/BuildBuiltinOptions/SVDFOptions.h31
-rw-r--r--compiler/tflite2circle/src/CircleModel.cpp42
-rw-r--r--compiler/tflite2circle/src/DataLookup.cpp16
-rw-r--r--compiler/tflite2circle/src/DataLookup.h2
-rw-r--r--compiler/tflite2circle/src/TFLBuiltinOptions.lst2
-rw-r--r--compiler/vconone/CMakeLists.txt2
867 files changed, 33719 insertions, 10210 deletions
diff --git a/compiler/angkor/CMakeLists.txt b/compiler/angkor/CMakeLists.txt
index 44b5e9058..7f5cb88c2 100644
--- a/compiler/angkor/CMakeLists.txt
+++ b/compiler/angkor/CMakeLists.txt
@@ -5,7 +5,9 @@ list(REMOVE_ITEM SOURCES ${TESTS})
# NOTE STATIC is deliberately used here to allow clients to use 'angkor' without installation
add_library(angkor STATIC ${HEADERS} ${SOURCES})
-set_target_properties(angkor PROPERTIES POSITION_INDEPENDENT_CODE ON)
+if (NOT NNCC_LIBRARY_NO_PIC)
+ set_target_properties(angkor PROPERTIES POSITION_INDEPENDENT_CODE ON)
+endif (NOT NNCC_LIBRARY_NO_PIC)
set_target_properties(angkor PROPERTIES LINKER_LANGUAGE CXX)
target_include_directories(angkor PUBLIC include)
target_link_libraries(angkor PRIVATE nncc_common)
diff --git a/compiler/arser/tests/arser.test.cpp b/compiler/arser/tests/arser.test.cpp
index 4e88f0cb7..63121b845 100644
--- a/compiler/arser/tests/arser.test.cpp
+++ b/compiler/arser/tests/arser.test.cpp
@@ -23,30 +23,9 @@
#include "arser/arser.h"
-using namespace arser;
+#include "Prompt.h"
-class Prompt
-{
-public:
- Prompt(const std::string &command)
- {
- std::istringstream iss(command);
- std::vector<std::string> token(std::istream_iterator<std::string>{iss},
- std::istream_iterator<std::string>());
- _arg = std::move(token);
- _argv.reserve(_arg.size());
- for (const auto &t : _arg)
- {
- _argv.push_back(const_cast<char *>(t.data()));
- }
- }
- int argc(void) const { return _argv.size(); }
- char **argv(void) { return _argv.data(); }
-
-private:
- std::vector<char *> _argv;
- std::vector<std::string> _arg;
-};
+using namespace arser;
TEST(BasicTest, option)
{
@@ -57,7 +36,7 @@ TEST(BasicTest, option)
.nargs(0)
.help("It provides additional details as to what the executable is doing");
- Prompt prompt("./executable --verbose");
+ test::Prompt prompt("./executable --verbose");
/* act */
arser.parse(prompt.argc(), prompt.argv());
/* assert */
@@ -79,7 +58,7 @@ TEST(BasicTest, OptionalArgument)
.type(arser::DataType::FLOAT)
.help("Set a frequency as you provided.");
- Prompt prompt("./radio --volume 5 --frequency 128.5");
+ test::Prompt prompt("./radio --volume 5 --frequency 128.5");
/* act */
arser.parse(prompt.argc(), prompt.argv());
/* assert */
@@ -103,7 +82,7 @@ TEST(BasicTest, NonRequiredOptionalArgument_NEG)
.type(arser::DataType::INT32)
.help("Set a volume as you provided.");
- Prompt prompt("./radio"); // empty argument
+ test::Prompt prompt("./radio"); // empty argument
/* act */
arser.parse(prompt.argc(), prompt.argv());
/* assert */
@@ -122,7 +101,7 @@ TEST(BasicTest, RequiredOptionalArgument_NEG)
.required()
.help("Set a volume as you provided.");
- Prompt prompt("./radio");
+ test::Prompt prompt("./radio");
/* act */ /* assert */
EXPECT_THROW(arser.parse(prompt.argc(), prompt.argv()), std::runtime_error);
}
@@ -134,7 +113,7 @@ TEST(BasicTest, OptionalMultipleArgument)
arser.add_argument("--add").nargs(2).type(arser::DataType::INT32_VEC).help("Add two numbers.");
- Prompt prompt("./calculator --add 3 5");
+ test::Prompt prompt("./calculator --add 3 5");
/* act */
arser.parse(prompt.argc(), prompt.argv());
/* assert */
@@ -167,8 +146,8 @@ TEST(BasicTest, MultipleOptionalArgument)
.help("give traning data to this program.")
.required();
- Prompt prompt("./ml --input_path /I/am/in.put --output_path I/am/out.put "
- "--training_data 2 43 234 3 334");
+ test::Prompt prompt("./ml --input_path /I/am/in.put --output_path I/am/out.put "
+ "--training_data 2 43 234 3 334");
/* act */
arser.parse(prompt.argc(), prompt.argv());
/* assert */
@@ -195,7 +174,7 @@ TEST(BasicTest, MultipleFloatValue)
.type(arser::DataType::FLOAT_VEC)
.help("Add two float numbers.");
- Prompt prompt("./calculator --add_float 3.2 5.4");
+ test::Prompt prompt("./calculator --add_float 3.2 5.4");
/* act */
arser.parse(prompt.argc(), prompt.argv());
/* assert */
@@ -217,7 +196,7 @@ TEST(BasicTest, MultipleStringValue)
.type(arser::DataType::STR_VEC)
.help("insert your three favorite color");
- Prompt prompt("./color_factory --three_color red blue yellow");
+ test::Prompt prompt("./color_factory --three_color red blue yellow");
/* act */
arser.parse(prompt.argc(), prompt.argv());
/* assert */
@@ -241,7 +220,7 @@ TEST(BasicTest, ExitWithFunctionCall)
arser.add_argument("--name").nargs(1).type(arser::DataType::STR).help("Name your hero");
- Prompt prompt("./hero --history");
+ test::Prompt prompt("./hero --history");
/* act */ /* assert */
EXPECT_EXIT(arser.parse(prompt.argc(), prompt.argv()), testing::ExitedWithCode(0),
"When I was young..");
@@ -258,7 +237,7 @@ TEST(BasicTest, ExitWithFunctionCallWithBind)
.help("Show version and exit")
.exit_with(std::bind(printVersion, "1.2.0"));
- Prompt prompt("./arser --version");
+ test::Prompt prompt("./arser --version");
/* act */ /* assert */
EXPECT_EXIT(arser.parse(prompt.argc(), prompt.argv()), testing::ExitedWithCode(0),
"arser version : 1.2.0");
@@ -275,7 +254,7 @@ TEST(BasicTest, ExitWithFunctionCallWithLamda)
arser.add_argument("OS").nargs(1).type(arser::DataType::STR).help("The OS you want to boot");
- Prompt prompt("./computer --shutdown");
+ test::Prompt prompt("./computer --shutdown");
/* act */ /* assert */
EXPECT_EXIT(arser.parse(prompt.argc(), prompt.argv()), testing::ExitedWithCode(0), "Good bye..");
}
@@ -315,7 +294,7 @@ TEST(BasicTest, DefaultValue)
.default_value("no name")
.help("Enter your name");
- Prompt prompt("/phone --time 1 52 34 --name arser");
+ test::Prompt prompt("/phone --time 1 52 34 --name arser");
/* act */
arser.parse(prompt.argc(), prompt.argv());
/* assert */
@@ -359,7 +338,7 @@ TEST(BasicTest, shortOption)
.help("output path of this program.")
.required(true);
- Prompt prompt("./driver -i /I/am/in.put --output_path I/am/out.put");
+ test::Prompt prompt("./driver -i /I/am/in.put --output_path I/am/out.put");
/* act */
arser.parse(prompt.argc(), prompt.argv());
/* assert */
@@ -385,7 +364,7 @@ TEST(BasicTest, shortMultipleOption)
.help("output path of this program.")
.required(true);
- Prompt prompt("./driver --in /I/am/in.put -o I/am/out.put");
+ test::Prompt prompt("./driver --in /I/am/in.put -o I/am/out.put");
/* act */
arser.parse(prompt.argc(), prompt.argv());
/* assert */
@@ -411,7 +390,7 @@ TEST(BasicTest, OptWithRequiredDuplicate_NEG)
.help("output path of this program.")
.required(true);
- Prompt prompt("./driver --in /I/am/in.put -o I/am/out.put -i /I/am/duplicate");
+ test::Prompt prompt("./driver --in /I/am/in.put -o I/am/out.put -i /I/am/duplicate");
/* act */ /* assert */
EXPECT_THROW(arser.parse(prompt.argc(), prompt.argv()), std::runtime_error);
}
@@ -432,7 +411,7 @@ TEST(BasicTest, OptWithNonRequiredDuplicate)
.help("output path of this program.")
.required(true);
- Prompt prompt("./driver --in /I/am/in.put -o I/am/out.put -i /I/am/duplicate");
+ test::Prompt prompt("./driver --in /I/am/in.put -o I/am/out.put -i /I/am/duplicate");
/* act */
arser.parse(prompt.argc(), prompt.argv());
/* assert */
@@ -449,7 +428,7 @@ TEST(BasicTest, AccumulateVectorOptions)
arser.add_argument("--specify").nargs(3).accumulated(true).type(arser::DataType::STR_VEC);
- Prompt prompt("./driver --specify a b c --specify 1 2 3");
+ test::Prompt prompt("./driver --specify a b c --specify 1 2 3");
/* act */
arser.parse(prompt.argc(), prompt.argv());
/* assert */
@@ -473,7 +452,7 @@ TEST(BasicTest, AccumulateScalarOptions)
arser.add_argument("--specify").nargs(1).accumulated(true).type(arser::DataType::FLOAT);
- Prompt prompt("./driver --specify 1 --specify 2");
+ test::Prompt prompt("./driver --specify 1 --specify 2");
/* act */
arser.parse(prompt.argc(), prompt.argv());
/* assert */
@@ -491,7 +470,7 @@ TEST(BasicTest, AccumulateScalarOptions_WrongType_NEG)
arser.add_argument("--specify").nargs(1).accumulated(true).type(arser::DataType::FLOAT);
- Prompt prompt("./driver --specify 1 --specify 2");
+ test::Prompt prompt("./driver --specify 1 --specify 2");
/* act */
arser.parse(prompt.argc(), prompt.argv());
/* assert */
diff --git a/compiler/circle-eval-diff/CMakeLists.txt b/compiler/circle-eval-diff/CMakeLists.txt
new file mode 100644
index 000000000..4d86f8097
--- /dev/null
+++ b/compiler/circle-eval-diff/CMakeLists.txt
@@ -0,0 +1,34 @@
+set(DRIVER "driver/Driver.cpp")
+
+file(GLOB_RECURSE SOURCES "src/*.cpp")
+file(GLOB_RECURSE TESTS "src/*.test.cpp")
+list(REMOVE_ITEM SOURCES ${TESTS})
+
+add_executable(circle-eval-diff ${DRIVER} ${SOURCES})
+target_include_directories(circle-eval-diff PRIVATE include)
+
+target_link_libraries(circle-eval-diff arser)
+target_link_libraries(circle-eval-diff safemain)
+target_link_libraries(circle-eval-diff foder)
+target_link_libraries(circle-eval-diff loco)
+target_link_libraries(circle-eval-diff luci_import)
+target_link_libraries(circle-eval-diff luci_lang)
+target_link_libraries(circle-eval-diff luci_interpreter)
+target_link_libraries(circle-eval-diff dio_hdf5)
+target_link_libraries(circle-eval-diff vconone)
+
+if(NOT ENABLE_TEST)
+ return()
+endif(NOT ENABLE_TEST)
+
+# circle-eval-diff is executable, so we do not link it to the test.
+# Instead, we use TEST_SOURCES to specify sources uesd for tests.
+set(TEST_SOURCES
+ "src/MetricPrinter.cpp"
+ "src/Tensor.cpp")
+
+nnas_find_package(GTest REQUIRED)
+GTest_AddTest(circle_eval_diff_test ${TESTS} ${TEST_SOURCES})
+target_include_directories(circle_eval_diff_test PRIVATE src)
+target_link_libraries(circle_eval_diff_test luci_testhelper)
+target_link_libraries(circle_eval_diff_test nncc_coverage)
diff --git a/compiler/circle-eval-diff/README.md b/compiler/circle-eval-diff/README.md
new file mode 100644
index 000000000..a3727cc6d
--- /dev/null
+++ b/compiler/circle-eval-diff/README.md
@@ -0,0 +1,51 @@
+# circle-eval-diff
+
+_circle-eval-diff_ compares inference results of two circle models.
+
+## Use cases
+
+1. _circle-eval-diff_ can be used to evaluate reconstruction errors of quantized models.
+2. _circle-eval-diff_ can be used to verify optimization (or any kind of value-preserving conversion) is safe.
+
+## Usage
+
+Run circle-eval-diff with the following arguments.
+
+--first_input_model: first model to compare (.circle).
+
+--second_input_model: second model to compare (.circle).
+
+--first_input_data: input data for the first model (.h5, directory). Random data will be used if this argument is not given.
+
+--second_input_data: input data for the second model (.h5, directory). Random data will be used if this argument is not given.
+
+--input_data_format: input data format (h5 (default), directory).
+
+--metric: metric to compare inference results (MAE (default), etc).
+
+```
+$ ./circle-eval-diff
+ --first_input_model <first_input_model>
+ --second_input_model <second_input_model>
+ --first_input_data <first_input_data>
+ --second_input_data <second_input_data>
+ --input_data_format <data_format>
+ --metric <metric>
+```
+
+For example,
+```
+$ ./circle-eval-diff
+ --first_input_model A.circle
+ --second_input_model B.circle
+ --first_input_data A.h5
+ --second_input_data B.h5
+ --input_data_format h5
+ --metric MAE
+```
+
+It will print MAE (Mean Absolute Error) between the inference result of A.circle with A.h5 and that of B.circle with B.h5.
+
+## Note
+
+Circle models are executed by _luci-interpreter_.
diff --git a/compiler/circle-eval-diff/driver/Driver.cpp b/compiler/circle-eval-diff/driver/Driver.cpp
new file mode 100644
index 000000000..f4a12a403
--- /dev/null
+++ b/compiler/circle-eval-diff/driver/Driver.cpp
@@ -0,0 +1,156 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleEvalDiff.h"
+
+#include <arser/arser.h>
+#include <vconone/vconone.h>
+
+using namespace circle_eval_diff;
+
+namespace
+{
+
+std::string to_lower_case(std::string s)
+{
+ std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); });
+ return s;
+}
+
+Metric to_metric(const std::string &str)
+{
+ if (to_lower_case(str).compare("mae") == 0)
+ return Metric::MAE;
+
+ throw std::runtime_error("Unsupported metric.");
+}
+
+InputFormat to_input_format(const std::string &str)
+{
+ if (to_lower_case(str).compare("h5") == 0)
+ return InputFormat::H5;
+
+ throw std::runtime_error("Unsupported input format.");
+}
+
+void print_version(void)
+{
+ std::cout << "circle-eval-diff version " << vconone::get_string() << std::endl;
+ std::cout << vconone::get_copyright() << std::endl;
+}
+
+} // namespace
+
+int entry(const int argc, char **argv)
+{
+ arser::Arser arser("Compare inference results of two circle models");
+
+ arser.add_argument("--version")
+ .nargs(0)
+ .required(false)
+ .default_value(false)
+ .help("Show version information and exit")
+ .exit_with(print_version);
+
+ arser.add_argument("--first_model")
+ .nargs(1)
+ .type(arser::DataType::STR)
+ .required(true)
+ .help("First input model filepath");
+
+ arser.add_argument("--second_model")
+ .nargs(1)
+ .type(arser::DataType::STR)
+ .required(true)
+ .help("Second input model filepath");
+
+ arser.add_argument("--first_input_data")
+ .nargs(1)
+ .type(arser::DataType::STR)
+ .required(false)
+ .help("Input data filepath for the first model. If not given, circle-eval-diff will run with "
+ "randomly generated data");
+
+ arser.add_argument("--second_input_data")
+ .nargs(1)
+ .type(arser::DataType::STR)
+ .required(false)
+ .help("Input data filepath for the second model. If not given, circle-eval-diff will run with "
+ "randomly generated data");
+
+ arser.add_argument("--metric")
+ .nargs(1)
+ .type(arser::DataType::STR)
+ .required(false)
+ .default_value("MAE")
+ .help("Metric for comparison (default: MAE)");
+
+ arser.add_argument("--input_data_format")
+ .nargs(1)
+ .type(arser::DataType::STR)
+ .required(false)
+ .default_value("h5")
+ .help("Input data format. h5/hdf5 (default) or directory");
+
+ try
+ {
+ arser.parse(argc, argv);
+ }
+ catch (const std::runtime_error &err)
+ {
+ std::cout << err.what() << std::endl;
+ std::cout << arser;
+ return 255;
+ }
+
+ const auto first_model_path = arser.get<std::string>("--first_model");
+ const auto second_model_path = arser.get<std::string>("--second_model");
+
+ // Default values
+ std::string first_input_data_path;
+ std::string second_input_data_path;
+ std::string metric;
+ std::string input_data_format;
+
+ if (arser["--first_input_data"])
+ first_input_data_path = arser.get<std::string>("--first_input_data");
+
+ if (arser["--second_input_data"])
+ second_input_data_path = arser.get<std::string>("--second_input_data");
+
+ if (arser["--first_input_data"] != arser["--second_input_data"])
+ throw std::runtime_error("Input data path should be given for both first_model and "
+ "second_model, or neither must be given.");
+
+ metric = arser.get<std::string>("--metric");
+ input_data_format = arser.get<std::string>("--input_data_format");
+
+ auto ctx = std::make_unique<CircleEvalDiff::Context>();
+ {
+ ctx->first_model_path = first_model_path;
+ ctx->second_model_path = second_model_path;
+ ctx->metric = to_metric(metric);
+ ctx->input_format = to_input_format(input_data_format);
+ }
+
+ CircleEvalDiff ced(std::move(ctx));
+
+ ced.init();
+
+ ced.evalDiff(first_input_data_path, second_input_data_path);
+
+ return EXIT_SUCCESS;
+}
diff --git a/compiler/circle-eval-diff/include/CircleEvalDiff.h b/compiler/circle-eval-diff/include/CircleEvalDiff.h
new file mode 100644
index 000000000..bf6aff46d
--- /dev/null
+++ b/compiler/circle-eval-diff/include/CircleEvalDiff.h
@@ -0,0 +1,74 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CIRCLE_EVAL_DIFF_H__
+#define __CIRCLE_EVAL_DIFF_H__
+
+#include <luci/IR/Module.h>
+#include <luci_interpreter/Interpreter.h>
+
+#include <string>
+#include <memory>
+
+namespace circle_eval_diff
+{
+
+// Forward declaration
+class ModuleEvalDiff;
+
+enum class Metric
+{
+ Undefined, // For debugging
+ MAE,
+};
+
+enum class InputFormat
+{
+ Undefined, // For debugging
+ H5,
+ // TODO Implement Random, Directory
+};
+
+class CircleEvalDiff final
+{
+public:
+ struct Context
+ {
+ std::string first_model_path;
+ std::string second_model_path;
+ Metric metric = Metric::Undefined;
+ InputFormat input_format = InputFormat::Undefined;
+ };
+
+public:
+ CircleEvalDiff(std::unique_ptr<Context> &&ctx);
+
+ ~CircleEvalDiff();
+
+ void init();
+
+ // Evaluate two circle models for the given input data and compare the results
+ void evalDiff(const std::string &first_input_data_path,
+ const std::string &second_input_data_path) const;
+
+private:
+ std::unique_ptr<Context> _ctx;
+ std::unique_ptr<ModuleEvalDiff> _runner;
+};
+
+} // namespace circle_eval_diff
+
+#endif // __CIRCLE_EVAL_DIFF_H__
diff --git a/compiler/circle-eval-diff/requires.cmake b/compiler/circle-eval-diff/requires.cmake
new file mode 100644
index 000000000..cae9b7c62
--- /dev/null
+++ b/compiler/circle-eval-diff/requires.cmake
@@ -0,0 +1,7 @@
+require("loco")
+require("luci")
+require("luci-interpreter")
+require("dio-hdf5")
+require("safemain")
+require("arser")
+require("vconone")
diff --git a/compiler/circle-eval-diff/src/CircleEvalDiff.cpp b/compiler/circle-eval-diff/src/CircleEvalDiff.cpp
new file mode 100644
index 000000000..c39a11371
--- /dev/null
+++ b/compiler/circle-eval-diff/src/CircleEvalDiff.cpp
@@ -0,0 +1,97 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleEvalDiff.h"
+#include "ModuleEvalDiff.h"
+#include "MetricPrinter.h"
+
+#include <foder/FileLoader.h>
+#include <luci/Importer.h>
+
+#include <stdexcept>
+
+namespace
+{
+
+std::unique_ptr<luci::Module> import(const std::string &model_path)
+{
+ // Load model from the file
+ foder::FileLoader loader{model_path};
+ std::vector<char> model_data = loader.load();
+
+ // Verify flatbuffers
+ flatbuffers::Verifier verifier{reinterpret_cast<const uint8_t *>(model_data.data()),
+ model_data.size()};
+ if (not circle::VerifyModelBuffer(verifier))
+ {
+ throw std::runtime_error("Failed to verify circle '" + model_path + "'");
+ }
+
+ auto module = luci::Importer().importModule(circle::GetModel(model_data.data()));
+
+ if (not module)
+ throw std::runtime_error("Failed to load '" + model_path + "'");
+
+ return module;
+}
+
+} // namespace
+
+namespace circle_eval_diff
+{
+
+CircleEvalDiff::CircleEvalDiff(std::unique_ptr<Context> &&ctx)
+ : _ctx(std::move(ctx)), _runner(nullptr)
+{
+}
+
+CircleEvalDiff::~CircleEvalDiff() = default;
+
+void CircleEvalDiff::init()
+{
+ // Set metric
+ std::unique_ptr<MetricPrinter> metric;
+ switch (_ctx->metric)
+ {
+ case Metric::MAE:
+ metric = std::make_unique<MAEPrinter>();
+ break;
+ default:
+ throw std::runtime_error("Unsupported metric.");
+ }
+
+ auto first_module = import(_ctx->first_model_path);
+ auto second_module = import(_ctx->second_model_path);
+
+ // Set runner
+ switch (_ctx->input_format)
+ {
+ case InputFormat::H5:
+ _runner = std::make_unique<H5InputEvalDiff>(std::move(first_module), std::move(second_module),
+ std::move(metric));
+ break;
+ default:
+ throw std::runtime_error("Unsupported input format.");
+ }
+}
+
+void CircleEvalDiff::evalDiff(const std::string &first_input_data_path,
+ const std::string &second_input_data_path) const
+{
+ _runner->evalDiff(first_input_data_path, second_input_data_path);
+}
+
+} // namespace circle_eval_diff
diff --git a/compiler/circle-eval-diff/src/MetricPrinter.cpp b/compiler/circle-eval-diff/src/MetricPrinter.cpp
new file mode 100644
index 000000000..d65eb9b63
--- /dev/null
+++ b/compiler/circle-eval-diff/src/MetricPrinter.cpp
@@ -0,0 +1,185 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "MetricPrinter.h"
+
+#include <luci/IR/CircleNode.h>
+
+#include <iostream>
+#include <cassert>
+
+using Tensor = circle_eval_diff::Tensor;
+
+#define THROW_UNLESS(COND, MSG) \
+ if (not(COND)) \
+ throw std::runtime_error(MSG);
+
+namespace
+{
+
+template <typename T> bool same_shape(const T a, const T b)
+{
+ if (a->rank() != b->rank())
+ return false;
+
+ for (uint32_t i = 0; i < a->rank(); i++)
+ {
+ if (not(a->dim(i) == b->dim(i)))
+ return false;
+ }
+
+ return true;
+}
+
+template <loco::DataType DT> std::shared_ptr<Tensor> to_fp32(const std::shared_ptr<Tensor> &tensor)
+{
+ assert(tensor->dtype() == DT); // FIX_CALLER_UNLESS
+
+ auto fp32_tensor = std::make_shared<Tensor>();
+ {
+ fp32_tensor->dtype(loco::DataType::FLOAT32);
+ fp32_tensor->rank(tensor->rank());
+ for (uint32_t i = 0; i < tensor->rank(); i++)
+ fp32_tensor->dim(i) = tensor->dim(i);
+
+ const auto num_elems = tensor->size<DT>();
+ fp32_tensor->size<loco::DataType::FLOAT32>(num_elems);
+ for (uint32_t i = 0; i < num_elems; i++)
+ fp32_tensor->at<loco::DataType::FLOAT32>(i) = static_cast<float>(tensor->at<DT>(i));
+ }
+ return fp32_tensor;
+}
+
+std::shared_ptr<Tensor> fp32(const std::shared_ptr<Tensor> &tensor)
+{
+ switch (tensor->dtype())
+ {
+ case loco::DataType::FLOAT32:
+ return tensor;
+ case loco::DataType::U8:
+ return to_fp32<loco::DataType::U8>(tensor);
+ case loco::DataType::S16:
+ return to_fp32<loco::DataType::S16>(tensor);
+ default:
+ throw std::runtime_error("Unsupported data type.");
+ }
+}
+
+} // namespace
+
+namespace circle_eval_diff
+{
+
+void MAEPrinter::init(const luci::Module *first, const luci::Module *second)
+{
+ THROW_UNLESS(first != nullptr, "Invalid module.");
+ THROW_UNLESS(second != nullptr, "Invalid module.");
+
+ const auto first_output = loco::output_nodes(first->graph());
+ const auto second_output = loco::output_nodes(second->graph());
+
+ assert(first_output.size() == second_output.size()); // FIX_CALLER_UNLESS
+
+ for (uint32_t i = 0; i < first_output.size(); i++)
+ {
+ const auto first_node = loco::must_cast<luci::CircleNode *>(first_output[i]);
+ const auto second_node = loco::must_cast<luci::CircleNode *>(second_output[i]);
+ assert(same_shape(first_node, second_node)); // FIX_CALLER_UNLESS
+
+ // Create tensors to store intermediate results
+ _intermediate.emplace_back();
+ _intermediate.at(i).dtype(loco::DataType::FLOAT32);
+ // NOTE Use both first_node and second_node to avoid release build break
+ _intermediate.at(i).rank(first_node->rank());
+ uint32_t num_elems = 1;
+ for (uint32_t j = 0; j < second_node->rank(); j++)
+ {
+ _intermediate.at(i).dim(j) = second_node->dim(j);
+ num_elems *= second_node->dim(j).value();
+ }
+ _intermediate.at(i).size<loco::DataType::FLOAT32>(num_elems);
+
+ // Check the buffer is initilized with zero
+ for (uint32_t j = 0; j < num_elems; j++)
+ assert(_intermediate.at(i).at<loco::DataType::FLOAT32>(j) == 0.0);
+
+ // Save output names for logging
+ _output_names.emplace_back(first_node->name());
+ }
+}
+
+void MAEPrinter::accum_absolute_error(uint32_t output_idx, const std::shared_ptr<Tensor> &a,
+ const std::shared_ptr<Tensor> &b)
+{
+ assert(a->dtype() == loco::DataType::FLOAT32 and
+ b->dtype() == loco::DataType::FLOAT32); // FIX_CALLER_UNLESS
+ assert(same_shape(a.get(), b.get())); // FIX_CALLER_UNLESS
+ assert(output_idx < _intermediate.size()); // FIX_CALLER_UNLESS
+
+ for (uint32_t i = 0; i < a->size<loco::DataType::FLOAT32>(); i++)
+ {
+ _intermediate.at(output_idx).at<loco::DataType::FLOAT32>(i) +=
+ std::abs(a->at<loco::DataType::FLOAT32>(i) - b->at<loco::DataType::FLOAT32>(i));
+ }
+}
+
+void MAEPrinter::accumulate(const std::vector<std::shared_ptr<Tensor>> &first,
+ const std::vector<std::shared_ptr<Tensor>> &second)
+{
+ assert(first.size() == second.size()); // FIX_CALLER_UNLESS
+ assert(first.size() == _intermediate.size()); // FIX_CALLER_UNLESS
+
+ for (uint32_t output_idx = 0; output_idx < _intermediate.size(); output_idx++)
+ {
+ const auto first_output = first[output_idx];
+ const auto second_output = second[output_idx];
+
+ // Cast data to fp32 and then compute absolute error
+ const auto fp32_first_output = fp32(first_output);
+ const auto fp32_second_output = fp32(second_output);
+
+ accum_absolute_error(output_idx, fp32_first_output, fp32_second_output);
+ }
+
+ _num_data++;
+}
+
+void MAEPrinter::dump(std::ostream &os) const
+{
+ os << "Mean Absolute Error (MAE)" << std::endl;
+
+ for (uint32_t output_idx = 0; output_idx < _intermediate.size(); output_idx++)
+ {
+ const auto name = _output_names.at(output_idx);
+ const auto &inter = _intermediate.at(output_idx);
+ assert(inter.dtype() == loco::DataType::FLOAT32); // FIX_ME_UNLESS
+ const auto elem_count = inter.size<loco::DataType::FLOAT32>();
+
+ // Compute MAE
+ float mae = 0.0;
+ for (uint32_t elem_idx = 0; elem_idx < elem_count; elem_idx++)
+ mae += inter.at<loco::DataType::FLOAT32>(elem_idx);
+
+ mae = mae / elem_count;
+ mae = mae / _num_data;
+
+ os << "MAE for " << name << " is " << mae << std::endl;
+ }
+}
+
+} // namespace circle_eval_diff
+
+#undef THROW_UNLESS
diff --git a/compiler/circle-eval-diff/src/MetricPrinter.h b/compiler/circle-eval-diff/src/MetricPrinter.h
new file mode 100644
index 000000000..b51581c31
--- /dev/null
+++ b/compiler/circle-eval-diff/src/MetricPrinter.h
@@ -0,0 +1,90 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CIRCLE_EVAL_DIFF_METRIC_PRINTER_H__
+#define __CIRCLE_EVAL_DIFF_METRIC_PRINTER_H__
+
+#include <luci/IR/Module.h>
+
+#include "Tensor.h"
+
+#include <vector>
+#include <iostream>
+
+namespace circle_eval_diff
+{
+
+// Class to print metrics
+// How to use?
+//
+// MetricPrinter metric;
+// metric.init(first_module, second_module); // optional initialization
+//
+// for (..) // Evaluate data one by one
+// {
+// ..
+// metric.accumulate(first_result, second_result); // accumulate results
+// }
+//
+// std::cout << &metric << std::endl; // print result
+class MetricPrinter
+{
+public:
+ virtual ~MetricPrinter() = default;
+
+ // Child class can implement this function if necessary
+ // NOTE init can be skipped
+ virtual void init(const luci::Module *, const luci::Module *) {}
+
+ // Accumulate results of comparing the first and the second model's outputs
+ virtual void accumulate(const std::vector<std::shared_ptr<Tensor>> &first,
+ const std::vector<std::shared_ptr<Tensor>> &second) = 0;
+
+ // Dump the final result of the corresponding metric
+ virtual void dump(std::ostream &os) const = 0;
+};
+
+static inline std::ostream &operator<<(std::ostream &os, const MetricPrinter *m)
+{
+ m->dump(os);
+ return os;
+}
+
+// Mean Absolute Error
+class MAEPrinter final : public MetricPrinter
+{
+public:
+ void init(const luci::Module *first, const luci::Module *second);
+
+ void accumulate(const std::vector<std::shared_ptr<Tensor>> &first,
+ const std::vector<std::shared_ptr<Tensor>> &second);
+
+ void dump(std::ostream &os) const;
+
+private:
+ void accum_absolute_error(uint32_t index, const std::shared_ptr<Tensor> &a,
+ const std::shared_ptr<Tensor> &b);
+
+private:
+ // Store accumulated sum of absolute error for each output
+ std::vector<Tensor> _intermediate;
+ std::vector<std::string> _output_names;
+ uint32_t _num_data = 0;
+};
+
+} // namespace circle_eval_diff
+
+#endif // __CIRCLE_EVAL_DIFF_METRIC_PRINTER_H__
diff --git a/compiler/circle-eval-diff/src/MetricPrinter.test.cpp b/compiler/circle-eval-diff/src/MetricPrinter.test.cpp
new file mode 100644
index 000000000..51ca89799
--- /dev/null
+++ b/compiler/circle-eval-diff/src/MetricPrinter.test.cpp
@@ -0,0 +1,236 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "MetricPrinter.h"
+
+#include <luci/test/TestIOGraph.h>
+
+#include <gtest/gtest.h>
+
+using Tensor = circle_eval_diff::Tensor;
+
+namespace
+{
+
+// TODO Reduce duplicate codes in ResolveCustomOpMatMulPass.cpp
+template <typename T>
+luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype,
+ const std::vector<uint32_t> &shape,
+ const std::vector<T> &values)
+{
+ auto node = g->nodes()->create<luci::CircleConst>();
+ node->dtype(dtype);
+ node->rank(shape.size());
+
+ uint32_t size = 1;
+ for (uint32_t i = 0; i < shape.size(); ++i)
+ {
+ node->dim(i) = shape.at(i);
+ size *= shape.at(i);
+ }
+ node->shape_status(luci::ShapeStatus::VALID);
+
+#define INIT_VALUES(DT) \
+ { \
+ node->size<DT>(size); \
+ for (uint32_t i = 0; i < values.size(); ++i) \
+ node->at<DT>(i) = values[i]; \
+ }
+
+ switch (dtype)
+ {
+ case loco::DataType::U8:
+ INIT_VALUES(loco::DataType::U8);
+ break;
+ case loco::DataType::S16:
+ INIT_VALUES(loco::DataType::S16);
+ break;
+ case loco::DataType::S32:
+ INIT_VALUES(loco::DataType::S32);
+ break;
+ case loco::DataType::FLOAT32:
+ INIT_VALUES(loco::DataType::FLOAT32)
+ break;
+ default:
+ INTERNAL_EXN("create_const_node called with unsupported type");
+ break;
+ }
+ return node;
+}
+
+/**
+ * Simple graph which adds constant (addition) to the input
+ *
+ * [Input] [Const] (addition)
+ * \ /
+ * [Add]
+ *
+ */
+class AddGraphlet
+{
+public:
+ AddGraphlet() = default;
+
+ void init(loco::Graph *g, float addition)
+ {
+ std::vector<float> addition_val;
+ for (uint32_t i = 0; i < 16; i++)
+ addition_val.push_back(addition);
+ _add_c = create_const_node(g, loco::DataType::FLOAT32, {1, 16}, addition_val);
+
+ _add = g->nodes()->create<luci::CircleAdd>();
+ _add->y(_add_c);
+ _add->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _add->dtype(loco::DataType::FLOAT32);
+ _add->shape({1, 16});
+ _add->name("add");
+ }
+
+protected:
+ luci::CircleAdd *_add = nullptr;
+ luci::CircleConst *_add_c = nullptr;
+};
+
+class AddOneGraph : public luci::test::TestIOGraph, public AddGraphlet
+{
+public:
+ AddOneGraph() = default;
+
+ void init(void)
+ {
+ luci::test::TestIOGraph::init({1, 4}, {1, 16});
+ AddGraphlet::init(g(), 1.0);
+
+ _add->x(input());
+
+ output()->from(_add);
+ }
+
+ std::unique_ptr<loco::Graph> graph(void) { return std::move(_g); }
+};
+
+class AddTwoGraph : public luci::test::TestIOGraph, public AddGraphlet
+{
+public:
+ AddTwoGraph() = default;
+
+ void init(void)
+ {
+ luci::test::TestIOGraph::init({1, 4}, {1, 16});
+ AddGraphlet::init(g(), 2.0);
+
+ _add->x(input());
+
+ output()->from(_add);
+ }
+
+ std::unique_ptr<loco::Graph> graph(void) { return std::move(_g); }
+};
+
+// Return number of elements of the node.
+uint32_t numElements(const luci::CircleNode *node)
+{
+ uint32_t num_elem = 1;
+ for (uint32_t i = 0; i < node->rank(); ++i)
+ num_elem *= node->dim(i).value();
+ return num_elem;
+}
+
+// Return Tensor which has the same dtype and shape with node.
+// Buffer does not have any data yet.
+std::shared_ptr<Tensor> create_empty_tensor(const luci::CircleNode *node)
+{
+ auto tensor = std::make_shared<Tensor>();
+ {
+ tensor->dtype(node->dtype());
+ tensor->rank(node->rank());
+ for (uint32_t i = 0; i < node->rank(); i++)
+ tensor->dim(i) = node->dim(i);
+ tensor->size<loco::DataType::FLOAT32>(numElements(node));
+ }
+
+ return tensor;
+}
+
+std::shared_ptr<Tensor> output_tensor_with_value(const luci::Module *module, float value)
+{
+ auto outputs = loco::output_nodes(module->graph());
+ assert(outputs.size() == 1);
+ auto output = *outputs.begin();
+ auto output_cnode = loco::must_cast<luci::CircleNode *>(output);
+ auto tensor = create_empty_tensor(output_cnode);
+ auto tensor_size = tensor->size<loco::DataType::FLOAT32>();
+ for (uint32_t i = 0; i < tensor_size; i++)
+ {
+ tensor->at<loco::DataType::FLOAT32>(i) = value;
+ }
+ return tensor;
+}
+
+} // namespace
+
+namespace circle_eval_diff
+{
+
+TEST(CircleEvalMetricPrinterTest, MAE_simple)
+{
+ luci::Module first;
+ AddOneGraph first_g;
+ first_g.init();
+
+ first.add(std::move(first_g.graph()));
+
+ luci::Module second;
+ AddTwoGraph second_g;
+ second_g.init();
+
+ second.add(std::move(second_g.graph()));
+
+ MAEPrinter mae;
+
+ mae.init(&first, &second);
+
+ // This test does not actually evaluate the modules, but create
+ // fake results.
+ std::vector<std::shared_ptr<Tensor>> first_result;
+ {
+ auto output = output_tensor_with_value(&first, 1.0);
+ first_result.emplace_back(output);
+ }
+
+ std::vector<std::shared_ptr<Tensor>> second_result;
+ {
+ auto output = output_tensor_with_value(&second, 2.0);
+ second_result.emplace_back(output);
+ }
+
+ mae.accumulate(first_result, second_result);
+
+ std::stringstream ss;
+ mae.dump(ss);
+ std::string result = ss.str();
+
+ EXPECT_NE(std::string::npos, result.find("MAE for output_0 is 1"));
+}
+
+TEST(CircleEvalMetricPrinterTest, MAE_init_with_null_NEG)
+{
+ MAEPrinter mae;
+
+ EXPECT_ANY_THROW(mae.init(nullptr, nullptr));
+}
+
+} // namespace circle_eval_diff
diff --git a/compiler/circle-eval-diff/src/ModuleEvalDiff.cpp b/compiler/circle-eval-diff/src/ModuleEvalDiff.cpp
new file mode 100644
index 000000000..85f985873
--- /dev/null
+++ b/compiler/circle-eval-diff/src/ModuleEvalDiff.cpp
@@ -0,0 +1,216 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ModuleEvalDiff.h"
+#include "Tensor.h"
+
+#include <luci_interpreter/Interpreter.h>
+#include <dio_hdf5/HDF5Importer.h>
+
+#include <string>
+#include <stdexcept>
+#include <iostream>
+#include <cassert>
+
+using Tensor = circle_eval_diff::Tensor;
+using DataType = loco::DataType;
+using Shape = std::vector<loco::Dimension>;
+using HDF5Importer = dio::hdf5::HDF5Importer;
+
+namespace
+{
+
+// Check the type and the shape of CircleInput
+void verifyTypeShape(const luci::CircleInput *input_node, const DataType &dtype, const Shape &shape)
+{
+ // Type check
+ if (dtype != input_node->dtype())
+ throw std::runtime_error("Wrong input type.");
+
+ if (shape.size() != input_node->rank())
+ throw std::runtime_error("Input rank mismatch.");
+
+ for (uint32_t i = 0; i < shape.size(); i++)
+ {
+ if (not(shape.at(i) == input_node->dim(i)))
+ throw std::runtime_error("Input shape mismatch.");
+ }
+}
+
+// Return number of elements of the node.
+uint32_t numElements(const luci::CircleNode *node)
+{
+ uint32_t num_elem = 1;
+ for (uint32_t i = 0; i < node->rank(); ++i)
+ num_elem *= node->dim(i).value();
+ return num_elem;
+}
+
+// Return Tensor which has the same dtype and shape with node.
+// Buffer does not have any data yet.
+std::shared_ptr<Tensor> createEmptyTensor(const luci::CircleNode *node)
+{
+ auto tensor = std::make_shared<Tensor>();
+ {
+ tensor->dtype(node->dtype());
+ tensor->rank(node->rank());
+ for (uint32_t i = 0; i < node->rank(); i++)
+ tensor->dim(i) = node->dim(i);
+
+ switch (node->dtype())
+ {
+ case loco::DataType::FLOAT32:
+ tensor->size<loco::DataType::FLOAT32>(numElements(node));
+ break;
+ case loco::DataType::U8:
+ tensor->size<loco::DataType::U8>(numElements(node));
+ break;
+ case loco::DataType::S16:
+ tensor->size<loco::DataType::S16>(numElements(node));
+ break;
+ case loco::DataType::S32:
+ tensor->size<loco::DataType::S32>(numElements(node));
+ break;
+ case loco::DataType::S64:
+ tensor->size<loco::DataType::S64>(numElements(node));
+ break;
+ default:
+ throw std::runtime_error("Unsupported input tensor dtype for " + node->name());
+ }
+ }
+
+ return tensor;
+}
+
+} // namespace
+
+namespace circle_eval_diff
+{
+
+void H5InputEvalDiff::evalDiff(const std::string &first_input_data_path,
+ const std::string &second_input_data_path) const
+{
+ const auto interp = std::make_unique<luci_interpreter::Interpreter>(_first_module.get());
+
+ _metric->init(_first_module.get(), _second_module.get());
+
+ try
+ {
+ HDF5Importer first_h5(first_input_data_path);
+ first_h5.importGroup("value");
+
+ HDF5Importer second_h5(second_input_data_path);
+ second_h5.importGroup("value");
+
+ const auto first_num_data = first_h5.numData();
+ const auto second_num_data = second_h5.numData();
+
+ if (first_num_data != second_num_data)
+ throw std::runtime_error(
+ "Number of data in the first data file and the second data file mismatches.");
+
+ if (first_num_data == 0)
+ throw std::runtime_error("Input data file does not contain any record.");
+
+ const auto first_input_nodes = loco::input_nodes(_first_module->graph());
+ const auto first_num_inputs = first_input_nodes.size();
+ const auto first_output_nodes = loco::output_nodes(_first_module->graph());
+ const auto first_num_outputs = first_output_nodes.size();
+
+ const auto second_input_nodes = loco::input_nodes(_second_module->graph());
+ const auto second_num_inputs = second_input_nodes.size();
+ const auto second_output_nodes = loco::output_nodes(_second_module->graph());
+ const auto second_num_outputs = second_output_nodes.size();
+
+ for (int32_t data_idx = 0; data_idx < first_num_data; data_idx++)
+ {
+ std::cout << "Evaluating " << data_idx << "'th data" << std::endl;
+
+ if (first_num_inputs != first_h5.numInputs(data_idx) ||
+ second_num_inputs != second_h5.numInputs(data_idx))
+ throw std::runtime_error("Wrong number of inputs in " + std::to_string(data_idx) +
+ "th data.");
+
+ // Do inference and return output
+ auto eval = [&](HDF5Importer &h5, uint32_t num_inputs,
+ const std::vector<loco::Node *> &input_nodes, uint32_t num_outputs,
+ const std::vector<loco::Node *> &output_nodes) {
+ // Write input data
+ for (uint32_t input_idx = 0; input_idx < num_inputs; input_idx++)
+ {
+ const auto *input_node =
+ loco::must_cast<const luci::CircleInput *>(input_nodes[input_idx]);
+ assert(input_node->index() == input_idx);
+
+ auto tensor = createEmptyTensor(input_node);
+ if (h5.isRawData())
+ {
+ h5.readTensor(data_idx, input_idx, tensor->buffer());
+ }
+ else
+ {
+ DataType dtype;
+ Shape shape;
+ h5.readTensor(data_idx, input_idx, &dtype, &shape, tensor->buffer());
+
+ // Check the type and the shape of the input data is valid
+ verifyTypeShape(input_node, dtype, shape);
+ }
+
+ interp->writeInputTensor(input_node, tensor->buffer(), tensor->byte_size());
+ }
+
+ // Interpret
+ interp->interpret();
+
+ // Read output data
+ std::vector<std::shared_ptr<Tensor>> outputs;
+ for (uint32_t output_idx = 0; output_idx < num_outputs; output_idx++)
+ {
+ const auto *output_node =
+ loco::must_cast<const luci::CircleOutput *>(output_nodes[output_idx]);
+ assert(output_node->index() == output_idx);
+
+ auto tensor = createEmptyTensor(output_node);
+ interp->readOutputTensor(output_node, tensor->buffer(), tensor->byte_size());
+ outputs.emplace_back(tensor);
+ }
+
+ return outputs;
+ };
+
+ auto first_output =
+ eval(first_h5, first_num_inputs, first_input_nodes, first_num_outputs, first_output_nodes);
+ auto second_output = eval(second_h5, second_num_inputs, second_input_nodes,
+ second_num_outputs, second_output_nodes);
+
+ // Accumulate diffs
+ _metric->accumulate(first_output, second_output);
+ }
+
+ std::cout << "Evaluation finished. Number of data: " << first_num_data << std::endl;
+ }
+ catch (const H5::Exception &e)
+ {
+ H5::Exception::printErrorStack();
+ throw std::runtime_error("HDF5 error occurred.");
+ }
+
+ // Print metric
+ std::cout << _metric.get() << std::endl;
+}
+
+} // namespace circle_eval_diff
diff --git a/compiler/circle-eval-diff/src/ModuleEvalDiff.h b/compiler/circle-eval-diff/src/ModuleEvalDiff.h
new file mode 100644
index 000000000..c7642f60b
--- /dev/null
+++ b/compiler/circle-eval-diff/src/ModuleEvalDiff.h
@@ -0,0 +1,67 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CIRCLE_EVAL_DIFF_MODULE_EVAL_DIFF_H__
+#define __CIRCLE_EVAL_DIFF_MODULE_EVAL_DIFF_H__
+
+#include "MetricPrinter.h"
+
+#include <luci/IR/Module.h>
+
+#include <memory>
+
+namespace circle_eval_diff
+{
+
+class ModuleEvalDiff
+{
+public:
+ ModuleEvalDiff(std::unique_ptr<luci::Module> &&first, std::unique_ptr<luci::Module> &&second,
+ std::unique_ptr<MetricPrinter> &&metric)
+ : _first_module(std::move(first)), _second_module(std::move(second)), _metric(std::move(metric))
+ {
+ }
+
+ virtual ~ModuleEvalDiff() = default;
+
+ // Implement this in the child class
+ virtual void evalDiff(const std::string &first_input_data_path,
+ const std::string &second_input_data_path) const = 0;
+
+protected:
+ std::unique_ptr<luci::Module> _first_module;
+ std::unique_ptr<luci::Module> _second_module;
+ std::unique_ptr<MetricPrinter> _metric;
+};
+
+class H5InputEvalDiff final : public ModuleEvalDiff
+{
+public:
+ H5InputEvalDiff(std::unique_ptr<luci::Module> &&first, std::unique_ptr<luci::Module> &&second,
+ std::unique_ptr<MetricPrinter> &&metric)
+ : ModuleEvalDiff(std::move(first), std::move(second), std::move(metric))
+ {
+ }
+
+ void evalDiff(const std::string &first_input_data_path,
+ const std::string &second_input_data_path) const;
+};
+
+// TODO Implement ModuleEvalDiff for random input and directory input
+
+} // namespace circle_eval_diff
+
+#endif // __CIRCLE_EVAL_DIFF_MODULE_EVAL_DIFF_H__
diff --git a/compiler/circle-eval-diff/src/Tensor.cpp b/compiler/circle-eval-diff/src/Tensor.cpp
new file mode 100644
index 000000000..6710e8c3d
--- /dev/null
+++ b/compiler/circle-eval-diff/src/Tensor.cpp
@@ -0,0 +1,72 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Tensor.h"
+
+#include <cassert>
+
+namespace circle_eval_diff
+{
+
+#define THROW_UNLESS(COND, MSG) \
+ if (not(COND)) \
+ throw std::runtime_error(MSG);
+
+template <loco::DataType DT> uint32_t Tensor::size(void) const
+{
+ assert(dtype() == DT);
+ assert(_data.size() % sizeof(typename loco::DataTypeImpl<DT>::Type) == 0);
+ return _data.size() / sizeof(typename loco::DataTypeImpl<DT>::Type);
+}
+
+template <loco::DataType DT> void Tensor::size(uint32_t l)
+{
+ assert(dtype() == DT);
+ _data.resize(l * sizeof(typename loco::DataTypeImpl<DT>::Type));
+}
+
+template <loco::DataType DT>
+const typename loco::DataTypeImpl<DT>::Type &Tensor::at(uint32_t n) const
+{
+ assert(dtype() == DT);
+ THROW_UNLESS(n < size<DT>(), "Access to out of buffer boundary.");
+ return *(reinterpret_cast<const typename loco::DataTypeImpl<DT>::Type *>(_data.data()) + n);
+}
+
+template <loco::DataType DT> typename loco::DataTypeImpl<DT>::Type &Tensor::at(uint32_t n)
+{
+ assert(dtype() == DT);
+ THROW_UNLESS(n < size<DT>(), "Access to out of buffer boundary.");
+ return *(reinterpret_cast<typename loco::DataTypeImpl<DT>::Type *>(_data.data()) + n);
+}
+
+#undef THROW_UNLESS
+
+#define INSTANTIATE(DT) \
+ template uint32_t Tensor::size<DT>(void) const; \
+ template void Tensor::size<DT>(uint32_t); \
+ template const typename loco::DataTypeImpl<DT>::Type &Tensor::at<DT>(uint32_t) const; \
+ template typename loco::DataTypeImpl<DT>::Type &Tensor::at<DT>(uint32_t);
+
+INSTANTIATE(loco::DataType::S64);
+INSTANTIATE(loco::DataType::S32);
+INSTANTIATE(loco::DataType::S16);
+INSTANTIATE(loco::DataType::U8);
+INSTANTIATE(loco::DataType::FLOAT32);
+
+#undef INSTANTIATE
+
+} // namespace circle_eval_diff
diff --git a/compiler/circle-eval-diff/src/Tensor.h b/compiler/circle-eval-diff/src/Tensor.h
new file mode 100644
index 000000000..65ab60638
--- /dev/null
+++ b/compiler/circle-eval-diff/src/Tensor.h
@@ -0,0 +1,81 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CIRCLE_EVAL_DIFF_TENSOR_H__
+#define __CIRCLE_EVAL_DIFF_TENSOR_H__
+
+#include <loco.h>
+
+#include <vector>
+
+namespace circle_eval_diff
+{
+
+struct TensorDataType
+{
+public:
+ const loco::DataType &dtype(void) const { return _dtype; }
+ void dtype(const loco::DataType &dtype) { _dtype = dtype; }
+
+private:
+ loco::DataType _dtype = loco::DataType::Unknown;
+};
+
+struct TensorShape
+{
+public:
+ uint32_t rank(void) const { return _dims.size(); }
+ void rank(uint32_t value) { _dims.resize(value); }
+
+ const loco::Dimension &dim(uint32_t axis) const { return _dims.at(axis); }
+ loco::Dimension &dim(uint32_t axis) { return _dims.at(axis); }
+
+ void shape(std::initializer_list<uint32_t> dims)
+ {
+ rank(dims.size());
+
+ uint32_t axis = 0;
+ for (auto d : dims)
+ {
+ dim(axis++) = d;
+ }
+ }
+
+private:
+ std::vector<loco::Dimension> _dims;
+};
+
+// Tensor has three kinds of data
+// 1. DataType (_dtype)
+// 2. Shape (_dims)
+// 3. Buffer (_data)
+struct Tensor final : public TensorShape, public TensorDataType
+{
+public:
+ template <loco::DataType DT> uint32_t size(void) const;
+ template <loco::DataType DT> void size(uint32_t size);
+ template <loco::DataType DT> const typename loco::DataTypeImpl<DT>::Type &at(uint32_t n) const;
+ template <loco::DataType DT> typename loco::DataTypeImpl<DT>::Type &at(uint32_t n);
+ uint8_t *buffer(void) { return _data.data(); }
+ uint32_t byte_size(void) const { return _data.size(); }
+
+private:
+ std::vector<uint8_t> _data;
+};
+
+} // namespace circle_eval_diff
+
+#endif // __CIRCLE_EVAL_DIFF_TENSOR_H__
diff --git a/compiler/circle-eval-diff/src/Tensor.test.cpp b/compiler/circle-eval-diff/src/Tensor.test.cpp
new file mode 100644
index 000000000..3bdeaecdf
--- /dev/null
+++ b/compiler/circle-eval-diff/src/Tensor.test.cpp
@@ -0,0 +1,101 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Tensor.h"
+
+#include <gtest/gtest.h>
+
+using Tensor = circle_eval_diff::Tensor;
+
+namespace
+{
+
+template <loco::DataType DT> void test_out_of_buffer_range()
+{
+ Tensor t;
+
+ t.shape({1, 2, 3});
+ t.dtype(DT);
+ t.size<DT>(6);
+
+ EXPECT_ANY_THROW(t.at<DT>(6));
+}
+
+template <loco::DataType DT> void test_getter_setter()
+{
+ Tensor t;
+
+ // Check shape
+ t.shape({1, 2, 3});
+ EXPECT_EQ(3, t.rank());
+ EXPECT_EQ(1, t.dim(0));
+ EXPECT_EQ(2, t.dim(1));
+ EXPECT_EQ(3, t.dim(2));
+
+ // Check dtype
+ t.dtype(DT);
+ EXPECT_EQ(DT, t.dtype());
+
+ // Check buffer
+ t.size<DT>(6);
+ EXPECT_EQ(6 * sizeof(typename loco::DataTypeImpl<DT>::Type), t.byte_size());
+ for (uint32_t i = 0; i < 6; i++)
+ t.at<DT>(i) = i;
+
+ for (uint32_t i = 0; i < 6; i++)
+ EXPECT_EQ(i, t.at<DT>(i));
+}
+
+} // namespace
+
+TEST(CircleEvalDiffTensorTest, constructor)
+{
+ Tensor t;
+
+ EXPECT_EQ(0, t.byte_size());
+ EXPECT_EQ(0, t.rank());
+ EXPECT_EQ(loco::DataType::Unknown, t.dtype());
+}
+
+TEST(CircleEvalDiffTensorTest, getter_setter)
+{
+ test_getter_setter<loco::DataType::S64>();
+ test_getter_setter<loco::DataType::S32>();
+ test_getter_setter<loco::DataType::S16>();
+ test_getter_setter<loco::DataType::U8>();
+ test_getter_setter<loco::DataType::FLOAT32>();
+
+ SUCCEED();
+}
+
+TEST(CircleEvalDiffTensorTest, out_of_shape_range_NEG)
+{
+ Tensor t;
+ t.shape({1, 2, 2, 3});
+
+ EXPECT_ANY_THROW(t.dim(4));
+}
+
+TEST(CircleEvalDiffTensorTest, out_of_buffer_range_NEG)
+{
+ test_out_of_buffer_range<loco::DataType::S64>();
+ test_out_of_buffer_range<loco::DataType::S32>();
+ test_out_of_buffer_range<loco::DataType::S16>();
+ test_out_of_buffer_range<loco::DataType::U8>();
+ test_out_of_buffer_range<loco::DataType::FLOAT32>();
+
+ SUCCEED();
+}
diff --git a/compiler/circle-execution-plan/CMakeLists.txt b/compiler/circle-execution-plan/CMakeLists.txt
index 115d24860..2f657c171 100644
--- a/compiler/circle-execution-plan/CMakeLists.txt
+++ b/compiler/circle-execution-plan/CMakeLists.txt
@@ -1,4 +1,9 @@
set(SOURCES
+ pal/IScratchpadHelper.h
+ pal/ScratchpadHelperLinux.h
+ pal/ScratchpadHelperMCU.h
+ pal/ScratchpadHelperCMSISNN.h
+ pal/TargetPlatform.h
src/CircleExecutionPlan.cpp
src/ExecutionPlanner.cpp
src/ExecutionPlanner.h
@@ -13,4 +18,5 @@ target_link_libraries(circle_execution_plan luci_export)
target_link_libraries(circle_execution_plan luci_plan)
target_link_libraries(circle_execution_plan arser)
+target_include_directories(circle_execution_plan PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/pal")
install(TARGETS circle_execution_plan DESTINATION bin)
diff --git a/compiler/circle-execution-plan/README.md b/compiler/circle-execution-plan/README.md
index e789a55db..dbb7d4f85 100644
--- a/compiler/circle-execution-plan/README.md
+++ b/compiler/circle-execution-plan/README.md
@@ -10,13 +10,12 @@ The output circle file contains plan (`CircleNodeMemoryPlan`) information for ev
- number which determines order in which nodes will be executed
- memory offsets for node output tensors from the beginning of shared memory buffer
-In order to record and read this metadata, we use `CircleImportMetadata` and `CircleExportMetadata`.
-For this purpose we use `std::map<uint32_t, std::vector<uint32_t>> _memory_plan_table` which for each node with key ID contains encoded `CircleNodeMemoryPlan` data.
+In order to record and read this data, we use `luci::CircleNodeExecutionPlan`.
### Execution plan building
In order to build "execution plan" we use `ExecutionPlanner` class.
-The main method is `get_execution_plan()` which for each node finds and writes to its annotations
+The main method is `make_execution_plan()` which for each node finds and writes to its annotations
"execution plan". For this purpose there are two steps:
- determining the order of execution of nodes, which is stored in `_ordered_nodes` vector.
Now for this purpose there is only one default method `get_default_execution_order_plan()` that uses `loco::postorder_traversal(const std::vector<loco::Node *> &roots)`.
diff --git a/compiler/circle-execution-plan/pal/IScratchpadHelper.h b/compiler/circle-execution-plan/pal/IScratchpadHelper.h
new file mode 100644
index 000000000..f5a991526
--- /dev/null
+++ b/compiler/circle-execution-plan/pal/IScratchpadHelper.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef CIRCLE_EXECUTION_PLAN_ISRCRATCHPAD_HELPER_H
+#define CIRCLE_EXECUTION_PLAN_ISRCRATCHPAD_HELPER_H
+
+#include <luci/IR/Nodes/CircleAveragePool2D.h>
+#include <luci/IR/Nodes/CircleBatchMatMul.h>
+#include <luci/IR/Nodes/CircleConv2D.h>
+#include <luci/IR/Nodes/CircleDepthwiseConv2D.h>
+#include <luci/IR/Nodes/CircleSVDF.h>
+#include <cstdint>
+
+namespace circle_planner
+{
+
+class IScratchpadHelper
+{
+public:
+ virtual uint32_t
+ ComputeScratchpadSizeAveragePool2d(const luci::CircleAveragePool2D *avg_pool) = 0;
+
+ virtual std::vector<uint32_t>
+ ComputeScratchpadSizeBatchMatMul(const luci::CircleBatchMatMul *batch_mat_mul) = 0;
+
+ virtual uint32_t ComputeScratchpadSizeConv2d(const luci::CircleConv2D *conv) = 0;
+
+ virtual uint32_t
+ ComputeScratchpadSizeDepthwiseConv2d(const luci::CircleDepthwiseConv2D *depthwise_conv) = 0;
+
+ virtual std::vector<uint32_t> ComputeScratchpadSizeSVDF(const luci::CircleSVDF *svdf) = 0;
+
+ virtual ~IScratchpadHelper() = default;
+};
+
+} // namespace circle_planner
+
+#endif // CIRCLE_EXECUTION_PLAN_ISRCRATCHPAD_HELPER_H
diff --git a/compiler/circle-execution-plan/pal/ScratchpadHelperCMSISNN.h b/compiler/circle-execution-plan/pal/ScratchpadHelperCMSISNN.h
new file mode 100644
index 000000000..5369c0937
--- /dev/null
+++ b/compiler/circle-execution-plan/pal/ScratchpadHelperCMSISNN.h
@@ -0,0 +1,187 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef CIRCLE_EXECUTION_PLAN_SCRATCHPAD_HELPER_CMSISNN_H
+#define CIRCLE_EXECUTION_PLAN_SCRATCHPAD_HELPER_CMSISNN_H
+
+#include "IScratchpadHelper.h"
+#include <cassert>
+
+namespace circle_planner
+{
+
+namespace
+{
+
+inline int32_t computePadding(int32_t stride, int32_t dilation_rate, int32_t in_size,
+ int32_t filter_size, int32_t out_size)
+{
+ const int32_t effective_filter_size = (filter_size - 1) * dilation_rate + 1;
+ const int32_t padding = ((out_size - 1) * stride + effective_filter_size - in_size) / 2;
+ return padding > 0 ? padding : 0;
+}
+
+} // namespace
+
+class ScratchpadHelperCMSISNN : public IScratchpadHelper
+{
+public:
+ explicit ScratchpadHelperCMSISNN(bool use_dsp) : _use_dsp(use_dsp)
+ {
+ // Do nothing
+ }
+
+ uint32_t ComputeScratchpadSizeAveragePool2d(const luci::CircleAveragePool2D *avg_pool) final
+ {
+ // Main logic of arm_avgpool_s8_get_buffer_size
+
+ const auto avg_pool_input = loco::must_cast<luci::CircleNode *>(avg_pool->value());
+
+ if (avg_pool_input->dtype() != loco::DataType::S8 or !_use_dsp)
+ return 0;
+
+ const auto depth = static_cast<int32_t>(avg_pool_input->dim(3).value());
+
+ return depth * sizeof(int32_t);
+ }
+
+ std::vector<uint32_t>
+ ComputeScratchpadSizeBatchMatMul(const luci::CircleBatchMatMul *batch_mat_mul) final
+ {
+ throw std::runtime_error("BatchMatMul is not currently supported for cmsisnn platform");
+ }
+
+ uint32_t ComputeScratchpadSizeConv2d(const luci::CircleConv2D *conv) final
+ {
+ // Main logic of arm_convolve_wrapper_s8_get_buffer_size
+
+ const auto dilation_height_factor = static_cast<int32_t>(conv->dilation()->h());
+ const auto dilation_width_factor = static_cast<int32_t>(conv->dilation()->w());
+
+ const auto conv_input = loco::must_cast<luci::CircleNode *>(conv->input());
+ const auto filter = loco::must_cast<luci::CircleNode *>(conv->filter());
+
+ if (dilation_width_factor != 1 or dilation_height_factor != 1 or
+ conv_input->dtype() != loco::DataType::S8)
+ {
+ return 0;
+ }
+
+ const auto input_depth = static_cast<int32_t>(conv_input->dim(3).value());
+
+ const auto input_height = static_cast<int32_t>(conv_input->dim(1).value());
+ const auto input_width = static_cast<int32_t>(conv_input->dim(2).value());
+
+ const auto filter_height = static_cast<int32_t>(filter->dim(1).value());
+ const auto filter_width = static_cast<int32_t>(filter->dim(2).value());
+
+ const auto stride_height = static_cast<int32_t>(conv->stride()->h());
+ const auto stride_width = static_cast<int32_t>(conv->stride()->w());
+
+ const auto output_height = static_cast<int32_t>(conv->dim(1).value());
+ const auto output_width = static_cast<int32_t>(conv->dim(2).value());
+
+ assert(conv_input->quantparam()->zerop.size() == 1);
+ assert(conv->quantparam()->zerop.size() == 1);
+
+ const auto padding_height = computePadding(stride_height, dilation_height_factor, input_height,
+ filter_height, output_height);
+ const auto padding_width =
+ computePadding(stride_width, dilation_width_factor, input_width, filter_width, output_width);
+
+ if ((padding_width == 0) && (padding_height == 0) && (input_depth % 4 == 0) &&
+ (stride_width == 1) && (stride_height == 1) && (filter_width == 1) && (filter_height == 1))
+ {
+ return 0;
+ }
+
+ if (_use_dsp)
+ {
+ return (2 * input_depth * filter_width * filter_height) * sizeof(int16_t);
+ }
+
+ return 0;
+ }
+
+ uint32_t
+ ComputeScratchpadSizeDepthwiseConv2d(const luci::CircleDepthwiseConv2D *depthwise_conv) final
+ {
+ // Main logic of arm_depthwise_conv_wrapper_s8_get_buffer_size
+
+ const auto dilation_height_factor = static_cast<int32_t>(depthwise_conv->dilation()->h());
+ const auto dilation_width_factor = static_cast<int32_t>(depthwise_conv->dilation()->w());
+
+ const auto depthwise_conv_input = loco::must_cast<luci::CircleNode *>(depthwise_conv->input());
+ const auto filter = loco::must_cast<luci::CircleNode *>(depthwise_conv->filter());
+
+ if (dilation_width_factor != 1 or dilation_height_factor != 1 or
+ depthwise_conv_input->dtype() != loco::DataType::S8)
+ {
+ return 0;
+ }
+
+ const auto input_depth = static_cast<int32_t>(depthwise_conv_input->dim(3).value());
+ const auto output_depth = static_cast<int32_t>(depthwise_conv->dim(3).value());
+ const auto batch_size = static_cast<int32_t>(depthwise_conv_input->dim(0).value());
+
+ if (input_depth != output_depth or batch_size != 1 or !_use_dsp)
+ return 0;
+
+ const auto filter_height = static_cast<int32_t>(filter->dim(1).value());
+ const auto filter_width = static_cast<int32_t>(filter->dim(2).value());
+
+ return input_depth * filter_height * filter_width * sizeof(int16_t);
+ }
+
+ std::vector<uint32_t> ComputeScratchpadSizeSVDF(const luci::CircleSVDF *svdf) final
+ {
+ const auto svdf_input = loco::must_cast<luci::CircleNode *>(svdf->input());
+ const auto weight_feature_input = loco::must_cast<luci::CircleNode *>(svdf->weight_feature());
+
+ if (svdf_input->dtype() == loco::DataType::FLOAT32 and
+ (weight_feature_input->dtype() == loco::DataType::S8 or
+ weight_feature_input->dtype() == loco::DataType::U8))
+ {
+ throw std::runtime_error("Hybrid type is not currently supported for linux platform");
+ }
+
+ std::vector<uint32_t> scratchpad_sizes;
+
+ const auto batch_size = svdf_input->dim(0).value();
+ const auto num_filters = weight_feature_input->dim(0).value();
+ const auto rank = svdf->svdf_rank();
+ const auto num_units = num_filters / rank;
+
+ if (svdf_input->dtype() == loco::DataType::S8)
+ {
+ scratchpad_sizes.push_back(batch_size * num_filters * sizeof(int32_t));
+ scratchpad_sizes.push_back(batch_size * num_units * sizeof(int32_t));
+ }
+ else
+ {
+ scratchpad_sizes.push_back(batch_size * num_filters * sizeof(float));
+ }
+
+ return scratchpad_sizes;
+ }
+
+private:
+ bool _use_dsp;
+};
+
+} // namespace circle_planner
+
+#endif // CIRCLE_EXECUTION_PLAN_SCRATCHPAD_HELPER_CMSISNN_H
diff --git a/compiler/circle-execution-plan/pal/ScratchpadHelperLinux.h b/compiler/circle-execution-plan/pal/ScratchpadHelperLinux.h
new file mode 100644
index 000000000..811aa67c3
--- /dev/null
+++ b/compiler/circle-execution-plan/pal/ScratchpadHelperLinux.h
@@ -0,0 +1,137 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef CIRCLE_EXECUTION_PLAN_SCRATCHPAD_HELPER_LINUX_H
+#define CIRCLE_EXECUTION_PLAN_SCRATCHPAD_HELPER_LINUX_H
+
+#include "IScratchpadHelper.h"
+#include <loco/IR/DataTypeTraits.h>
+
+namespace circle_planner
+{
+
+class ScratchpadHelperLinux : public IScratchpadHelper
+{
+public:
+ uint32_t ComputeScratchpadSizeAveragePool2d(const luci::CircleAveragePool2D *avg_pool) final
+ {
+ // for linux AveragePool2d scratchpad tensors size = 0
+ return 0;
+ }
+
+ std::vector<uint32_t>
+ ComputeScratchpadSizeBatchMatMul(const luci::CircleBatchMatMul *batch_mat_mul) final
+ {
+ const auto lhs = loco::must_cast<luci::CircleNode *>(batch_mat_mul->x());
+ const auto rhs = loco::must_cast<luci::CircleNode *>(batch_mat_mul->y());
+
+ std::vector<uint32_t> scratchpad_sizes;
+
+ // Scratchpad for lhs
+ uint32_t scratchpad_size = 1;
+ for (int32_t i = 0; i < lhs->rank(); ++i)
+ scratchpad_size *= lhs->dim(i).value();
+
+ scratchpad_sizes.push_back(scratchpad_size * loco::size(lhs->dtype()));
+
+ // Scratchpad for rhs
+ scratchpad_size = 1;
+ for (int32_t i = 0; i < rhs->rank(); ++i)
+ scratchpad_size *= rhs->dim(i).value();
+
+ scratchpad_sizes.push_back(scratchpad_size * loco::size(rhs->dtype()));
+
+ return scratchpad_sizes;
+ }
+
+ uint32_t ComputeScratchpadSizeConv2d(const luci::CircleConv2D *conv) final
+ {
+ const auto conv_input = loco::must_cast<luci::CircleNode *>(conv->input());
+ const auto filter = loco::must_cast<luci::CircleNode *>(conv->filter());
+
+ const uint32_t stride_height = conv->stride()->h();
+ const uint32_t stride_width = conv->stride()->w();
+
+ const uint32_t dilation_height_factor = conv->dilation()->h();
+ const uint32_t dilation_width_factor = conv->dilation()->w();
+
+ const uint32_t filter_height = filter->dim(1).value();
+ const uint32_t filter_width = filter->dim(2).value();
+
+ const bool need_dilated_im2col = dilation_height_factor != 1 || dilation_width_factor != 1;
+ const bool need_non_dilated_im2col =
+ stride_height != 1 || stride_width != 1 || filter_height != 1 || filter_width != 1;
+ const bool need_im2col = conv_input->dtype() != loco::DataType::S16 &&
+ (need_dilated_im2col || need_non_dilated_im2col);
+
+ if (!need_im2col)
+ {
+ return 0;
+ }
+
+ const uint32_t input_depth = conv_input->dim(3).value();
+ const uint32_t batches = conv_input->dim(0).value();
+
+ const uint32_t output_height = conv->dim(1).value();
+ const uint32_t output_width = conv->dim(2).value();
+
+ return batches * output_height * output_width * input_depth * filter_height * filter_width *
+ size(conv_input->dtype());
+ }
+
+ uint32_t
+ ComputeScratchpadSizeDepthwiseConv2d(const luci::CircleDepthwiseConv2D *depthwise_conv) final
+ {
+ // for linux DepthwiseConv2d scratchpad tensors size = 0
+ return 0;
+ }
+
+ std::vector<uint32_t> ComputeScratchpadSizeSVDF(const luci::CircleSVDF *svdf) final
+ {
+ const auto svdf_input = loco::must_cast<luci::CircleNode *>(svdf->input());
+ const auto weight_feature_input = loco::must_cast<luci::CircleNode *>(svdf->weight_feature());
+
+ if (svdf_input->dtype() == loco::DataType::FLOAT32 and
+ (weight_feature_input->dtype() == loco::DataType::S8 or
+ weight_feature_input->dtype() == loco::DataType::U8))
+ {
+ throw std::runtime_error("Hybrid type is not currently supported for linux platform");
+ }
+
+ std::vector<uint32_t> scratchpad_sizes;
+
+ const auto batch_size = svdf_input->dim(0).value();
+ const auto num_filters = weight_feature_input->dim(0).value();
+ const auto rank = svdf->svdf_rank();
+ const auto num_units = num_filters / rank;
+
+ if (svdf_input->dtype() == loco::DataType::S8)
+ {
+ scratchpad_sizes.push_back(batch_size * num_filters * sizeof(int32_t));
+ scratchpad_sizes.push_back(batch_size * num_units * sizeof(int32_t));
+ }
+ else
+ {
+ scratchpad_sizes.push_back(batch_size * num_filters * sizeof(float));
+ }
+
+ return scratchpad_sizes;
+ }
+};
+
+} // namespace circle_planner
+
+#endif // CIRCLE_EXECUTION_PLAN_SCRATCHPAD_HELPER_LINUX_H
diff --git a/compiler/circle-execution-plan/pal/ScratchpadHelperMCU.h b/compiler/circle-execution-plan/pal/ScratchpadHelperMCU.h
new file mode 100644
index 000000000..14b41640c
--- /dev/null
+++ b/compiler/circle-execution-plan/pal/ScratchpadHelperMCU.h
@@ -0,0 +1,88 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef CIRCLE_EXECUTION_PLAN_SCRATCHPAD_HELPER_MCU_H
+#define CIRCLE_EXECUTION_PLAN_SCRATCHPAD_HELPER_MCU_H
+
+#include "IScratchpadHelper.h"
+
+namespace circle_planner
+{
+
+class ScratchpadHelperMCU : public IScratchpadHelper
+{
+public:
+ uint32_t ComputeScratchpadSizeAveragePool2d(const luci::CircleAveragePool2D *avg_pool) final
+ {
+ // for mcu AveragePool2d scratchpad tensors size = 0
+ return 0;
+ }
+
+ std::vector<uint32_t>
+ ComputeScratchpadSizeBatchMatMul(const luci::CircleBatchMatMul *batch_mat_mul) final
+ {
+ throw std::runtime_error("BatchMatMul is not currently supported for mcu platform");
+ }
+
+ uint32_t ComputeScratchpadSizeConv2d(const luci::CircleConv2D *) final
+ {
+ // for mcu scratchpad size = 0
+ return 0;
+ }
+
+ uint32_t
+ ComputeScratchpadSizeDepthwiseConv2d(const luci::CircleDepthwiseConv2D *depthwise_conv) final
+ {
+ // for mcu DepthwiseConv2d scratchpad tensors size = 0
+ return 0;
+ }
+
+ std::vector<uint32_t> ComputeScratchpadSizeSVDF(const luci::CircleSVDF *svdf) final
+ {
+ const auto svdf_input = loco::must_cast<luci::CircleNode *>(svdf->input());
+ const auto weight_feature_input = loco::must_cast<luci::CircleNode *>(svdf->weight_feature());
+
+ if (svdf_input->dtype() == loco::DataType::FLOAT32 and
+ (weight_feature_input->dtype() == loco::DataType::S8 or
+ weight_feature_input->dtype() == loco::DataType::U8))
+ {
+ throw std::runtime_error("Hybrid type is not currently supported for linux platform");
+ }
+
+ std::vector<uint32_t> scratchpad_sizes;
+
+ const auto batch_size = svdf_input->dim(0).value();
+ const auto num_filters = weight_feature_input->dim(0).value();
+ const auto rank = svdf->svdf_rank();
+ const auto num_units = num_filters / rank;
+
+ if (svdf_input->dtype() == loco::DataType::S8)
+ {
+ scratchpad_sizes.push_back(batch_size * num_filters * sizeof(int32_t));
+ scratchpad_sizes.push_back(batch_size * num_units * sizeof(int32_t));
+ }
+ else
+ {
+ scratchpad_sizes.push_back(batch_size * num_filters * sizeof(float));
+ }
+
+ return scratchpad_sizes;
+ }
+};
+
+} // namespace circle_planner
+
+#endif // CIRCLE_EXECUTION_PLAN_SCRATCHPAD_HELPER_MCU_H
diff --git a/compiler/circle-execution-plan/pal/TargetPlatform.h b/compiler/circle-execution-plan/pal/TargetPlatform.h
new file mode 100644
index 000000000..538a502fe
--- /dev/null
+++ b/compiler/circle-execution-plan/pal/TargetPlatform.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef CIRCLE_EXECUTION_PLAN_TARGET_PLATFORM_H
+#define CIRCLE_EXECUTION_PLAN_TARGET_PLATFORM_H
+
+namespace circle_planner
+{
+
+enum SupportedPlatformType
+{
+ LINUX,
+ MCU,
+ CMSISNN
+};
+
+struct TargetPlatform
+{
+ SupportedPlatformType platform_type;
+ bool use_dsp;
+};
+
+} // namespace circle_planner
+
+#endif // CIRCLE_EXECUTION_PLAN_TARGET_PLATFORM_H
diff --git a/compiler/circle-execution-plan/src/CircleExecutionPlan.cpp b/compiler/circle-execution-plan/src/CircleExecutionPlan.cpp
index a54100b8c..1788124c3 100644
--- a/compiler/circle-execution-plan/src/CircleExecutionPlan.cpp
+++ b/compiler/circle-execution-plan/src/CircleExecutionPlan.cpp
@@ -35,6 +35,18 @@ int entry(int argc, char **argv)
arser.add_argument("input").nargs(1).type(arser::DataType::STR).help("Input circle model");
arser.add_argument("output").nargs(1).type(arser::DataType::STR).help("Output circle model");
+ arser.add_argument("--platform")
+ .nargs(1)
+ .type(arser::DataType::STR)
+ .required(false)
+ .default_value("linux")
+ .help("Platform name: linux mcu cmsisnn");
+ arser.add_argument("--use_dsp")
+ .nargs(1)
+ .type(arser::DataType::BOOL)
+ .required(false)
+ .default_value(false)
+ .help("Plan with or without dsp (now can be used only with cmsisnn)");
try
{
@@ -47,8 +59,35 @@ int entry(int argc, char **argv)
return 255;
}
- std::string input_path = arser.get<std::string>("input");
- std::string output_path = arser.get<std::string>("output");
+ const std::string input_path = arser.get<std::string>("input");
+ const std::string output_path = arser.get<std::string>("output");
+ const std::string platform_name = arser.get<std::string>("--platform");
+ const bool use_dsp = arser.get<bool>("--use_dsp");
+
+ if (platform_name != "cmsisnn" && use_dsp)
+ {
+ std::cerr << "ERROR: Now use_dsp can be used only with cmsisnn" << std::endl;
+ return EXIT_FAILURE;
+ }
+
+ circle_planner::SupportedPlatformType platform_type;
+ if (platform_name == "linux")
+ {
+ platform_type = circle_planner::SupportedPlatformType::LINUX;
+ }
+ else if (platform_name == "mcu")
+ {
+ platform_type = circle_planner::SupportedPlatformType::MCU;
+ }
+ else if (platform_name == "cmsisnn")
+ {
+ platform_type = circle_planner::SupportedPlatformType::CMSISNN;
+ }
+ else
+ {
+ std::cerr << "ERROR: Invalid platform name '" << platform_name << "'" << std::endl;
+ return EXIT_FAILURE;
+ }
foder::FileLoader file_loader{input_path};
std::vector<char> model_data;
@@ -82,8 +121,8 @@ int entry(int argc, char **argv)
auto module = importer.importModule(circle_model);
// Do main job
- luci::ExecutionPlanner execution_planner(module->graph());
- execution_planner.get_execution_plan();
+ circle_planner::ExecutionPlanner execution_planner(module->graph(), {platform_type, use_dsp});
+ execution_planner.make_execution_plan();
// Export to output Circle file
luci::CircleExporter exporter;
diff --git a/compiler/circle-execution-plan/src/ExecutionPlanner.cpp b/compiler/circle-execution-plan/src/ExecutionPlanner.cpp
index c37d1e5f5..ec2ec1362 100644
--- a/compiler/circle-execution-plan/src/ExecutionPlanner.cpp
+++ b/compiler/circle-execution-plan/src/ExecutionPlanner.cpp
@@ -18,72 +18,49 @@
#include <loco/IR/Algorithm.h>
#include <luci/UserSettings.h>
-namespace luci
+namespace circle_planner
{
namespace
{
-constexpr uint32_t nodeNotAssigned = std::numeric_limits<int32_t>::max();
+constexpr uint32_t node_not_assigned = std::numeric_limits<int32_t>::max();
-uint32_t compute_output_size(Padding padding, uint32_t image_size, uint32_t filter_size,
- uint32_t stride, uint32_t dilation_rate = 1)
+bool isExecutableNode(const luci::CircleNode *node)
{
- const int32_t effective_filter_size = (filter_size - 1) * dilation_rate + 1;
- switch (padding)
+ switch (node->opcode())
{
- case Padding::SAME:
- return (image_size + stride - 1) / stride;
- case Padding::VALID:
- return (image_size + stride - effective_filter_size) / stride;
+ // The following nodes denote outputs of multiple-output nodes.
+ // The list is synchronized with the same list from luci-interpreter/src/loader/GraphLoader.cpp
+ case luci::CircleOpcode::CIRCLEIFOUT:
+ case luci::CircleOpcode::CIRCLESPLITOUT:
+ case luci::CircleOpcode::CIRCLESPLITVOUT:
+ case luci::CircleOpcode::CIRCLEUNPACKOUT:
+ case luci::CircleOpcode::CIRCLEWHILEOUT:
+ return false;
default:
- assert(false);
+ return true;
}
}
-// Method finds (if necessary) size for im2col temporary tensor.
-uint32_t compute_im2col_size(const luci::CircleConv2D *conv)
+bool isTensorProducingNode(const luci::CircleNode *node)
{
- auto conv_input = loco::must_cast<luci::CircleNode *>(conv->input());
- auto filter = loco::must_cast<luci::CircleNode *>(conv->filter());
- auto padding = (conv->padding());
- uint32_t stride_height = conv->stride()->h();
- uint32_t stride_width = conv->stride()->w();
-
- uint32_t dilation_height_factor = conv->dilation()->h();
- uint32_t dilation_width_factor = conv->dilation()->w();
-
- uint32_t filter_height = filter->dim(1).value();
- uint32_t filter_width = filter->dim(2).value();
-
- const bool need_dilated_im2col = dilation_height_factor != 1 || dilation_width_factor != 1;
- const bool need_non_dilated_im2col =
- stride_height != 1 || stride_width != 1 || filter_height != 1 || filter_width != 1;
- bool need_im2col =
- conv_input->dtype() != loco::DataType::S16 && (need_dilated_im2col || need_non_dilated_im2col);
-
- if (!need_im2col)
+ switch (node->opcode())
{
- return 0;
+ // The following nodes are multiple-output nodes. They do not produce tensors, the tensors
+ // are produced by the corresponding *Out nodes instead.
+ // The list is synchronized with the same list from luci-interpreter/src/loader/GraphLoader.cpp
+ case luci::CircleOpcode::IF:
+ case luci::CircleOpcode::SPLIT:
+ case luci::CircleOpcode::UNPACK:
+ return false;
+ default:
+ return true;
}
-
- uint32_t input_depth = conv_input->dim(3).value();
- uint32_t input_height = conv_input->dim(1).value();
- uint32_t input_width = conv_input->dim(2).value();
-
- uint32_t output_height = compute_output_size(padding, input_height, filter_height, stride_height,
- dilation_height_factor);
- uint32_t output_width =
- compute_output_size(padding, input_width, filter_width, stride_width, dilation_width_factor);
-
- uint32_t batches = conv_input->dim(0).value();
-
- return batches * output_height * output_width * input_depth * filter_height * filter_width *
- size(conv_input->dtype());
}
} // namespace
-void ExecutionPlanner::get_execution_plan()
+void ExecutionPlanner::make_execution_plan()
{
get_default_execution_order_plan();
_required_size = get_offsets_with_greedy_by_size();
@@ -106,23 +83,23 @@ void ExecutionPlanner::get_default_execution_order_plan()
void ExecutionPlanner::get_usage_interval()
{
// Initialize vectors of first and last nodes for usage interval
- _alloc_node.assign(_ordered_nodes.size(), nodeNotAssigned);
- _dealloc_node.assign(_ordered_nodes.size(), nodeNotAssigned);
+ _alloc_node.assign(_ordered_nodes.size(), node_not_assigned);
+ _dealloc_node.assign(_ordered_nodes.size(), node_not_assigned);
// Vector for count usages
std::vector<int> usages_counts(_ordered_nodes.size(), 0);
auto allocate = [this](uint32_t node, uint32_t tensor) {
- if (_alloc_node[tensor] != nodeNotAssigned)
+ if (_alloc_node[tensor] != node_not_assigned)
{
return;
}
- assert(_dealloc_node[tensor] == nodeNotAssigned);
+ assert(_dealloc_node[tensor] == node_not_assigned);
_alloc_node[tensor] = node;
};
auto deallocate = [this](uint32_t node, uint32_t tensor) {
- assert(_dealloc_node[tensor] == nodeNotAssigned);
+ assert(_dealloc_node[tensor] == node_not_assigned);
_dealloc_node[tensor] = node;
};
@@ -158,13 +135,24 @@ void ExecutionPlanner::get_usage_interval()
for (uint32_t i = 0; i < _ordered_nodes.size(); i++)
{
const auto node = _ordered_nodes.at(i);
+ auto prev_nodes = preds(node);
if (const auto *const_node = dynamic_cast<const luci::CircleConst *>(node))
{
allocate(0, i);
}
- allocate(i, i);
+ else if (!isExecutableNode(loco::must_cast<luci::CircleNode *>(node)))
+ {
+ // If current node is multi output node than begin life time for current node should start
+ // when prev node start live
+ auto it = std::find(_ordered_nodes.begin(), _ordered_nodes.end(), *prev_nodes.begin());
+ size_t index = std::distance(_ordered_nodes.begin(), it);
+ allocate(index, i);
+ }
+ else
+ {
+ allocate(i, i);
+ }
- auto prev_nodes = preds(node);
for (auto &prev_node : prev_nodes)
{
auto it = std::find(_ordered_nodes.begin(), _ordered_nodes.end(), prev_node);
@@ -203,7 +191,7 @@ uint32_t ExecutionPlanner::get_offsets_with_greedy_by_size()
uint32_t ExecutionPlanner::greedy_by_size_approach()
{
size_t result_size = 0;
- create_alloc_node_inform_vector(false, false, false);
+ create_alloc_node_inform_vector(_is_null_consts, _is_null_inputs, _is_null_scratchpads);
std::vector<AllocationNodeInformation> ordered_alloc_inform;
for (auto &current_node : _alloc_node_inform_vector)
{
@@ -250,22 +238,22 @@ uint32_t ExecutionPlanner::greedy_by_size_approach()
}
void ExecutionPlanner::create_alloc_node_inform_vector(bool null_consts, bool null_inputs,
- bool null_im2col)
+ bool null_scratchpad)
{
auto node_compare = [this](const AllocationNodeInformation &alloc_1,
const AllocationNodeInformation &alloc_2) {
auto idx1 = alloc_1.node_num;
auto idx2 = alloc_2.node_num;
- if (this->_alloc_node[idx1] == 0 && this->_dealloc_node[idx1] == nodeNotAssigned)
+ if (this->_alloc_node[idx1] == 0 && this->_dealloc_node[idx1] == node_not_assigned)
{
- if (this->_alloc_node[idx2] == 0 && this->_dealloc_node[idx2] == nodeNotAssigned)
+ if (this->_alloc_node[idx2] == 0 && this->_dealloc_node[idx2] == node_not_assigned)
{
return idx1 < idx2;
}
return true;
}
- if (this->_alloc_node[idx2] == 0 && this->_dealloc_node[idx2] == nodeNotAssigned)
+ if (this->_alloc_node[idx2] == 0 && this->_dealloc_node[idx2] == node_not_assigned)
{
return false;
}
@@ -305,30 +293,66 @@ void ExecutionPlanner::create_alloc_node_inform_vector(bool null_consts, bool nu
{
_alloc_node_inform_vector[i].size = 0;
}
+ else if (!isTensorProducingNode(circle_node))
+ {
+ _alloc_node_inform_vector[i].size = 0;
+ }
else
{
_alloc_node_inform_vector[i].size = node_size;
}
- // Im2col
- auto opcode = circle_node->opcode();
- if (opcode == luci::CircleOpcode::CONV_2D)
+ // Scratchpad If needed
+ std::vector<uint32_t> scratchpad_sizes;
+ if (!null_scratchpad)
{
- auto conv = loco::must_cast<const luci::CircleConv2D *>(circle_node);
- auto im2col_size = compute_im2col_size(conv);
- if (im2col_size > 0)
+ switch (circle_node->opcode())
{
- AllocationNodeInformation temp_alloc;
-
- if (null_im2col)
+ case luci::CircleOpcode::AVERAGE_POOL_2D:
{
- temp_alloc.size = 0;
+ const auto avg_pool = loco::must_cast<const luci::CircleAveragePool2D *>(circle_node);
+ scratchpad_sizes.push_back(
+ _scratchpad_helper->ComputeScratchpadSizeAveragePool2d(avg_pool));
+ break;
}
- else
+ case luci::CircleOpcode::BATCH_MATMUL:
{
- temp_alloc.size = im2col_size;
+ const auto batch_mat_mul = loco::must_cast<const luci::CircleBatchMatMul *>(circle_node);
+ scratchpad_sizes = _scratchpad_helper->ComputeScratchpadSizeBatchMatMul(batch_mat_mul);
+ break;
}
+ case luci::CircleOpcode::CONV_2D:
+ {
+ const auto conv = loco::must_cast<const luci::CircleConv2D *>(circle_node);
+ scratchpad_sizes.push_back(_scratchpad_helper->ComputeScratchpadSizeConv2d(conv));
+ break;
+ }
+ case luci::CircleOpcode::DEPTHWISE_CONV_2D:
+ {
+ const auto depthwise_conv =
+ loco::must_cast<const luci::CircleDepthwiseConv2D *>(circle_node);
+ scratchpad_sizes.push_back(
+ _scratchpad_helper->ComputeScratchpadSizeDepthwiseConv2d(depthwise_conv));
+ break;
+ }
+ case luci::CircleOpcode::SVDF:
+ {
+ const auto svdf = loco::must_cast<const luci::CircleSVDF *>(circle_node);
+ scratchpad_sizes = _scratchpad_helper->ComputeScratchpadSizeSVDF(svdf);
+ break;
+ }
+ default:
+ break;
+ }
+ }
+
+ for (const auto scratchpad_size : scratchpad_sizes)
+ {
+ if (scratchpad_size > 0)
+ {
+ AllocationNodeInformation temp_alloc;
+ temp_alloc.size = scratchpad_size;
temp_alloc.first_node = i - 1;
temp_alloc.last_node = i + 1;
temp_alloc.node_num = i;
@@ -352,7 +376,7 @@ void ExecutionPlanner::dump_inform()
{
auto current_node_it = std::find_if(
_alloc_node_inform_vector.begin(), _alloc_node_inform_vector.end(),
- [this, i](const AllocationNodeInformation &x) { return x.node_num == i && !x.is_temp; });
+ [i](const AllocationNodeInformation &x) { return x.node_num == i && !x.is_temp; });
for (uint32_t j = 0; j < _ordered_nodes.size(); j++)
{
auto first_node = _alloc_node[j];
@@ -360,7 +384,7 @@ void ExecutionPlanner::dump_inform()
auto it = std::find_if(
_alloc_node_inform_vector.begin(), _alloc_node_inform_vector.end(),
- [this, j](const AllocationNodeInformation &x) { return x.node_num == j && !x.is_temp; });
+ [j](const AllocationNodeInformation &x) { return x.node_num == j && !x.is_temp; });
if (i >= first_node && i <= last_node)
{
current_node_it->breadth += it->size;
@@ -386,4 +410,4 @@ void ExecutionPlanner::dump_inform()
});
}
-} // namespace luci
+} // namespace circle_planner
diff --git a/compiler/circle-execution-plan/src/ExecutionPlanner.h b/compiler/circle-execution-plan/src/ExecutionPlanner.h
index 8e3d9b46a..e0833c407 100644
--- a/compiler/circle-execution-plan/src/ExecutionPlanner.h
+++ b/compiler/circle-execution-plan/src/ExecutionPlanner.h
@@ -17,10 +17,15 @@
#ifndef CIRCLE_EXECUTION_PLANNER_H
#define CIRCLE_EXECUTION_PLANNER_H
+#include "TargetPlatform.h"
+#include "IScratchpadHelper.h"
+#include "ScratchpadHelperLinux.h"
+#include "ScratchpadHelperMCU.h"
+#include "ScratchpadHelperCMSISNN.h"
#include <luci/IR/Module.h>
#include <luci/Plan/CircleNodeExecutionPlan.h>
-namespace luci
+namespace circle_planner
{
// struct for additional information for the node. it helps build allocations plan for nodes.
struct AllocationNodeInformation
@@ -50,7 +55,7 @@ struct AllocationNodeInformation
uint32_t last_node;
// is the current node temporary or not
bool is_temp;
- // operation breadth of current node
+ // Breadth is a sum of live tensors sizes at the moment of execution of given node
uint32_t breadth;
bool operator<(const AllocationNodeInformation &other) const { return offset < other.offset; }
@@ -60,12 +65,44 @@ class ExecutionPlanner
{
public:
ExecutionPlanner() = delete;
- explicit ExecutionPlanner(loco::Graph *graph) { _graph = graph; };
+ explicit ExecutionPlanner(loco::Graph *graph) : _graph(graph)
+ {
+ _scratchpad_helper = std::make_unique<ScratchpadHelperLinux>();
+ }
+
+ explicit ExecutionPlanner(loco::Graph *graph, TargetPlatform target_platform) : _graph(graph)
+ {
+ switch (target_platform.platform_type)
+ {
+ case LINUX:
+ _scratchpad_helper = std::make_unique<ScratchpadHelperLinux>();
+ break;
+ case MCU:
+ _scratchpad_helper = std::make_unique<ScratchpadHelperMCU>();
+ break;
+ case CMSISNN:
+ _scratchpad_helper = std::make_unique<ScratchpadHelperCMSISNN>(target_platform.use_dsp);
+ break;
+ default:
+ assert(false && "Use unsupported platform");
+ }
+ };
// Method provides execution plan, which contains execution order and
// memory offsets for all nodes in _graph.
// This plan writes in nodes annotation information with help of CircleNodeExecutionPlan class.
- void get_execution_plan();
+ void make_execution_plan();
+
+ // Method change planning mode:
+ // is_null_consts = true - constants are no longer taken into account when planning
+ // is_null_inputs = true - input are no longer taken into account when planning
+ // is_null_scratchpads = true - scratchpads are no longer taken into account when planning
+ void change_planning_mode(bool is_null_consts, bool is_null_inputs, bool is_null_scratchpads)
+ {
+ _is_null_consts = is_null_consts;
+ _is_null_inputs = is_null_inputs;
+ _is_null_scratchpads = is_null_scratchpads;
+ };
private:
// Method gets default execution order plan and saves it in _ordered_nodes vector.
@@ -83,18 +120,19 @@ private:
// Return: required size of buffer.
uint32_t get_offsets_with_greedy_by_size();
- // Realization of greedy by size approach to find offsets for nodes.
+ // Realization of greedy by size approach (algorithm is mentioned in
+ // "EFFICIENT MEMORY MANAGEMENT FOR DEEP NEURAL NET INFERENCE" paper) to find offsets for nodes.
uint32_t greedy_by_size_approach();
// Method creates and fills _alloc_node_inform_vector with usage interval inform and node's sizes.
// null_consts = true - size of const nodes will be equal 0;
// null_inputs = true - size of input nodes will be equal 0;
- // null_im2col = true - size of im2col nodes will be equal 0;
- // It using if we don't want to take input(const or im2col) nodes into account
+ // null_scratchpad = true - size of scratchpad nodes will be equal 0;
+ // It using if we don't want to take input(const or scratchpads) nodes into account
// when determining offsets and calculating the required buffer size. This is uses for
// experiments.
void create_alloc_node_inform_vector(bool null_consts = false, bool null_inputs = false,
- bool null_im2col = false);
+ bool null_scratchpad = false);
// Stores allocation additional information for the all nodes from _graph.
std::vector<AllocationNodeInformation> _alloc_node_inform_vector;
@@ -121,10 +159,21 @@ private:
loco::Graph *_graph;
+ // Calculate size of scratchpad tensors for current platform
+ std::unique_ptr<IScratchpadHelper> _scratchpad_helper;
+
// Required memory size.
uint32_t _required_size = 0;
+
+ // Flags for choosing different planning modes:
+ // _is_null_consts = true - constants are no longer taken into account when planning
+ // _is_null_inputs = true - input are no longer taken into account when planning
+ // _is_null_scratchpads = true - scratchpads are no longer taken into account when planning
+ bool _is_null_consts = false;
+ bool _is_null_inputs = false;
+ bool _is_null_scratchpads = false;
};
-} // namespace luci
+} // namespace circle_planner
#endif // CIRCLE_EXECUTION_PLANNER_H
diff --git a/compiler/circle-inspect/CMakeLists.txt b/compiler/circle-inspect/CMakeLists.txt
index d0775ea2d..10d26d191 100644
--- a/compiler/circle-inspect/CMakeLists.txt
+++ b/compiler/circle-inspect/CMakeLists.txt
@@ -1,6 +1,6 @@
-if(NOT TARGET mio_circle)
+if(NOT TARGET mio_circle04)
return()
-endif(NOT TARGET mio_circle)
+endif(NOT TARGET mio_circle04)
set(DRIVER "driver/Driver.cpp")
@@ -10,5 +10,6 @@ add_executable(circle-inspect ${DRIVER} ${SOURCES})
target_include_directories(circle-inspect PRIVATE src)
target_link_libraries(circle-inspect arser)
target_link_libraries(circle-inspect foder)
-target_link_libraries(circle-inspect mio_circle)
+target_link_libraries(circle-inspect mio_circle04)
+target_link_libraries(circle-inspect mio_circle04_helper)
target_link_libraries(circle-inspect safemain)
diff --git a/compiler/circle-inspect/README.md b/compiler/circle-inspect/README.md
index 1f76c8ede..94eea7b08 100644
--- a/compiler/circle-inspect/README.md
+++ b/compiler/circle-inspect/README.md
@@ -20,3 +20,19 @@ ADD
```
To get the count of specific operator, use other tools like sort, uniq, etc.
+
+Operators with `--tensor_dtype`
+- show name and dtype of each tensor one line at a time
+
+Example
+```
+$ circle-inspect --tensor_dtype quantized_conv2d.circle
+```
+
+Result
+```
+ifm UINT8
+weights UINT8
+bias INT32
+ofm UINT8
+```
diff --git a/compiler/circle-inspect/driver/Driver.cpp b/compiler/circle-inspect/driver/Driver.cpp
index a450fd9e0..10e185de5 100644
--- a/compiler/circle-inspect/driver/Driver.cpp
+++ b/compiler/circle-inspect/driver/Driver.cpp
@@ -35,6 +35,7 @@ int entry(int argc, char **argv)
.nargs(0)
.help("Dump Conv2D series weight operators in circle file");
arser.add_argument("--op_version").nargs(0).help("Dump versions of the operators in circle file");
+ arser.add_argument("--tensor_dtype").nargs(0).help("Dump dtype of tensors");
arser.add_argument("circle").type(arser::DataType::STR).help("Circle file to inspect");
try
@@ -48,7 +49,8 @@ int entry(int argc, char **argv)
return 255;
}
- if (!arser["--operators"] && !arser["--conv2d_weight"] && !arser["--op_version"])
+ if (!arser["--operators"] && !arser["--conv2d_weight"] && !arser["--op_version"] &&
+ !arser["--tensor_dtype"])
{
std::cout << "At least one option must be specified" << std::endl;
std::cout << arser;
@@ -63,6 +65,8 @@ int entry(int argc, char **argv)
dumps.push_back(std::make_unique<circleinspect::DumpConv2DWeight>());
if (arser["--op_version"])
dumps.push_back(std::make_unique<circleinspect::DumpOperatorVersion>());
+ if (arser["--tensor_dtype"])
+ dumps.push_back(std::make_unique<circleinspect::DumpTensorDType>());
std::string model_file = arser.get<std::string>("circle");
diff --git a/compiler/circle-inspect/requires.cmake b/compiler/circle-inspect/requires.cmake
index 81e0f0dbd..362d67cf4 100644
--- a/compiler/circle-inspect/requires.cmake
+++ b/compiler/circle-inspect/requires.cmake
@@ -1,3 +1,3 @@
require("arser")
-require("mio-circle")
+require("mio-circle04")
require("safemain")
diff --git a/compiler/circle-inspect/src/Dump.cpp b/compiler/circle-inspect/src/Dump.cpp
index 5c71afb3f..bba5e56c3 100644
--- a/compiler/circle-inspect/src/Dump.cpp
+++ b/compiler/circle-inspect/src/Dump.cpp
@@ -175,3 +175,28 @@ void DumpOperatorVersion::run(std::ostream &os, const circle::Model *model)
}
} // namespace circleinspect
+
+namespace circleinspect
+{
+
+void DumpTensorDType::run(std::ostream &os, const circle::Model *model)
+{
+ circleinspect::Reader reader(model);
+
+ const uint32_t subgraph_size = reader.num_subgraph();
+
+ for (uint32_t g = 0; g < subgraph_size; g++)
+ {
+ reader.select_subgraph(g);
+ auto tensors = reader.tensors();
+
+ for (uint32_t i = 0; i < tensors->Length(); ++i)
+ {
+ const auto tensor = tensors->Get(i);
+
+ os << reader.tensor_name(tensor) << " " << reader.tensor_dtype(tensor) << std::endl;
+ }
+ }
+}
+
+} // namespace circleinspect
diff --git a/compiler/circle-inspect/src/Dump.h b/compiler/circle-inspect/src/Dump.h
index 996c421f9..8ca6838d1 100644
--- a/compiler/circle-inspect/src/Dump.h
+++ b/compiler/circle-inspect/src/Dump.h
@@ -60,6 +60,15 @@ public:
void run(std::ostream &os, const circle::Model *model);
};
+class DumpTensorDType final : public DumpInterface
+{
+public:
+ DumpTensorDType() = default;
+
+public:
+ void run(std::ostream &os, const circle::Model *model);
+};
+
} // namespace circleinspect
#endif // __DUMP_H__
diff --git a/compiler/circle-inspect/src/Reader.cpp b/compiler/circle-inspect/src/Reader.cpp
index 7807db38a..0e2865254 100644
--- a/compiler/circle-inspect/src/Reader.cpp
+++ b/compiler/circle-inspect/src/Reader.cpp
@@ -16,66 +16,14 @@
#include "Reader.h"
+#include <mio_circle/Helper.h>
+
#include <sstream>
#include <string>
namespace circleinspect
{
-bool is_valid(const circle::OperatorCode *opcode)
-{
- circle::BuiltinOperator code = opcode->builtin_code();
- return (circle::BuiltinOperator_MIN <= code && code <= circle::BuiltinOperator_MAX);
-}
-
-bool is_custom(const circle::OperatorCode *opcode)
-{
- circle::BuiltinOperator code = opcode->builtin_code();
- return (code == circle::BuiltinOperator_CUSTOM);
-}
-
-std::string opcode_name(const circle::OperatorCode *opcode)
-{
- assert(opcode);
-
- if (!is_valid(opcode))
- {
- std::ostringstream oss;
- oss << "(invalid)";
- return oss.str();
- }
-
- if (is_custom(opcode))
- {
- if (!opcode->custom_code())
- return "(invalid custom)";
-
- std::string custom_op = "CUSTOM(";
- custom_op += opcode->custom_code()->c_str();
- custom_op += ")";
- return custom_op;
- }
-
- circle::BuiltinOperator code = opcode->builtin_code();
- return circle::EnumNameBuiltinOperator(code);
-}
-
-const char *tensor_type(const circle::Tensor *tensor)
-{
- return circle::EnumNameTensorType(tensor->type());
-}
-
-const char *tensor_name(const circle::Tensor *tensor)
-{
- static const char *kEmptyTensorName = "(noname)";
-
- auto name = tensor->name();
- if (name)
- return name->c_str();
-
- return kEmptyTensorName;
-}
-
Reader::Reader(const circle::Model *model)
{
_subgraphs = model->subgraphs();
@@ -122,7 +70,7 @@ circle::BuiltinOperator Reader::builtin_code(const circle::Operator *op) const
assert(index < _op_codes.size());
const circle::OperatorCode *opcode = _op_codes.at(index);
- return opcode->builtin_code();
+ return mio::circle::builtin_code_neutral(opcode);
}
std::string Reader::opcode_name(const circle::Operator *op) const
@@ -131,14 +79,24 @@ std::string Reader::opcode_name(const circle::Operator *op) const
assert(index < _op_codes.size());
const circle::OperatorCode *opcode = _op_codes.at(index);
- if (!is_valid(opcode))
+ if (!mio::circle::is_valid(opcode))
{
std::ostringstream oss;
oss << "(invalid: " << index << ")";
return oss.str();
}
- return circleinspect::opcode_name(opcode);
+ return mio::circle::opcode_name(opcode);
+}
+
+std::string Reader::tensor_name(const circle::Tensor *tensor) const
+{
+ return mio::circle::tensor_name(tensor);
+}
+
+std::string Reader::tensor_dtype(const circle::Tensor *tensor) const
+{
+ return mio::circle::tensor_type(tensor);
}
bool Reader::select_subgraph(uint32_t sgindex)
diff --git a/compiler/circle-inspect/src/Reader.h b/compiler/circle-inspect/src/Reader.h
index b5a99df3f..c38ec3990 100644
--- a/compiler/circle-inspect/src/Reader.h
+++ b/compiler/circle-inspect/src/Reader.h
@@ -36,12 +36,6 @@ template <typename T> std::vector<T> as_index_vector(const flatbuffers::Vector<T
return ret;
}
-bool is_valid(const circle::OperatorCode *opcode);
-bool is_custom(const circle::OperatorCode *opcode);
-std::string opcode_name(const circle::OperatorCode *opcode);
-const char *tensor_type(const circle::Tensor *tensor);
-const char *tensor_name(const circle::Tensor *tensor);
-
/**
* @brief Loads Circle file and provides helpers to access attributes
*/
@@ -71,6 +65,8 @@ public:
size_t buffer_info(uint32_t buf_idx, const uint8_t **buff_data);
circle::BuiltinOperator builtin_code(const circle::Operator *op) const;
std::string opcode_name(const circle::Operator *op) const;
+ std::string tensor_name(const circle::Tensor *tensor) const;
+ std::string tensor_dtype(const circle::Tensor *tensor) const;
public:
bool select_subgraph(uint32_t subgraph);
diff --git a/compiler/circle-opselector/README.md b/compiler/circle-opselector/README.md
index c06899ab5..5ea2d32c4 100644
--- a/compiler/circle-opselector/README.md
+++ b/compiler/circle-opselector/README.md
@@ -1,21 +1,21 @@
-# circle-opselector
-
-`circle-opselector` is a tool for creating new circle models by selecting nodes from a model.
-
-## Example
-
-### 1. Select from location numbers
-
-```bash
-./circle-opselector --by_id "1-3,5" input.circle output.circle
-```
-
-Then, output.circle which has node 1, 2, 3 and 5 will be created.
-
-### 2. Select from node names
-
-```bash
-./circle-opselector --by_name "Add_1,Sub_1,Concat_2" input.circle output.circle
-```
-
-Then, output.circle which has node Add_1, Sub_1 and Concat_2 will be created.
+# circle-opselector
+
+`circle-opselector` is a tool for creating new circle models by selecting nodes from a model.
+
+## Example
+
+### 1. Select from location numbers
+
+```bash
+./circle-opselector --by_id "1-3,5" input.circle output.circle
+```
+
+Then, output.circle which has node 1, 2, 3 and 5 will be created.
+
+### 2. Select from node names
+
+```bash
+./circle-opselector --by_name "Add_1,Sub_1,Concat_2" input.circle output.circle
+```
+
+Then, output.circle which has node Add_1, Sub_1 and Concat_2 will be created.
diff --git a/compiler/circle-part-value-test/CMakeLists.txt b/compiler/circle-part-value-test/CMakeLists.txt
index 1cfbcbd9b..0657607d2 100644
--- a/compiler/circle-part-value-test/CMakeLists.txt
+++ b/compiler/circle-part-value-test/CMakeLists.txt
@@ -82,8 +82,8 @@ foreach(IDX RANGE ${RECIPE_LENGTH_M1})
# Run partitioner
add_custom_command(OUTPUT ${PARTITIONER_CONN_JSON}
- COMMAND circle_partitioner "${PART_FILE}" "${PARTITION_NAME}.circle" "${PARTITIONER_OUTPUT_PATH}"
- DEPENDS circle_partitioner ${PART_DST_PATH} ${CIRCLE_DST_PATH}
+ COMMAND circle-partitioner "${PART_FILE}" "${PARTITION_NAME}.circle" "${PARTITIONER_OUTPUT_PATH}"
+ DEPENDS circle-partitioner ${PART_DST_PATH} ${CIRCLE_DST_PATH}
COMMENT "Parition ${RECIPE_NAME}.circle with ${PART_FILE}"
)
list(APPEND TEST_DEPS ${PARTITIONER_CONN_JSON})
@@ -106,7 +106,7 @@ add_dependencies(circle_part_value_test_prepare common_artifacts_deps)
add_test(NAME circle_part_value_test
COMMAND "${CMAKE_CURRENT_SOURCE_DIR}/part_eval_all.sh"
"${CMAKE_CURRENT_BINARY_DIR}"
- "${NNCC_OVERLAY_DIR}/venv_2_6_0"
+ "${NNCC_OVERLAY_DIR}/venv_2_8_0"
"$<TARGET_FILE:circle_part_driver>"
${PARTITION_LIST}
)
diff --git a/compiler/circle-part-value-test/part_eval_one.py b/compiler/circle-part-value-test/part_eval_one.py
index 91e32d78f..44661c78b 100755
--- a/compiler/circle-part-value-test/part_eval_one.py
+++ b/compiler/circle-part-value-test/part_eval_one.py
@@ -53,21 +53,37 @@ except:
interpreter = tf.lite.Interpreter(tflite_model)
interpreter.allocate_tensors()
+# Read SignatureDef and get output tensor id orders for remapping
+full_signatures = interpreter._get_full_signature_list()
+full_signatures_outputs_remap = None
+if full_signatures != None:
+ signature_serving_default = full_signatures.get('serving_default', None)
+ if signature_serving_default != None:
+ signature_outputs = signature_serving_default['outputs']
+
+ full_signatures_outputs_remap = []
+ for index, (key, value) in enumerate(signature_outputs.items()):
+ full_signatures_outputs_remap.append(value)
+
# Generate random input data.
num_inputs = len(interpreter.get_input_details())
for i in range(num_inputs):
input_details = interpreter.get_input_details()[i]
- if input_details["dtype"] == np.float32:
+ input_details_dtype = input_details["dtype"]
+ input_details_shape = input_details["shape"]
+ if input_details_dtype == np.float32:
input_data = np.array(
- np.random.random_sample(input_details["shape"]), input_details["dtype"])
- elif input_details["dtype"] == np.uint8:
+ np.random.random_sample(input_details_shape), input_details_dtype)
+ elif input_details_dtype == np.int16:
input_data = np.array(
- np.random.randint(0, 256, size=input_details["shape"]),
- input_details["dtype"])
- elif input_details["dtype"] == np.bool_:
+ np.random.randint(0, 100, size=input_details_shape), input_details_dtype)
+ elif input_details_dtype == np.uint8:
input_data = np.array(
- np.random.choice(a=[True, False], size=input_details["shape"]),
- input_details["dtype"])
+ np.random.randint(0, 256, size=input_details_shape), input_details_dtype)
+ elif input_details_dtype == np.bool_:
+ input_data = np.array(
+ np.random.choice(a=[True, False], size=input_details_shape),
+ input_details_dtype)
else:
raise SystemExit("Unsupported input dtype")
@@ -90,52 +106,42 @@ print("", flush=True)
subprocess.run(partition_command, check=True)
# Compare the results.
-for idx in range(len(interpreter.get_output_details())):
- output_details = interpreter.get_output_details()[idx]
- output_data = np.fromfile(circle_model + ".output" + str(idx),
- output_details["dtype"])
+inpt_output_details = interpreter.get_output_details()
+for idx in range(len(inpt_output_details)):
+ output_details = inpt_output_details[idx]
+ output_dtype = output_details["dtype"]
+ output_data = np.fromfile(circle_model + ".output" + str(idx), output_dtype)
shape_file = open(circle_model + ".output" + str(idx) + ".shape", 'r')
output_shape = [int(i) for i in shape_file.read().split(',')]
luci_output_data = np.reshape(output_data, output_shape)
+ output_tensor = output_details["index"]
+ if full_signatures_outputs_remap != None:
+ output_tensor = full_signatures_outputs_remap[idx]
+ intp_output_data = interpreter.get_tensor(output_tensor)
try:
- if output_details["dtype"] == np.uint8:
- if np.allclose(
- luci_output_data,
- interpreter.get_tensor(
- interpreter.get_output_details()[idx]["index"]),
- rtol=0,
- atol=0) == False:
+ if output_dtype == np.uint8:
+ if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) == False:
raise SystemExit("Execution result of " + tflite_model +
" does not match with " + circle_model)
- elif output_details["dtype"] == np.float32:
+ elif output_dtype == np.float32:
if np.allclose(
- luci_output_data,
- interpreter.get_tensor(
- interpreter.get_output_details()[idx]["index"]),
- rtol=1.e-5,
- atol=1.e-5) == False:
+ luci_output_data, intp_output_data, rtol=1.e-5, atol=1.e-5) == False:
raise SystemExit("Execution result of " + tflite_model +
" does not match with " + circle_model)
- elif output_details["dtype"] == np.int64:
- if np.allclose(
- luci_output_data,
- interpreter.get_tensor(
- interpreter.get_output_details()[idx]["index"]),
- rtol=0,
- atol=0) == False:
+ elif output_dtype == np.int64:
+ if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) == False:
raise SystemExit("Execution result of " + tflite_model +
" does not match with " + circle_model)
- elif output_details["dtype"] == np.int32:
- if np.allclose(
- luci_output_data,
- interpreter.get_tensor(
- interpreter.get_output_details()[idx]["index"]),
- rtol=0,
- atol=0) == False:
+ elif output_dtype == np.int32:
+ if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) == False:
+ raise SystemExit("Execution result of " + tflite_model +
+ " does not match with " + circle_model)
+ elif output_dtype == np.int16:
+ if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) == False:
raise SystemExit("Execution result of " + tflite_model +
" does not match with " + circle_model)
else:
- raise SystemExit("Unsupported data type: ", output_details["dtype"])
+ raise SystemExit("Unsupported data type: ", output_dtype)
except:
print(traceback.format_exc())
quit(255)
diff --git a/compiler/circle-part-value-test/parts/Net_UnpackAdd_001.001.part b/compiler/circle-part-value-test/parts/Net_UnpackAdd_001.001.part
new file mode 100644
index 000000000..496971e55
--- /dev/null
+++ b/compiler/circle-part-value-test/parts/Net_UnpackAdd_001.001.part
@@ -0,0 +1,7 @@
+[partition]
+backends=cpu,npu
+default=cpu
+comply=opcode
+
+[OPCODE]
+ADD=npu
diff --git a/compiler/circle-part-value-test/parts/Net_UnpackAdd_001.002.part b/compiler/circle-part-value-test/parts/Net_UnpackAdd_001.002.part
new file mode 100644
index 000000000..9913fea96
--- /dev/null
+++ b/compiler/circle-part-value-test/parts/Net_UnpackAdd_001.002.part
@@ -0,0 +1,7 @@
+[partition]
+backends=cpu,npu
+default=cpu
+comply=opcode
+
+[OPCODE]
+UNPACK=npu
diff --git a/compiler/circle-part-value-test/parts/Net_UnpackAdd_001.part b/compiler/circle-part-value-test/parts/Net_UnpackAdd_001.part
new file mode 100644
index 000000000..c63efc592
--- /dev/null
+++ b/compiler/circle-part-value-test/parts/Net_UnpackAdd_001.part
@@ -0,0 +1,7 @@
+[partition]
+backends=cpu,npu
+default=npu
+comply=opcode
+
+[OPCODE]
+UNPACK=cpu
diff --git a/compiler/circle-part-value-test/parts/Part_Mul_Sqrt_FC_nobias_000_000.part b/compiler/circle-part-value-test/parts/Part_Mul_Sqrt_FC_nobias_000_000.part
new file mode 100644
index 000000000..ad0842165
--- /dev/null
+++ b/compiler/circle-part-value-test/parts/Part_Mul_Sqrt_FC_nobias_000_000.part
@@ -0,0 +1,7 @@
+[partition]
+backends=cpu,npu
+default=npu
+comply=opcode
+
+[OPCODE]
+MUL=npu
diff --git a/compiler/circle-part-value-test/parts/Part_Mul_Sqrt_FC_nobias_000_001.part b/compiler/circle-part-value-test/parts/Part_Mul_Sqrt_FC_nobias_000_001.part
new file mode 100644
index 000000000..c82b741b0
--- /dev/null
+++ b/compiler/circle-part-value-test/parts/Part_Mul_Sqrt_FC_nobias_000_001.part
@@ -0,0 +1,7 @@
+[partition]
+backends=cpu,npu
+default=npu
+comply=opcode
+
+[OPCODE]
+SQRT=cpu
diff --git a/compiler/circle-part-value-test/parts/Part_Mul_Sqrt_FC_nobias_000_002.part b/compiler/circle-part-value-test/parts/Part_Mul_Sqrt_FC_nobias_000_002.part
new file mode 100644
index 000000000..d9d2a8e59
--- /dev/null
+++ b/compiler/circle-part-value-test/parts/Part_Mul_Sqrt_FC_nobias_000_002.part
@@ -0,0 +1,7 @@
+[partition]
+backends=cpu,npu
+default=npu
+comply=opcode
+
+[OPCODE]
+FULLY_CONNECTED=cpu
diff --git a/compiler/circle-part-value-test/parts/Part_Split_Add_000.part b/compiler/circle-part-value-test/parts/Part_Split_Add_000.part
new file mode 100644
index 000000000..91af566cd
--- /dev/null
+++ b/compiler/circle-part-value-test/parts/Part_Split_Add_000.part
@@ -0,0 +1,7 @@
+[partition]
+backends=cpu,npu
+default=npu
+comply=opcode
+
+[OPCODE]
+SPLIT=cpu
diff --git a/compiler/circle-part-value-test/parts/Part_Tanh_FC_nobias.part b/compiler/circle-part-value-test/parts/Part_Tanh_FC_nobias.part
new file mode 100644
index 000000000..d4d439d27
--- /dev/null
+++ b/compiler/circle-part-value-test/parts/Part_Tanh_FC_nobias.part
@@ -0,0 +1,7 @@
+[partition]
+backends=cpu,acl_cl
+default=cpu
+comply=opcode
+
+[OPCODE]
+DIV=acl_cl
diff --git a/compiler/circle-part-value-test/parts/Part_Tanh_FC_nobias_001.part b/compiler/circle-part-value-test/parts/Part_Tanh_FC_nobias_001.part
new file mode 100644
index 000000000..dbd174ee1
--- /dev/null
+++ b/compiler/circle-part-value-test/parts/Part_Tanh_FC_nobias_001.part
@@ -0,0 +1,7 @@
+[partition]
+backends=cpu,npu
+default=npu
+comply=opcode
+
+[OPCODE]
+TANH=cpu
diff --git a/compiler/circle-part-value-test/parts/Part_Tanh_FC_nobias_002.part b/compiler/circle-part-value-test/parts/Part_Tanh_FC_nobias_002.part
new file mode 100644
index 000000000..475439a9d
--- /dev/null
+++ b/compiler/circle-part-value-test/parts/Part_Tanh_FC_nobias_002.part
@@ -0,0 +1,7 @@
+[partition]
+backends=cpu,npu
+default=cpu
+comply=opcode
+
+[OPCODE]
+FULLY_CONNECTED=npu
diff --git a/compiler/circle-part-value-test/parts/Part_Tanh_FC_nobias_003.part b/compiler/circle-part-value-test/parts/Part_Tanh_FC_nobias_003.part
new file mode 100644
index 000000000..d9d2a8e59
--- /dev/null
+++ b/compiler/circle-part-value-test/parts/Part_Tanh_FC_nobias_003.part
@@ -0,0 +1,7 @@
+[partition]
+backends=cpu,npu
+default=npu
+comply=opcode
+
+[OPCODE]
+FULLY_CONNECTED=cpu
diff --git a/compiler/circle-part-value-test/parts/SignatureDef_MultiOut_000.part b/compiler/circle-part-value-test/parts/SignatureDef_MultiOut_000.part
new file mode 100644
index 000000000..e469eeb26
--- /dev/null
+++ b/compiler/circle-part-value-test/parts/SignatureDef_MultiOut_000.part
@@ -0,0 +1,7 @@
+[partition]
+backends=cpu,acl_cl
+default=cpu
+comply=opcode
+
+[OPCODE]
+MAXIMUM=acl_cl
diff --git a/compiler/circle-part-value-test/parts/SignatureDef_MultiOut_001.part b/compiler/circle-part-value-test/parts/SignatureDef_MultiOut_001.part
new file mode 100644
index 000000000..e469eeb26
--- /dev/null
+++ b/compiler/circle-part-value-test/parts/SignatureDef_MultiOut_001.part
@@ -0,0 +1,7 @@
+[partition]
+backends=cpu,acl_cl
+default=cpu
+comply=opcode
+
+[OPCODE]
+MAXIMUM=acl_cl
diff --git a/compiler/circle-part-value-test/test.lst b/compiler/circle-part-value-test/test.lst
index af2f5ba5c..b7a3f403a 100644
--- a/compiler/circle-part-value-test/test.lst
+++ b/compiler/circle-part-value-test/test.lst
@@ -35,3 +35,24 @@ add(Part_If_Add_Sub_001 Part_If_Add_Sub_001.001 3)
# WHILE with subgraphs
add(Part_While_000 Part_While_000 3)
add(Part_While_001 Part_While_001 3)
+
+# UNPACK with multiple outputs
+add(Net_UnpackAdd_001 Net_UnpackAdd_001 2)
+add(Net_UnpackAdd_001 Net_UnpackAdd_001.001 2)
+add(Net_UnpackAdd_001 Net_UnpackAdd_001.002 2)
+
+# Other multiple outputs
+add(Part_Split_Add_000 Part_Split_Add_000 2)
+
+# test SignatureDef, with any OPCODE
+add(SignatureDef_MultiOut_000 SignatureDef_MultiOut_000 0)
+add(SignatureDef_MultiOut_001 SignatureDef_MultiOut_001 0)
+
+# FC with nobias
+add(Part_Tanh_FC_nobias Part_Tanh_FC_nobias 1)
+add(Part_Tanh_FC_nobias Part_Tanh_FC_nobias_001 2)
+add(Part_Tanh_FC_nobias Part_Tanh_FC_nobias_002 2)
+add(Part_Tanh_FC_nobias Part_Tanh_FC_nobias_003 2)
+add(Part_Mul_Sqrt_FC_nobias_000 Part_Mul_Sqrt_FC_nobias_000_000 0)
+add(Part_Mul_Sqrt_FC_nobias_000 Part_Mul_Sqrt_FC_nobias_000_001 0)
+add(Part_Mul_Sqrt_FC_nobias_000 Part_Mul_Sqrt_FC_nobias_000_002 0)
diff --git a/compiler/circle-partitioner-test/CMakeLists.txt b/compiler/circle-partitioner-test/CMakeLists.txt
index ed8c97948..e29a66b41 100644
--- a/compiler/circle-partitioner-test/CMakeLists.txt
+++ b/compiler/circle-partitioner-test/CMakeLists.txt
@@ -57,8 +57,8 @@ foreach(IDX RANGE ${RECIPE_LENGTH_M1})
# Run partitioner
set(PART_CONN_JSON "${PART_OUT_PATH}/${PART_NAME}.conn.json")
add_custom_command(OUTPUT ${PART_CONN_JSON}
- COMMAND circle_partitioner "${PART_FILE}" "${PART_NAME}.circle" "${PART_OUT_PATH}"
- DEPENDS circle_partitioner ${CIRCLE_DST_PATH} ${PART_DST_PATH}
+ COMMAND circle-partitioner "${PART_FILE}" "${PART_NAME}.circle" "${PART_OUT_PATH}"
+ DEPENDS circle-partitioner ${CIRCLE_DST_PATH} ${PART_DST_PATH}
COMMENT "Parition ${RECIPE_NAME}.circle with ${PART_FILE}"
)
# NOTE this is checked in build time and not added with 'add_test' command
diff --git a/compiler/circle-partitioner-test/parts/Part_Add_SVDF_000.part b/compiler/circle-partitioner-test/parts/Part_Add_SVDF_000.part
new file mode 100644
index 000000000..01b8c704e
--- /dev/null
+++ b/compiler/circle-partitioner-test/parts/Part_Add_SVDF_000.part
@@ -0,0 +1,7 @@
+[partition]
+backends=cpu,acl_cl
+default=cpu
+comply=opcode
+
+[OPCODE]
+ADD=acl_cl
diff --git a/compiler/circle-partitioner-test/test.lst b/compiler/circle-partitioner-test/test.lst
index b731f8d0e..c0c185c7e 100644
--- a/compiler/circle-partitioner-test/test.lst
+++ b/compiler/circle-partitioner-test/test.lst
@@ -5,3 +5,7 @@
# add(RECIPE_NAME PART_NAME)
add(Net_InstanceNorm_003 Net_InstanceNorm_003)
+
+# NOTE SVDF partition test is done here as value test may need custom tolerance
+# TODO move Part_Add_SVDF_000 to circle-part-value-test when ready
+add(Part_Add_SVDF_000 Part_Add_SVDF_000)
diff --git a/compiler/circle-partitioner/CMakeLists.txt b/compiler/circle-partitioner/CMakeLists.txt
index 28a16c9fc..9b8f5afae 100644
--- a/compiler/circle-partitioner/CMakeLists.txt
+++ b/compiler/circle-partitioner/CMakeLists.txt
@@ -1,5 +1,24 @@
file(GLOB_RECURSE SOURCES "src/*.cpp")
+add_executable(circle-partitioner "${SOURCES}")
+target_link_libraries(circle-partitioner foder)
+target_link_libraries(circle-partitioner crew)
+target_link_libraries(circle-partitioner safemain)
+target_link_libraries(circle-partitioner luci_lang)
+target_link_libraries(circle-partitioner luci_log)
+target_link_libraries(circle-partitioner luci_import)
+target_link_libraries(circle-partitioner luci_service)
+target_link_libraries(circle-partitioner luci_pass)
+target_link_libraries(circle-partitioner luci_export)
+target_link_libraries(circle-partitioner luci_partition)
+target_link_libraries(circle-partitioner arser)
+target_link_libraries(circle-partitioner pepper_csv2vec)
+target_link_libraries(circle-partitioner vconone)
+target_link_libraries(circle-partitioner nncc_common)
+
+install(TARGETS circle-partitioner DESTINATION bin)
+
+# TODO remove circle_partitioner
add_executable(circle_partitioner "${SOURCES}")
target_link_libraries(circle_partitioner foder)
target_link_libraries(circle_partitioner crew)
diff --git a/compiler/circle-partitioner/README.md b/compiler/circle-partitioner/README.md
index 5fd312e33..2e0a98638 100644
--- a/compiler/circle-partitioner/README.md
+++ b/compiler/circle-partitioner/README.md
@@ -94,7 +94,7 @@ Net_InstanceNorm_003/
Command example
```
-./circle_partitioner Net_InstanceNorm_003.part Net_InstanceNorm_003.circle Net_InstanceNorm_003
+./circle-partitioner Net_InstanceNorm_003.part Net_InstanceNorm_003.circle Net_InstanceNorm_003
```
Result of _circle-partitioner_
@@ -163,3 +163,131 @@ as the `source` model: `[ "Input" ]`.
`Net_InstanceNorm_003.00002_acl_cl.circle` which they should be connected.
- And `outputs` `[ "Div" ]` should be connected to `inputs` of
third model `Net_InstanceNorm_003.00003_cpu.circle`.
+
+### Execution example
+
+Consider partitioning with backends of OneRT
+- `cpu`, `acl_cl`, `acl_neon`, `ruy`, `xnnpack`
+
+Let's try with this command:
+```
+circle_partitioner \
+ --partition Net_InstanceNorm_003.part \
+ --backends cpu,acl_cl \
+ --default cpu \
+ Net_InstanceNorm_003.circle Net_InstanceNorm_003
+```
+
+where `Net_InstanceNorm_003.part` is like this for initial design
+```
+[partition]
+backends=cpu,acl_cl
+default=cpu
+comply=opcode
+
+[OPCODE]
+ADD=acl_cl
+```
+where in `[partition]` section,
+- `backends` is available backends and can be overridden by `--backends`
+- `default` is default backend for OpCodes not assigned in `[OPCODE]` section can be overridden by `--default`
+- `comply` is which rule to apply, where only `opcode` is available for now
+
+#### Use Op name to assign backend
+
+```
+[OP]
+Reduction_indices=GPU
+```
+- there are very long names that may be inconvenient
+
+### Partitioned output
+
+#### Output files
+
+After partition is applied, output files will look something like these
+- `Net_InstanceNorm_003.part.00001_cpu.circle`
+- `Net_InstanceNorm_003.part.00002_acl_cl.circle`
+- `Net_InstanceNorm_003.part.00003_cpu.circle`
+- `Net_InstanceNorm_003.part.conn.ini`
+- `Net_InstanceNorm_003.part.conn.json`
+
+Assume only `Div` node is assigned to `acl_cl`
+
+#### Connection information of partitioned circle files
+
+##### Format with ini
+- `Net_InstanceNorm_003.conn.ini` provides connection of each circle files.
+```
+[source]
+file=Net_InstanceNorm_003.circle
+i1=Input
+o1=Add_as_terminal
+
+[models]
+m1=Net_InstanceNorm_003.part.00001_cpu.circle
+m2=Net_InstanceNorm_003.part.00002_acl_cl.circle
+m3=Net_InstanceNorm_003.part.00003_cpu.circle
+
+[Net_InstanceNorm_003.part.00001_cpu.circle]
+file=Net_InstanceNorm_003.part.00001_cpu.circle
+i1=Input
+o1=Pow
+o2=Sub
+
+[Net_InstanceNorm_003.part.00002_acl_cl.circle]
+file=Net_InstanceNorm_003.part.00002_acl_cl.circle
+i1=Sub
+i2=Pow
+o1=Div
+
+[Net_InstanceNorm_003.part.00003_cpu.circle]
+file=Net_InstanceNorm_003.part.00003_cpu.circle
+i1=Div
+o1=Add_as_terminal
+```
+
+Predefined section
+- `source`: Source circle model information. Has `file` as filename, `iN` for inputs and `oN` for outputs.
+- `models`: Partitioned circle models. Has `mN` for model filename.
+
+Partitioned Model section
+- `iN`: inputs of this model
+- `oN`: outputs of this model
+
+In graph diagram, output order of `Net_InstanceNorm_003.part.00001_cpu.circle`
+looks like `Pow,Sub` but `Div` Op in `Net_InstanceNorm_003.part.00002_acl_cl.circle`
+requires order of `Sub,Pow`.
+
+##### Format with JSON
+- Use JSON format, `Net_InstanceNorm_003.part.conn.json`
+```json
+{
+ "source" : {
+ "file" : "Net_InstanceNorm_003.circle",
+ "inputs" : [ "Input" ],
+ "outputs" : [ "Add_as_terminal" ]
+ },
+ "parts" : [
+ {
+ "file" : "Net_InstanceNorm_003.part.00001_cpu.circle",
+ "inputs" : [ "Input" ],
+ "outputs" : [ "Pow", "Sub" ],
+ },
+ {
+ "file" : "Net_InstanceNorm_003.part.00002_acl_cl.circle",
+ "inputs" : [ "Pow", "Sub" ],
+ "outputs" : [ "Div" ]
+ },
+ {
+ "file" : "Net_InstanceNorm_003.part.00003_cpu.circle",
+ "inputs" : [ "Div" ],
+ "outputs" : [ "Add_as_terminal" ]
+ }
+ ]
+}
+```
+
+### Future works
+
+How to partition with multiple inputs?
diff --git a/compiler/circle-quantizer-dredd-recipe-test/CMakeLists.txt b/compiler/circle-quantizer-dredd-recipe-test/CMakeLists.txt
new file mode 100644
index 000000000..5ec8b6ee5
--- /dev/null
+++ b/compiler/circle-quantizer-dredd-recipe-test/CMakeLists.txt
@@ -0,0 +1,144 @@
+if(NOT ENABLE_TEST)
+ return()
+endif(NOT ENABLE_TEST)
+
+nnas_include(TargetRequire)
+
+unset(REQUIRED_TARGETS)
+list(APPEND REQUIRED_TARGETS circle-inspect)
+list(APPEND REQUIRED_TARGETS circle-verify)
+list(APPEND REQUIRED_TARGETS circle-quantizer)
+list(APPEND REQUIRED_TARGETS record-minmax)
+list(APPEND REQUIRED_TARGETS dredd_rule_lib)
+TargetRequire_Return(${REQUIRED_TARGETS})
+
+unset(TEST_DEPS)
+unset(TEST_NAMES)
+
+get_target_property(ARTIFACTS_BIN_PATH testDataGenerator BINARY_DIR)
+
+set(options USE_QCONFIG)
+set(oneValueArgs DTYPE GRANULARITY)
+set(multiValueArgs "")
+
+macro(Add RECIPE)
+ cmake_parse_arguments(ARG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
+
+ set(QCONFIG_OPT "")
+ if(ARG_USE_QCONFIG)
+ set(QCONFIG_OPT "--config" "${ARTIFACTS_BIN_PATH}/${RECIPE}.qconf.json")
+ endif()
+
+ set(CIRCLE_PATH "${ARTIFACTS_BIN_PATH}/${RECIPE}.circle")
+ set(FAKE_QUANT_CIRCLE_PATH "${CMAKE_CURRENT_BINARY_DIR}/${RECIPE}.fq.circle")
+ set(RECORDED_CIRCLE_PATH "${CMAKE_CURRENT_BINARY_DIR}/${RECIPE}.recorded.circle")
+ set(QUANT_CIRCLE_PATH "${CMAKE_CURRENT_BINARY_DIR}/${RECIPE}.q.circle")
+
+ # Generate quantized .circle
+ add_custom_command(OUTPUT ${QUANT_CIRCLE_PATH}
+ COMMAND $<TARGET_FILE:circle-quantizer> --quantize_dequantize_weights float32 ${ARG_DTYPE} ${ARG_GRANULARITY} ${QCONFIG_OPT} ${CIRCLE_PATH} ${FAKE_QUANT_CIRCLE_PATH}
+ COMMAND $<TARGET_FILE:record-minmax> --input_model ${FAKE_QUANT_CIRCLE_PATH} --output_model ${RECORDED_CIRCLE_PATH}
+ COMMAND $<TARGET_FILE:circle-quantizer> --quantize_with_minmax float32 ${ARG_DTYPE} ${ARG_GRANULARITY} ${QCONFIG_OPT} ${RECORDED_CIRCLE_PATH} ${QUANT_CIRCLE_PATH}
+ DEPENDS
+ circle-quantizer
+ record-minmax
+ ${CIRCLE_PATH}
+ COMMENT "Generate ${RECIPE}.q.circle"
+ )
+
+ list(APPEND TEST_DEPS ${QUANT_CIRCLE_PATH})
+ list(APPEND TEST_NAMES ${RECIPE})
+endmacro(Add)
+
+# Macro to generate fully fake-quantized models
+macro(AddFakeQuant RECIPE)
+ set(CIRCLE_PATH "${ARTIFACTS_BIN_PATH}/${RECIPE}.circle")
+ # NOTE We use .q.circle because it is convention for output file (see testall.sh for more details)
+ set(FULL_FAKE_QUANT_CIRCLE_PATH "${CMAKE_CURRENT_BINARY_DIR}/${RECIPE}.q.circle")
+
+ # Generate fully fake-quantized .circle
+ add_custom_command(OUTPUT ${FULL_FAKE_QUANT_CIRCLE_PATH}
+ COMMAND $<TARGET_FILE:circle-quantizer> --fake_quantize ${CIRCLE_PATH} ${FULL_FAKE_QUANT_CIRCLE_PATH}
+ DEPENDS
+ circle-quantizer
+ ${CIRCLE_PATH}
+ COMMENT "Generate ${RECIPE}.q.circle"
+ )
+
+ list(APPEND TEST_DEPS ${FULL_FAKE_QUANT_CIRCLE_PATH})
+ list(APPEND TEST_NAMES ${RECIPE})
+endmacro(AddFakeQuant)
+
+# Read "test.lst"
+include("test.lst")
+
+##
+## Copy testall
+##
+set(TEST_RUNNER "${CMAKE_CURRENT_BINARY_DIR}/testall.sh")
+set(TEST_RUNNER_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/testall.sh")
+
+add_custom_command(
+ OUTPUT ${TEST_RUNNER}
+ COMMAND ${CMAKE_COMMAND} -E copy "${TEST_RUNNER_SOURCE}" "${TEST_RUNNER}"
+ DEPENDS ${TEST_RUNNER_SOURCE}
+ COMMENT "Generate test runner"
+)
+
+list(APPEND TEST_DEPS "${TEST_RUNNER}")
+
+###
+### Generate test.config
+###
+set(TEST_CONFIG "${CMAKE_CURRENT_BINARY_DIR}/test.config")
+
+add_custom_command(
+ OUTPUT ${TEST_CONFIG}
+ COMMAND ${CMAKE_COMMAND} -E remove -f ${TEST_CONFIG}
+ COMMAND ${CMAKE_COMMAND} -E echo 'CIRCLE_INSPECT_PATH=\"$<TARGET_FILE:circle-inspect>\"' >> ${TEST_CONFIG}
+ COMMAND ${CMAKE_COMMAND} -E echo 'CIRCLE_VERIFY_PATH=\"$<TARGET_FILE:circle-verify>\"' >> ${TEST_CONFIG}
+ COMMAND ${CMAKE_COMMAND} -E echo 'RECORD_MINMAX_PATH=\"$<TARGET_FILE:record-minmax>\"' >> ${TEST_CONFIG}
+ COMMAND ${CMAKE_COMMAND} -E echo 'CIRCLE_QUANTIZER_PATH=\"$<TARGET_FILE:circle-quantizer>\"' >> ${TEST_CONFIG}
+ DEPENDS
+ circle-inspect
+ circle-verify
+ record-minmax
+ circle-quantizer
+ COMMENT "Generate test configuration"
+)
+
+list(APPEND TEST_DEPS "${TEST_CONFIG}")
+
+#
+# copy rule-lib.sh (a library of shell script functions)
+#
+
+# getting path for rule-lib.sh in dredd-rule-lib
+get_target_property(DREDD_RULE_LIB_DIR dredd_rule_lib BINARY_DIR)
+
+set(RULE_LIB_SOURCE_PATH "${DREDD_RULE_LIB_DIR}/rule-lib.sh")
+set(RULE_LIB_BINARY_PATH "${CMAKE_CURRENT_BINARY_DIR}/rule-lib.sh")
+
+add_custom_command(
+ OUTPUT ${RULE_LIB_BINARY_PATH}
+ COMMAND ${CMAKE_COMMAND} -E copy "${RULE_LIB_SOURCE_PATH}" "${RULE_LIB_BINARY_PATH}"
+ DEPENDS ${RULE_LIB_SOURCE_PATH}
+ COMMENT "Generate rule lib"
+)
+
+list(APPEND TEST_DEPS "${RULE_LIB_BINARY_PATH}")
+
+# Generate dependencies
+add_custom_target(circle_quantizer_dredd_recipe_test ALL DEPENDS ${TEST_DEPS})
+add_dependencies(circle_quantizer_dredd_recipe_test common_artifacts_deps)
+
+get_target_property(ARTIFACTS_BIN_PATH testDataGenerator BINARY_DIR)
+
+# Run tests
+add_test(
+ NAME circle_quantizer_dredd_recipe_test
+ COMMAND ${TEST_RUNNER}
+ ${TEST_CONFIG}
+ ${ARTIFACTS_BIN_PATH}
+ ${TEST_NAMES}
+)
diff --git a/compiler/circle-quantizer-dredd-recipe-test/README.md b/compiler/circle-quantizer-dredd-recipe-test/README.md
new file mode 100644
index 000000000..61525495a
--- /dev/null
+++ b/compiler/circle-quantizer-dredd-recipe-test/README.md
@@ -0,0 +1,37 @@
+# circle-quantizer-dredd-recipe-test
+
+It tests non-functional conditions of a quantized circle model generated by circle-quantizer.
+
+## How to add a test?
+
+1. Create a directory under `res/TensorFlowLiteRecipes/` or `res/CircleRecipes/`.
+
+2. Make a recipe (`test.recipe`) for fp32 model under the directory.
+
+3. Make a rule (`test.rule`) you want to test under the directory. (For more information on dredd-test-rules, see _dredd-rule-lib_ module.)
+
+4. Add test to `test.lst` in this module with `Add` macro.
+
+```
+Add(RECIPE_DIR DTYPE dtype GRANULARITY granularity USE_QCONFIG)
+```
+
+- `RECIPE_DIR`: Path to the directory where the recipe file is saved.
+- `DTYPE`: Default quantization dtype (uint8, int16)
+- `GRANULARITY`: Quantization granularity (channel, layer)
+- `USE_QCONFIG`: (Optional) Whether to use a quantization configuration file or not. If this is set, `test.qconf.json` should exist under `RECIPE_DIR`
+
+## Example
+
+```
+# TensorFlowLiteRecipes
+res/TensorFlowLiteRecipes/Quant_Conv_Mul_Add_000
+├── test.recipe # What you want to test
+└── test.rule # Non-functional conditions to be satisfied
+└── test.qconf.json # Quantization configuration file (optional)
+
+# test.lst
+...
+Add(Quant_Conv_Mul_Add_000 DTYPE uint8 GRANULARITY channel USE_QCONFIG)
+...
+```
diff --git a/compiler/circle-quantizer-dredd-recipe-test/requires.cmake b/compiler/circle-quantizer-dredd-recipe-test/requires.cmake
new file mode 100644
index 000000000..7450f7322
--- /dev/null
+++ b/compiler/circle-quantizer-dredd-recipe-test/requires.cmake
@@ -0,0 +1,6 @@
+require("circle-quantizer")
+require("record-minmax")
+require("circle-inspect")
+require("circle-verify")
+require("common-artifacts")
+require("dredd-rule-lib")
diff --git a/compiler/circle-quantizer-dredd-recipe-test/test.lst b/compiler/circle-quantizer-dredd-recipe-test/test.lst
new file mode 100644
index 000000000..188103016
--- /dev/null
+++ b/compiler/circle-quantizer-dredd-recipe-test/test.lst
@@ -0,0 +1,15 @@
+## EXAMPLE
+#
+# Add(RECIPE_DIR DTYPE dtype GRANULARITY granularity USE_QCONFIG(optional))
+# AddFakeQuant(RECIPE_DIR)
+#
+
+## TFLITE RECIPE
+
+Add(Quant_Conv_Mul_Add_000 DTYPE uint8 GRANULARITY channel USE_QCONFIG)
+Add(Quant_Conv_Mul_Add_001 DTYPE uint8 GRANULARITY channel USE_QCONFIG)
+Add(Quant_Conv_Mul_Add_002 DTYPE uint8 GRANULARITY channel USE_QCONFIG)
+Add(Quant_Split_Add_000 DTYPE uint8 GRANULARITY channel USE_QCONFIG)
+Add(Quant_Split_Add_001 DTYPE uint8 GRANULARITY channel USE_QCONFIG)
+
+AddFakeQuant(Quant_Add_000)
diff --git a/compiler/circle-quantizer-dredd-recipe-test/testall.sh b/compiler/circle-quantizer-dredd-recipe-test/testall.sh
new file mode 100755
index 000000000..e5d5cf2b8
--- /dev/null
+++ b/compiler/circle-quantizer-dredd-recipe-test/testall.sh
@@ -0,0 +1,100 @@
+#!/bin/bash
+
+# Need at least 2 arguments
+if [[ $# -lt 2 ]]; then
+ echo "USAGE: $0 ..."
+ echo
+ echo "ARGUMENTS:"
+ echo " [test.config path]"
+ echo " [WORKDIR]"
+ echo " [Prefix1]"
+ echo " [Prefix2]"
+ echo " ..."
+ exit 255
+fi
+
+WORKDIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
+CONFIG_PATH="$1"; shift
+RESOURCE_DIR="$1"; shift
+
+source "${CONFIG_PATH}"
+
+echo "-- Found circle-inspect: ${CIRCLE_INSPECT_PATH}"
+echo "-- Found circle-verify: ${CIRCLE_VERIFY_PATH}"
+echo "-- Found circle-quantizer: ${CIRCLE_QUANTIZER_PATH}"
+echo "-- Found record-minmax: ${RECORD_MINMAX_PATH}"
+echo "-- Found common-artifacts: ${RESOURCE_DIR}"
+
+TESTED=()
+PASSED=()
+FAILED=()
+
+pushd ${WORKDIR}
+while [[ $# -ne 0 ]]; do
+ PREFIX="$1"; shift
+
+ TESTED+=("${PREFIX}")
+
+ PASSED_TAG="${PREFIX}.passed"
+
+ rm -f "${PASSED_TAG}"
+
+ cat > "${PREFIX}.log" <(
+ exec 2>&1
+
+ echo "-- Found circle: ${PREFIX}.q.circle"
+
+ # Exit immediately if any command fails
+ set -e
+ # Show commands
+ set -x
+
+ #
+ # Check if rule is satisfied
+ #
+
+ # Note: turn off 'command printing'. Otherwise printing will be so messy
+ set +x
+
+ # (COMPILED_FILE, INSPECT_PROG_PATH, VERIFY_PROG_PATH, ERROR_LOG) must be set for rule-lib.sh
+ COMPILED_FILE="${PREFIX}.q.circle"
+ INSPECT_PROG_PATH=${CIRCLE_INSPECT_PATH}
+ VERIFY_PROG_PATH=${CIRCLE_VERIFY_PATH}
+ ERROR_LOG="${PREFIX}.error"
+
+ rm -f "${ERROR_LOG}"
+
+ # in case error while running rule-lib.sh, prints error msg
+ trap 'echo "** ERROR **" ; cat "${ERROR_LOG}"' ERR
+
+ source rule-lib.sh
+ source "${RESOURCE_DIR}/${PREFIX}.rule"
+
+ # unset
+ trap - ERR
+ set -x
+
+ # At this point, the exit code of all commands is 0
+ # If not 0, execution of this script ends because of "set -e"
+ touch "${PASSED_TAG}"
+ )
+
+ if [[ -f "${PASSED_TAG}" ]]; then
+ PASSED+=("$PREFIX")
+ else
+ FAILED+=("$PREFIX")
+ fi
+done
+popd
+
+if [[ ${#TESTED[@]} -ne ${#PASSED[@]} ]]; then
+ echo "FAILED"
+ for TEST in "${FAILED[@]}"
+ do
+ echo "- ${TEST}"
+ done
+ exit 255
+fi
+
+echo "PASSED"
+exit 0
diff --git a/compiler/circle-quantizer/CMakeLists.txt b/compiler/circle-quantizer/CMakeLists.txt
index a5f5f61c4..14e00972b 100644
--- a/compiler/circle-quantizer/CMakeLists.txt
+++ b/compiler/circle-quantizer/CMakeLists.txt
@@ -1,11 +1,19 @@
+nnas_find_package(Jsoncpp)
+if(NOT Jsoncpp_FOUND)
+ message(STATUS "Build jsoncpp: FAILED (missing jsoncpp)")
+ return()
+endif(NOT Jsoncpp_FOUND)
+
set (SOURCES src/CircleQuantizer.cpp)
add_executable(circle-quantizer "${SOURCES}")
+target_include_directories(circle-quantizer PRIVATE ${Jsoncpp_INCLUDE_DIRS})
+
+target_link_libraries(circle-quantizer ${Jsoncpp_STATIC_LIB})
target_link_libraries(circle-quantizer foder)
target_link_libraries(circle-quantizer safemain)
target_link_libraries(circle-quantizer oops)
target_link_libraries(circle-quantizer loco)
-target_link_libraries(circle-quantizer mio_circle)
target_link_libraries(circle-quantizer luci_import)
target_link_libraries(circle-quantizer luci_service)
target_link_libraries(circle-quantizer luci_pass)
diff --git a/compiler/circle-quantizer/src/CircleQuantizer.cpp b/compiler/circle-quantizer/src/CircleQuantizer.cpp
index 57ac30a87..e0c85cb6e 100644
--- a/compiler/circle-quantizer/src/CircleQuantizer.cpp
+++ b/compiler/circle-quantizer/src/CircleQuantizer.cpp
@@ -17,7 +17,7 @@
#include <foder/FileLoader.h>
#include <luci/Importer.h>
-#include <luci/CircleOptimizer.h>
+#include <luci/CircleQuantizer.h>
#include <luci/Service/Validate.h>
#include <luci/CircleExporter.h>
#include <luci/CircleFileExpContract.h>
@@ -26,6 +26,7 @@
#include <oops/InternalExn.h>
#include <arser/arser.h>
#include <vconone/vconone.h>
+#include <json.h>
#include <functional>
#include <iostream>
@@ -34,8 +35,41 @@
using OptionHook = std::function<int(const char **)>;
-using Algorithms = luci::CircleOptimizer::Options::Algorithm;
-using AlgorithmParameters = luci::CircleOptimizer::Options::AlgorithmParameters;
+using LayerParam = luci::CircleQuantizer::Options::LayerParam;
+using Algorithms = luci::CircleQuantizer::Options::Algorithm;
+using AlgorithmParameters = luci::CircleQuantizer::Options::AlgorithmParameters;
+
+std::vector<std::shared_ptr<LayerParam>> read_layer_params(std::string &filename)
+{
+ Json::Value root;
+ std::ifstream ifs(filename);
+
+ // Failed to open cfg file
+ if (not ifs.is_open())
+ throw std::runtime_error("Cannot open config file. " + filename);
+
+ Json::CharReaderBuilder builder;
+ JSONCPP_STRING errs;
+
+ // Failed to parse
+ if (not parseFromStream(builder, ifs, &root, &errs))
+ throw std::runtime_error("Cannot parse config file (json format). " + errs);
+
+ auto layers = root["layers"];
+ std::vector<std::shared_ptr<LayerParam>> p;
+ for (auto layer : layers)
+ {
+ auto l = std::make_shared<LayerParam>();
+ {
+ l->name = layer["name"].asString();
+ l->dtype = layer["dtype"].asString();
+ l->granularity = layer["granularity"].asString();
+ }
+ p.emplace_back(l);
+ }
+
+ return p;
+}
void print_exclusive_options(void)
{
@@ -56,15 +90,20 @@ int entry(int argc, char **argv)
{
// Simple argument parser (based on map)
std::map<std::string, OptionHook> argparse;
- luci::CircleOptimizer optimizer;
+ luci::CircleQuantizer quantizer;
- auto options = optimizer.options();
+ auto options = quantizer.options();
auto settings = luci::UserSettings::settings();
const std::string qdqw = "--quantize_dequantize_weights";
const std::string qwmm = "--quantize_with_minmax";
const std::string rq = "--requantize";
const std::string fq = "--force_quantparam";
+ const std::string cq = "--copy_quantparam";
+ const std::string fake_quant = "--fake_quantize";
+ const std::string cfg = "--config";
+
+ const std::string tf_maxpool = "--TF-style_maxpool";
const std::string gpd = "--generate_profile_data";
@@ -99,6 +138,19 @@ int entry(int argc, char **argv)
"Three arguments required: input_model_dtype(float32) "
"output_model_dtype(uint8) granularity(layer, channel)");
+ arser.add_argument(tf_maxpool)
+ .nargs(0)
+ .required(false)
+ .default_value(false)
+ .help("Force MaxPool Op to have the same input/output quantparams. NOTE: This feature can "
+ "degrade accuracy of some models");
+
+ arser.add_argument(fake_quant)
+ .nargs(0)
+ .required(false)
+ .help("Convert a quantized model to a fake-quantized model. NOTE: This feature will "
+ "generate an fp32 model.");
+
arser.add_argument(rq)
.nargs(2)
.type(arser::DataType::STR_VEC)
@@ -116,6 +168,15 @@ int entry(int argc, char **argv)
"Three arguments required: tensor_name(string), "
"scale(float) zero_point(int)");
+ arser.add_argument(cq)
+ .nargs(2)
+ .type(arser::DataType::STR_VEC)
+ .required(false)
+ .accumulated(true)
+ .help("Copy quantization parameter from a tensor to another tensor."
+ "Two arguments required: source_tensor_name(string), "
+ "destination_tensor_name(string)");
+
arser.add_argument("--input_type")
.nargs(1)
.type(arser::DataType::STR)
@@ -128,6 +189,12 @@ int entry(int argc, char **argv)
.required(false)
.help("Output type of quantized model (uint8 or int16)");
+ arser.add_argument(cfg)
+ .nargs(1)
+ .type(arser::DataType::STR)
+ .required(false)
+ .help("Path to the quantization configuration file");
+
arser.add_argument("input").nargs(1).type(arser::DataType::STR).help("Input circle model");
arser.add_argument("output").nargs(1).type(arser::DataType::STR).help("Output circle model");
@@ -146,11 +213,13 @@ int entry(int argc, char **argv)
}
{
- // only one of qdqw, qwmm, rq, fq option can be used
+ // only one of qdqw, qwmm, rq, fq, cq, fake_quant option can be used
int32_t opt_used = arser[qdqw] ? 1 : 0;
opt_used += arser[qwmm] ? 1 : 0;
opt_used += arser[rq] ? 1 : 0;
opt_used += arser[fq] ? 1 : 0;
+ opt_used += arser[cq] ? 1 : 0;
+ opt_used += arser[fake_quant] ? 1 : 0;
if (opt_used != 1)
{
print_exclusive_options();
@@ -178,6 +247,22 @@ int entry(int argc, char **argv)
options->param(AlgorithmParameters::Quantize_input_model_dtype, values.at(0));
options->param(AlgorithmParameters::Quantize_output_model_dtype, values.at(1));
options->param(AlgorithmParameters::Quantize_granularity, values.at(2));
+
+ if (arser[cfg])
+ {
+ auto filename = arser.get<std::string>(cfg);
+ try
+ {
+ auto layer_params = read_layer_params(filename);
+
+ options->layer_params(AlgorithmParameters::Quantize_layer_params, layer_params);
+ }
+ catch (const std::runtime_error &e)
+ {
+ std::cerr << e.what() << '\n';
+ return 255;
+ }
+ }
}
if (arser[qwmm])
@@ -201,6 +286,25 @@ int entry(int argc, char **argv)
if (arser["--output_type"])
options->param(AlgorithmParameters::Quantize_output_type,
arser.get<std::string>("--output_type"));
+
+ if (arser[tf_maxpool] and arser.get<bool>(tf_maxpool))
+ options->param(AlgorithmParameters::Quantize_TF_style_maxpool, "True");
+
+ if (arser[cfg])
+ {
+ auto filename = arser.get<std::string>(cfg);
+ try
+ {
+ auto layer_params = read_layer_params(filename);
+
+ options->layer_params(AlgorithmParameters::Quantize_layer_params, layer_params);
+ }
+ catch (const std::runtime_error &e)
+ {
+ std::cerr << e.what() << '\n';
+ return 255;
+ }
+ }
}
if (arser[rq])
@@ -245,6 +349,34 @@ int entry(int argc, char **argv)
options->params(AlgorithmParameters::Quantize_zero_points, zero_points);
}
+ if (arser[cq])
+ {
+ auto values = arser.get<std::vector<std::vector<std::string>>>(cq);
+
+ std::vector<std::string> src;
+ std::vector<std::string> dst;
+
+ for (auto const value : values)
+ {
+ if (value.size() != 2)
+ {
+ std::cerr << arser;
+ return 255;
+ }
+
+ src.push_back(value[0]);
+ dst.push_back(value[1]);
+ }
+
+ options->enable(Algorithms::CopyQuantParam);
+
+ options->params(AlgorithmParameters::Quantize_src_tensor_names, src);
+ options->params(AlgorithmParameters::Quantize_dst_tensor_names, dst);
+ }
+
+ if (arser[fake_quant])
+ options->enable(Algorithms::ConvertToFakeQuantizedModel);
+
std::string input_path = arser.get<std::string>("input");
std::string output_path = arser.get<std::string>("output");
@@ -279,7 +411,7 @@ int entry(int argc, char **argv)
auto graph = module->graph(idx);
// quantize the graph
- optimizer.quantize(graph);
+ quantizer.quantize(graph);
if (!luci::validate(graph))
{
diff --git a/compiler/circle-tensordump/CMakeLists.txt b/compiler/circle-tensordump/CMakeLists.txt
index 4524260c4..676aecd53 100644
--- a/compiler/circle-tensordump/CMakeLists.txt
+++ b/compiler/circle-tensordump/CMakeLists.txt
@@ -1,6 +1,6 @@
-if(NOT TARGET mio_circle)
+if(NOT TARGET mio_circle04)
return()
-endif(NOT TARGET mio_circle)
+endif(NOT TARGET mio_circle04)
nnas_find_package(HDF5 COMPONENTS STATIC QUIET)
@@ -19,7 +19,8 @@ target_include_directories(circle-tensordump PRIVATE ${HDF5_INCLUDE_DIRS})
target_link_libraries(circle-tensordump PRIVATE ${HDF5_CXX_LIBRARIES})
target_link_libraries(circle-tensordump PRIVATE arser)
target_link_libraries(circle-tensordump PRIVATE foder)
-target_link_libraries(circle-tensordump PRIVATE mio_circle)
+target_link_libraries(circle-tensordump PRIVATE mio_circle04)
+target_link_libraries(circle-tensordump PRIVATE mio_circle04_helper)
target_link_libraries(circle-tensordump PRIVATE safemain)
install(TARGETS circle-tensordump DESTINATION bin)
diff --git a/compiler/circle-tensordump/requires.cmake b/compiler/circle-tensordump/requires.cmake
index 1c754f518..183dfe227 100644
--- a/compiler/circle-tensordump/requires.cmake
+++ b/compiler/circle-tensordump/requires.cmake
@@ -1,4 +1,4 @@
require("arser")
require("foder")
-require("mio-circle")
+require("mio-circle04")
require("safemain")
diff --git a/compiler/circle-tensordump/src/Reader.cpp b/compiler/circle-tensordump/src/Reader.cpp
index 429736bfe..47b876054 100644
--- a/compiler/circle-tensordump/src/Reader.cpp
+++ b/compiler/circle-tensordump/src/Reader.cpp
@@ -16,66 +16,14 @@
#include "Reader.h"
+#include <mio_circle/Helper.h>
+
#include <sstream>
#include <string>
namespace circletensordump
{
-bool is_valid(const circle::OperatorCode *opcode)
-{
- circle::BuiltinOperator code = opcode->builtin_code();
- return (circle::BuiltinOperator_MIN <= code && code <= circle::BuiltinOperator_MAX);
-}
-
-bool is_custom(const circle::OperatorCode *opcode)
-{
- circle::BuiltinOperator code = opcode->builtin_code();
- return (code == circle::BuiltinOperator_CUSTOM);
-}
-
-std::string opcode_name(const circle::OperatorCode *opcode)
-{
- assert(opcode);
-
- if (!is_valid(opcode))
- {
- std::ostringstream oss;
- oss << "(invalid)";
- return oss.str();
- }
-
- if (is_custom(opcode))
- {
- if (!opcode->custom_code())
- return "(invalid custom)";
-
- std::string custom_op = "CUSTOM(";
- custom_op += opcode->custom_code()->c_str();
- custom_op += ")";
- return custom_op;
- }
-
- circle::BuiltinOperator code = opcode->builtin_code();
- return circle::EnumNameBuiltinOperator(code);
-}
-
-const char *tensor_type(const circle::Tensor *tensor)
-{
- return circle::EnumNameTensorType(tensor->type());
-}
-
-const char *tensor_name(const circle::Tensor *tensor)
-{
- static const char *kEmptyTensorName = "(noname)";
-
- auto name = tensor->name();
- if (name)
- return name->c_str();
-
- return kEmptyTensorName;
-}
-
Reader::Reader(const circle::Model *model)
{
_subgraphs = model->subgraphs();
@@ -122,7 +70,7 @@ circle::BuiltinOperator Reader::builtin_code(const circle::Operator *op) const
assert(index < _op_codes.size());
const circle::OperatorCode *opcode = _op_codes.at(index);
- return opcode->builtin_code();
+ return mio::circle::builtin_code_neutral(opcode);
}
std::string Reader::opcode_name(const circle::Operator *op) const
@@ -131,14 +79,14 @@ std::string Reader::opcode_name(const circle::Operator *op) const
assert(index < _op_codes.size());
const circle::OperatorCode *opcode = _op_codes.at(index);
- if (!is_valid(opcode))
+ if (!mio::circle::is_valid(opcode))
{
std::ostringstream oss;
oss << "(invalid: " << index << ")";
return oss.str();
}
- return circletensordump::opcode_name(opcode);
+ return mio::circle::opcode_name(opcode);
}
bool Reader::select_subgraph(uint32_t sgindex)
diff --git a/compiler/circle-tensordump/src/Reader.h b/compiler/circle-tensordump/src/Reader.h
index bbb039552..c868bc277 100644
--- a/compiler/circle-tensordump/src/Reader.h
+++ b/compiler/circle-tensordump/src/Reader.h
@@ -36,12 +36,6 @@ template <typename T> std::vector<T> as_index_vector(const flatbuffers::Vector<T
return ret;
}
-bool is_valid(const circle::OperatorCode *opcode);
-bool is_custom(const circle::OperatorCode *opcode);
-std::string opcode_name(const circle::OperatorCode *opcode);
-const char *tensor_type(const circle::Tensor *tensor);
-const char *tensor_name(const circle::Tensor *tensor);
-
/**
* @brief Loads Circle file and provides helpers to access attributes
*/
diff --git a/compiler/circle-verify/CMakeLists.txt b/compiler/circle-verify/CMakeLists.txt
index f22174865..5d0eb9468 100644
--- a/compiler/circle-verify/CMakeLists.txt
+++ b/compiler/circle-verify/CMakeLists.txt
@@ -1,13 +1,14 @@
-if(NOT TARGET mio_circle)
+if(NOT TARGET mio_circle04)
+ message(STATUS "Skip circle-verify: mio_circle04 not found")
return()
-endif(NOT TARGET mio_circle)
+endif(NOT TARGET mio_circle04)
file(GLOB_RECURSE SOURCES "src/*.cpp")
add_executable(circle-verify ${SOURCES})
target_include_directories(circle-verify PRIVATE src)
target_link_libraries(circle-verify arser)
-target_link_libraries(circle-verify mio_circle)
+target_link_libraries(circle-verify mio_circle04)
target_link_libraries(circle-verify safemain)
target_link_libraries(circle-verify cwrap)
target_link_libraries(circle-verify foder)
diff --git a/compiler/circle-verify/requires.cmake b/compiler/circle-verify/requires.cmake
index e1b7fb212..74c8f448b 100644
--- a/compiler/circle-verify/requires.cmake
+++ b/compiler/circle-verify/requires.cmake
@@ -1,5 +1,5 @@
require("arser")
-require("mio-circle")
+require("mio-circle04")
require("safemain")
require("cwrap")
require("foder")
diff --git a/compiler/circle2circle-dredd-recipe-test/CMakeLists.txt b/compiler/circle2circle-dredd-recipe-test/CMakeLists.txt
index ee73d63e3..9ccfd0008 100644
--- a/compiler/circle2circle-dredd-recipe-test/CMakeLists.txt
+++ b/compiler/circle2circle-dredd-recipe-test/CMakeLists.txt
@@ -1,3 +1,7 @@
+if(NOT ENABLE_TEST)
+ return()
+endif(NOT ENABLE_TEST)
+
nnas_include(TargetRequire)
unset(REQUIRED_TARGETS)
diff --git a/compiler/circle2circle/CMakeLists.txt b/compiler/circle2circle/CMakeLists.txt
index 358fc4e2c..cd79967b7 100644
--- a/compiler/circle2circle/CMakeLists.txt
+++ b/compiler/circle2circle/CMakeLists.txt
@@ -11,7 +11,6 @@ target_link_libraries(circle2circle oops)
target_link_libraries(circle2circle hermes)
target_link_libraries(circle2circle hermes_std)
target_link_libraries(circle2circle loco)
-target_link_libraries(circle2circle mio_circle)
target_link_libraries(circle2circle luci_env)
target_link_libraries(circle2circle luci_import)
target_link_libraries(circle2circle luci_service)
@@ -36,7 +35,6 @@ target_link_libraries(circle2circle_test oops)
target_link_libraries(circle2circle_test hermes)
target_link_libraries(circle2circle_test hermes_std)
target_link_libraries(circle2circle_test loco)
-target_link_libraries(circle2circle_test mio_circle)
target_link_libraries(circle2circle_test luci_env)
target_link_libraries(circle2circle_test luci_import)
target_link_libraries(circle2circle_test luci_service)
diff --git a/compiler/circle2circle/requires.cmake b/compiler/circle2circle/requires.cmake
index 36a9efd16..b6c61198f 100644
--- a/compiler/circle2circle/requires.cmake
+++ b/compiler/circle2circle/requires.cmake
@@ -3,7 +3,6 @@ require("loco")
require("locop")
require("logo-core")
require("safemain")
-require("mio-circle")
require("oops")
require("hermes")
require("hermes-std")
diff --git a/compiler/circle2circle/src/Circle2Circle.cpp b/compiler/circle2circle/src/Circle2Circle.cpp
index a5ddb26dc..ae677a321 100644
--- a/compiler/circle2circle/src/Circle2Circle.cpp
+++ b/compiler/circle2circle/src/Circle2Circle.cpp
@@ -104,6 +104,12 @@ int entry(int argc, char **argv)
.default_value(false)
.help("This will fold Depthwise Convolution operator with constant inputs");
+ arser.add_argument("--fold_gather")
+ .nargs(0)
+ .required(false)
+ .default_value(false)
+ .help("This will fold Gather operator");
+
arser.add_argument("--fold_sparse_to_dense")
.nargs(0)
.required(false)
@@ -203,6 +209,12 @@ int entry(int argc, char **argv)
.default_value(false)
.help("This will remove Quantize-Dequantize sequence");
+ arser.add_argument("--remove_redundant_quantize")
+ .nargs(0)
+ .required(false)
+ .default_value(false)
+ .help("This will remove redundant Quantize operators");
+
arser.add_argument("--remove_redundant_reshape")
.nargs(0)
.required(false)
@@ -452,6 +464,8 @@ int entry(int argc, char **argv)
options->enable(Algorithms::FoldDequantize);
if (arser.get<bool>("--fold_dwconv"))
options->enable(Algorithms::FoldDepthwiseConv2D);
+ if (arser.get<bool>("--fold_gather"))
+ options->enable(Algorithms::FoldGather);
if (arser.get<bool>("--fold_sparse_to_dense"))
options->enable(Algorithms::FoldSparseToDense);
if (arser.get<bool>("--forward_reshape_to_unaryop"))
@@ -484,6 +498,8 @@ int entry(int argc, char **argv)
options->enable(Algorithms::RemoveFakeQuant);
if (arser.get<bool>("--remove_quantdequant"))
options->enable(Algorithms::RemoveQuantDequantSeq);
+ if (arser.get<bool>("--remove_redundant_quantize"))
+ options->enable(Algorithms::RemoveRedundantQuantize);
if (arser.get<bool>("--remove_redundant_reshape"))
options->enable(Algorithms::RemoveRedundantReshape);
if (arser.get<bool>("--remove_redundant_transpose"))
diff --git a/compiler/circlechef/CMakeLists.txt b/compiler/circlechef/CMakeLists.txt
index 3e2ddcbb3..b124d3027 100644
--- a/compiler/circlechef/CMakeLists.txt
+++ b/compiler/circlechef/CMakeLists.txt
@@ -1,12 +1,14 @@
nnas_find_package(Protobuf QUIET)
if(NOT Protobuf_FOUND)
+ message(STATUS "circlechef: SKIP (missing Protobuf)")
return()
endif(NOT Protobuf_FOUND)
-if(NOT TARGET mio_circle)
+if(NOT TARGET mio_circle04)
+ message(STATUS "circlechef: SKIP (missing mio-circle04)")
return()
-endif(NOT TARGET mio_circle)
+endif(NOT TARGET mio_circle04)
# Recipe Parser
add_subdirectory(proto)
diff --git a/compiler/circlechef/circle/CMakeLists.txt b/compiler/circlechef/circle/CMakeLists.txt
index 98a284c30..12dc7217b 100644
--- a/compiler/circlechef/circle/CMakeLists.txt
+++ b/compiler/circlechef/circle/CMakeLists.txt
@@ -4,6 +4,7 @@ add_library(circlechef_circle STATIC ${SOURCES})
target_include_directories(circlechef_circle PUBLIC include)
target_include_directories(circlechef_circle PRIVATE src)
target_link_libraries(circlechef_circle circlechef_proto)
-target_link_libraries(circlechef_circle mio_circle)
+target_link_libraries(circlechef_circle mio_circle04)
+target_link_libraries(circlechef_circle mio_circle04_helper)
target_link_libraries(circlechef_circle cwrap)
target_link_libraries(circlechef_circle souschef)
diff --git a/compiler/circlechef/circle/src/CircleImport.cpp b/compiler/circlechef/circle/src/CircleImport.cpp
index e970fbce3..f8756ef94 100644
--- a/compiler/circlechef/circle/src/CircleImport.cpp
+++ b/compiler/circlechef/circle/src/CircleImport.cpp
@@ -18,38 +18,13 @@
#include "Convert.h"
+#include <mio_circle/Helper.h>
+
#include <sstream>
namespace circlechef
{
-const char *kEmptyTensorName = "(noname)";
-
-const char *tensor_type(const circle::Tensor *tensor)
-{
- return circle::EnumNameTensorType(tensor->type());
-}
-
-const char *tensor_name(const circle::Tensor *tensor)
-{
- auto name = tensor->name();
- if (name)
- return name->c_str();
- return kEmptyTensorName;
-}
-
-bool is_valid(const circle::OperatorCode *opcode)
-{
- circle::BuiltinOperator code = opcode->builtin_code();
- return (circle::BuiltinOperator_MIN <= code && code <= circle::BuiltinOperator_MAX);
-}
-
-bool is_custom(const circle::OperatorCode *opcode)
-{
- circle::BuiltinOperator code = opcode->builtin_code();
- return (code == circle::BuiltinOperator_CUSTOM);
-}
-
CircleImport::CircleImport(const circle::Model *model)
{
_subgraphs = model->subgraphs();
@@ -92,7 +67,7 @@ circle::BuiltinOperator CircleImport::builtin_code(const circle::Operator *op) c
assert(index < _op_codes.size());
const circle::OperatorCode *opcode = _op_codes.at(index);
- return opcode->builtin_code();
+ return mio::circle::builtin_code_neutral(opcode);
}
std::string CircleImport::opcode_name(const circle::Operator *op) const
@@ -101,14 +76,14 @@ std::string CircleImport::opcode_name(const circle::Operator *op) const
assert(index < _op_codes.size());
const circle::OperatorCode *opcode = _op_codes.at(index);
- if (!is_valid(opcode))
+ if (!mio::circle::is_valid(opcode))
{
std::ostringstream oss;
oss << "(invalid: " << index << ")";
return oss.str();
}
- if (is_custom(opcode))
+ if (mio::circle::is_custom(opcode))
{
if (!opcode->custom_code())
return "(invalid custom)";
diff --git a/compiler/circlechef/circle/src/CircleImport.h b/compiler/circlechef/circle/src/CircleImport.h
index 23ca29beb..9c1d161b6 100644
--- a/compiler/circlechef/circle/src/CircleImport.h
+++ b/compiler/circlechef/circle/src/CircleImport.h
@@ -34,11 +34,6 @@ using CircleTensors_t = flatbuffers::Vector<flatbuffers::Offset<circle::Tensor>>
using CircleBuffers_t = flatbuffers::Vector<flatbuffers::Offset<circle::Buffer>>;
using CircleOperators_t = flatbuffers::Vector<flatbuffers::Offset<circle::Operator>>;
-const char *tensor_type(const circle::Tensor *tensor);
-const char *tensor_name(const circle::Tensor *tensor);
-bool is_valid(const circle::OperatorCode *opcode);
-bool is_custom(const circle::OperatorCode *opcode);
-
/**
* @brief Loads TF lite file and provides helpers to access attributes
*/
diff --git a/compiler/circlechef/circle/src/RecipeChef.cpp b/compiler/circlechef/circle/src/RecipeChef.cpp
index cd520cbc3..e21bca8a6 100644
--- a/compiler/circlechef/circle/src/RecipeChef.cpp
+++ b/compiler/circlechef/circle/src/RecipeChef.cpp
@@ -15,6 +15,7 @@
*/
#include <circlechef/RecipeChef.h>
+#include <mio_circle/Helper.h>
#include "Convert.h"
#include "CircleImport.h"
@@ -42,7 +43,7 @@ void set_inputs(CircleImport *import, circlechef::Operation *operation, const ci
else
{
auto tensor = tensors->Get(input);
- std::string name = tensor_name(tensor);
+ std::string name = mio::circle::tensor_name(tensor);
operation->add_input(name);
}
}
@@ -56,7 +57,7 @@ void set_outputs(CircleImport *import, circlechef::Operation *operation, const c
for (auto output : outputs)
{
auto tensor = tensors->Get(output);
- std::string name = tensor_name(tensor);
+ std::string name = mio::circle::tensor_name(tensor);
operation->add_output(name);
}
}
@@ -108,7 +109,7 @@ std::unique_ptr<ModelRecipe> generate_recipe(const circle::Model *model)
::circlechef::Operand *operand = model_recipe->add_operand();
- operand->set_name(tensor_name(tensor));
+ operand->set_name(mio::circle::tensor_name(tensor));
operand->set_type(as_circlechef_type(tensor->type()));
std::vector<int32_t> dims = as_index_vector(tensor->shape());
@@ -224,14 +225,14 @@ std::unique_ptr<ModelRecipe> generate_recipe(const circle::Model *model)
for (const auto input : inputs)
{
auto tensor = tensors->Get(input);
- std::string name = tensor_name(tensor);
+ std::string name = mio::circle::tensor_name(tensor);
model_recipe->add_input(name);
}
for (const auto output : outputs)
{
auto tensor = tensors->Get(output);
- std::string name = tensor_name(tensor);
+ std::string name = mio::circle::tensor_name(tensor);
model_recipe->add_output(name);
}
diff --git a/compiler/circlechef/core/CMakeLists.txt b/compiler/circlechef/core/CMakeLists.txt
index 0e8f47483..415954767 100644
--- a/compiler/circlechef/core/CMakeLists.txt
+++ b/compiler/circlechef/core/CMakeLists.txt
@@ -7,7 +7,7 @@ target_include_directories(circlechef_core PUBLIC include)
target_include_directories(circlechef_core PRIVATE src)
target_link_libraries(circlechef_core PUBLIC circlechef_proto)
target_link_libraries(circlechef_core PUBLIC circlechef_log)
-target_link_libraries(circlechef_core PUBLIC mio_circle)
+target_link_libraries(circlechef_core PUBLIC mio_circle04)
target_link_libraries(circlechef_core PUBLIC souschef)
target_link_libraries(circlechef_core PRIVATE nncc_coverage)
diff --git a/compiler/circlechef/core/src/ModelChef.cpp b/compiler/circlechef/core/src/ModelChef.cpp
index 6975f42a3..6c5206dfc 100644
--- a/compiler/circlechef/core/src/ModelChef.cpp
+++ b/compiler/circlechef/core/src/ModelChef.cpp
@@ -520,6 +520,10 @@ GeneratedModel cook(const ::circlechef::ModelRecipe &model_recipe)
for (auto const &opcode : builtin_code_map)
{
circle::OperatorCodeBuilder code_builder{*flatbuffer_builder};
+ int8_t dep_code = 127; // BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES
+ if (opcode.first < circle::BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES)
+ dep_code = static_cast<int8_t>(opcode.first);
+ code_builder.add_deprecated_builtin_code(dep_code);
code_builder.add_builtin_code(opcode.first);
code_builder.add_version(opcode.second);
auto code = code_builder.Finish();
diff --git a/compiler/circlechef/requires.cmake b/compiler/circlechef/requires.cmake
index 2106146d7..a5d6bedaa 100644
--- a/compiler/circlechef/requires.cmake
+++ b/compiler/circlechef/requires.cmake
@@ -1,9 +1,10 @@
require("arser")
require("nnkit")
require("cwrap")
-require("mio-circle")
+require("mio-circle04")
require("safemain")
require("hermes")
require("hermes-std")
require("foder")
require("souschef")
+require("circle-verify")
diff --git a/compiler/circlechef/tests/CMakeLists.txt b/compiler/circlechef/tests/CMakeLists.txt
index 773ff5403..7ae619f8b 100644
--- a/compiler/circlechef/tests/CMakeLists.txt
+++ b/compiler/circlechef/tests/CMakeLists.txt
@@ -3,6 +3,15 @@ set(CIRCLERECIPES_DIR "${CircleRecipes_DIR}")
file(GLOB RECIPES RELATIVE ${CIRCLERECIPES_DIR} "${CIRCLERECIPES_DIR}/*/test.recipe")
+set(CIRCLECHEF_FILE_PATH $<TARGET_FILE:circlechef-file>)
+set(CIRCLECHEF_REVERSE_PATH $<TARGET_FILE:circlechef-reverse>)
+if(DEFINED ENV{BUILD_HOST_EXEC})
+ # TODO use better way to represent path for host executable
+ set(CIRCLECHEF_FILE_PATH $ENV{BUILD_HOST_EXEC}/compiler/circlechef/tools/file/circlechef-file)
+ set(CIRCLECHEF_REVERSE_PATH $ENV{BUILD_HOST_EXEC}/compiler/circlechef/tools/reverse/circlechef-reverse)
+ message(STATUS "CIRCLECHEF_FILE_PATH = ${CIRCLECHEF_FILE_PATH}")
+endif(DEFINED ENV{BUILD_HOST_EXEC})
+
foreach(RECIPE IN ITEMS ${RECIPES})
get_filename_component(RECIPE_PREFIX ${RECIPE} DIRECTORY)
@@ -18,8 +27,8 @@ foreach(RECIPE IN ITEMS ${RECIPES})
# Generate .circle
add_custom_command(OUTPUT ${RECIPE_OUTPUT_FILE}
- COMMAND circlechef-file ${RECIPE_SOURCE_FILE} ${RECIPE_OUTPUT_FILE}
- DEPENDS circlechef-file ${RECIPE_SOURCE_FILE}
+ COMMAND ${CIRCLECHEF_FILE_PATH} ${RECIPE_SOURCE_FILE} ${RECIPE_OUTPUT_FILE}
+ DEPENDS ${CIRCLECHEF_FILE_PATH} ${RECIPE_SOURCE_FILE}
COMMENT "Generating ${RECIPE_OUTPUT_FILE}")
list(APPEND TESTS ${RECIPE_PREFIX})
@@ -44,8 +53,8 @@ foreach(RECIPE IN ITEMS ${RECIPES})
# Generate .circle
add_custom_command(OUTPUT ${RECIPE_OUTPUT_FILE}
- COMMAND circlechef-file ${RECIPE_SOURCE_FILE} ${RECIPE_OUTPUT_FILE}
- DEPENDS circlechef-file ${RECIPE_SOURCE_FILE}
+ COMMAND ${CIRCLECHEF_FILE_PATH} ${RECIPE_SOURCE_FILE} ${RECIPE_OUTPUT_FILE}
+ DEPENDS ${CIRCLECHEF_FILE_PATH} ${RECIPE_SOURCE_FILE}
COMMENT "Generating ${RECIPE_OUTPUT_FILE}")
list(APPEND TESTS ${RECIPE_PREFIX})
@@ -68,16 +77,16 @@ foreach(CIRCLEFILE IN ITEMS ${GEN_CIRCLEFILES})
# Generate .gen.recipe from generated .circle
add_custom_command(OUTPUT ${RECIPE_GEN_OUTPUT_FILE}
- COMMAND circlechef-reverse ${RECIPE_OUTPUT_FILE} ${RECIPE_GEN_OUTPUT_FILE}
- DEPENDS circlechef-reverse ${RECIPE_OUTPUT_FILE}
+ COMMAND ${CIRCLECHEF_REVERSE_PATH} ${RECIPE_OUTPUT_FILE} ${RECIPE_GEN_OUTPUT_FILE}
+ DEPENDS ${CIRCLECHEF_REVERSE_PATH} ${RECIPE_OUTPUT_FILE}
COMMENT "Generating ${RECIPE_GEN_OUTPUT_FILE}")
# now we are going to generate .gen.circle from .gen.recipe
# to check generated .gen.recipe file is correct by using it.
# as weight values may be different, binary comparision is not acceptable.
add_custom_command(OUTPUT ${RECIPE_GEN_OUTPUT_FILE2}
- COMMAND circlechef-file ${RECIPE_GEN_OUTPUT_FILE} ${RECIPE_GEN_OUTPUT_FILE2}
- DEPENDS circlechef-file ${RECIPE_GEN_OUTPUT_FILE}
+ COMMAND ${CIRCLECHEF_FILE_PATH} ${RECIPE_GEN_OUTPUT_FILE} ${RECIPE_GEN_OUTPUT_FILE2}
+ DEPENDS ${CIRCLECHEF_FILE_PATH} ${RECIPE_GEN_OUTPUT_FILE}
COMMENT "Generating ${RECIPE_GEN_OUTPUT_FILE2}")
list(APPEND TESTS ${CIRCLE_PREFIX}.gen)
@@ -96,13 +105,13 @@ foreach(CIRCLEFILE IN ITEMS ${GEN_CIRCLEFILES})
# Generate .gen.recipe from generated .circle
add_custom_command(OUTPUT ${RECIPE_GEN_OUTPUT_FILE}
- COMMAND circlechef-reverse ${RECIPE_OUTPUT_FILE} ${RECIPE_GEN_OUTPUT_FILE}
- DEPENDS circlechef-reverse ${RECIPE_OUTPUT_FILE}
+ COMMAND ${CIRCLECHEF_REVERSE_PATH} ${RECIPE_OUTPUT_FILE} ${RECIPE_GEN_OUTPUT_FILE}
+ DEPENDS ${CIRCLECHEF_REVERSE_PATH} ${RECIPE_OUTPUT_FILE}
COMMENT "Generating ${RECIPE_GEN_OUTPUT_FILE}")
add_custom_command(OUTPUT ${RECIPE_GEN_OUTPUT_FILE2}
- COMMAND circlechef-file ${RECIPE_GEN_OUTPUT_FILE} ${RECIPE_GEN_OUTPUT_FILE2}
- DEPENDS circlechef-file ${RECIPE_GEN_OUTPUT_FILE}
+ COMMAND ${CIRCLECHEF_FILE_PATH} ${RECIPE_GEN_OUTPUT_FILE} ${RECIPE_GEN_OUTPUT_FILE2}
+ DEPENDS ${CIRCLECHEF_FILE_PATH} ${RECIPE_GEN_OUTPUT_FILE}
COMMENT "Generating ${RECIPE_GEN_OUTPUT_FILE2}")
list(APPEND TESTS ${CIRCLE_PREFIX}.gen)
diff --git a/compiler/circledump/CMakeLists.txt b/compiler/circledump/CMakeLists.txt
index 7848ac722..b65c06677 100644
--- a/compiler/circledump/CMakeLists.txt
+++ b/compiler/circledump/CMakeLists.txt
@@ -1,6 +1,7 @@
-if(NOT TARGET mio_circle)
+if(NOT TARGET mio_circle04)
+ message(STATUS "Skip circledump: mio_circle04 not found")
return()
-endif(NOT TARGET mio_circle)
+endif(NOT TARGET mio_circle04)
set(DRIVER "driver/Driver.cpp")
@@ -9,8 +10,8 @@ file(GLOB_RECURSE SOURCES "src/*.cpp")
add_executable(circledump ${DRIVER} ${SOURCES})
target_include_directories(circledump PRIVATE include)
target_link_libraries(circledump arser)
-target_link_libraries(circledump mio_circle)
+target_link_libraries(circledump mio_circle04)
+target_link_libraries(circledump mio_circle04_helper)
target_link_libraries(circledump safemain)
-target_link_libraries(circledump flatbuffers-1.10)
install(TARGETS circledump DESTINATION bin)
diff --git a/compiler/circledump/README.md b/compiler/circledump/README.md
index e31c2d560..d2baf26b3 100644
--- a/compiler/circledump/README.md
+++ b/compiler/circledump/README.md
@@ -65,6 +65,6 @@ O T(3) ofm
### Dependency
-- mio-circle
+- mio-circle04
- safemain
- FlatBuffers
diff --git a/compiler/circledump/requires.cmake b/compiler/circledump/requires.cmake
index 81e0f0dbd..362d67cf4 100644
--- a/compiler/circledump/requires.cmake
+++ b/compiler/circledump/requires.cmake
@@ -1,3 +1,3 @@
require("arser")
-require("mio-circle")
+require("mio-circle04")
require("safemain")
diff --git a/compiler/circledump/src/Dump.cpp b/compiler/circledump/src/Dump.cpp
index 42b4ad97a..0b256dda8 100644
--- a/compiler/circledump/src/Dump.cpp
+++ b/compiler/circledump/src/Dump.cpp
@@ -15,6 +15,7 @@
*/
#include <circledump/Dump.h>
+#include <mio_circle/Helper.h>
#include "Read.h"
#include "OpPrinter.h"
@@ -141,7 +142,7 @@ void dump_sub_graph(std::ostream &os, circleread::Reader &reader)
// dump operands(tensors)
os << "Operands: T(subgraph index : tensor index) TYPE (shape) (shape_signature) "
- << "B(buffer index) OperandName" << std::endl;
+ << "B(buffer index) (variable) OperandName" << std::endl;
for (uint32_t i = 0; i < tensors->Length(); ++i)
{
// TODO refactor to some better structure
@@ -151,7 +152,7 @@ void dump_sub_graph(std::ostream &os, circleread::Reader &reader)
if (tensor->shape())
dims = circleread::as_index_vector(tensor->shape());
- os << "T(" << reader.subgraph_index() << ":" << i << ") " << circleread::tensor_type(tensor)
+ os << "T(" << reader.subgraph_index() << ":" << i << ") " << mio::circle::tensor_type(tensor)
<< " ";
os << "(" << dims << ") ";
if (tensor->shape_signature())
@@ -160,7 +161,11 @@ void dump_sub_graph(std::ostream &os, circleread::Reader &reader)
os << "(" << dims_sig << ") ";
}
os << "B(" << tensor->buffer() << ") ";
- os << circleread::tensor_name(tensor) << std::endl;
+ if (tensor->is_variable())
+ {
+ os << "(variable) ";
+ }
+ os << mio::circle::tensor_name(tensor) << std::endl;
if (auto q_params = tensor->quantization())
{
@@ -312,7 +317,7 @@ void dump_sub_graph(std::ostream &os, circleread::Reader &reader)
if (input >= 0)
{
auto tensor = tensors->Get(input);
- os << circleread::tensor_name(tensor);
+ os << mio::circle::tensor_name(tensor);
}
os << std::endl;
}
@@ -322,7 +327,7 @@ void dump_sub_graph(std::ostream &os, circleread::Reader &reader)
if (output >= 0)
{
auto tensor = tensors->Get(output);
- os << circleread::tensor_name(tensor);
+ os << mio::circle::tensor_name(tensor);
}
os << std::endl;
}
@@ -335,14 +340,14 @@ void dump_sub_graph(std::ostream &os, circleread::Reader &reader)
for (const auto input : reader.inputs())
{
auto tensor = tensors->Get(input);
- std::string name = circleread::tensor_name(tensor);
+ std::string name = mio::circle::tensor_name(tensor);
os << "I T(" << reader.subgraph_index() << ":" << input << ") " << name << std::endl;
}
for (const auto output : reader.outputs())
{
auto tensor = tensors->Get(output);
- std::string name = circleread::tensor_name(tensor);
+ std::string name = mio::circle::tensor_name(tensor);
os << "O T(" << reader.subgraph_index() << ":" << output << ") " << name << std::endl;
}
@@ -364,6 +369,7 @@ void dump_model(std::ostream &os, const circle::Model *model)
auto opcodes = reader.opcodes();
auto buffers = reader.buffers();
auto metadata = reader.metadata();
+ auto signaturedefs = reader.signature_defs();
// dump operator_codes
os << "Operator Codes: [order] OpCodeName (OpCode Enum)" << std::endl;
@@ -371,11 +377,14 @@ void dump_model(std::ostream &os, const circle::Model *model)
for (auto opcode : opcodes)
{
circle::BuiltinOperator op_code = opcode->builtin_code();
- auto op_name = circleread::opcode_name(opcode);
+ // cast to int32_t to print as number or int8_t will print as ascii code
+ int32_t dp_code = static_cast<int32_t>(opcode->deprecated_builtin_code());
+
+ auto op_name = mio::circle::opcode_name(opcode);
auto op_version = opcode->version();
os << "[" << opcode_index << "] " << op_name << " (code: " << op_code
- << ", version: " << op_version << ")" << std::endl;
+ << ", dep_code: " << dp_code << ", version: " << op_version << ")" << std::endl;
opcode_index++;
}
@@ -417,6 +426,37 @@ void dump_model(std::ostream &os, const circle::Model *model)
os << std::endl;
}
+ // dump signaturedef
+ if (signaturedefs != nullptr)
+ {
+ os << "SignatureDef" << std::endl;
+ for (uint32_t i = 0; i < signaturedefs->Length(); ++i)
+ {
+ auto sign_i = signaturedefs->Get(i);
+ os << "S(" << i << ") signature_key(" << sign_i->signature_key()->c_str() << "), sub_graph("
+ << sign_i->subgraph_index() << ")" << std::endl;
+
+ auto inputs_i = sign_i->inputs();
+ for (uint32_t t = 0; t < inputs_i->Length(); ++t)
+ {
+ auto inputs_i_t = inputs_i->Get(t);
+ os << " I(" << t << ")"
+ << " T(" << sign_i->subgraph_index() << ":" << inputs_i_t->tensor_index() << ") "
+ << inputs_i_t->name()->c_str() << std::endl;
+ }
+
+ auto outputs_i = sign_i->outputs();
+ for (uint32_t t = 0; t < outputs_i->Length(); ++t)
+ {
+ auto outputs_i_t = outputs_i->Get(t);
+ os << " O(" << t << ")"
+ << " T(" << sign_i->subgraph_index() << ":" << outputs_i_t->tensor_index() << ") "
+ << outputs_i_t->name()->c_str() << std::endl;
+ }
+ }
+ os << std::endl;
+ }
+
for (uint32_t sg = 0; sg < num_subgraph; ++sg)
{
reader.select_subgraph(sg);
diff --git a/compiler/circledump/src/Load.cpp b/compiler/circledump/src/Load.cpp
index ec91ed189..67e7fa5a6 100644
--- a/compiler/circledump/src/Load.cpp
+++ b/compiler/circledump/src/Load.cpp
@@ -76,7 +76,7 @@ public:
{
if (_value != -1)
{
- // Close on descturction
+ // Close on destructor
close(_value);
}
}
diff --git a/compiler/circledump/src/OpPrinter.cpp b/compiler/circledump/src/OpPrinter.cpp
index 7af3ff641..02e5c26b5 100644
--- a/compiler/circledump/src/OpPrinter.cpp
+++ b/compiler/circledump/src/OpPrinter.cpp
@@ -341,6 +341,7 @@ public:
<< ") ";
os << "Activation(" << EnumNameActivationFunctionType(params->fused_activation_function())
<< ") ";
+ os << "keep_num_dims(" << params->keep_num_dims() << ") ";
os << std::endl;
}
@@ -619,6 +620,23 @@ public:
}
};
+class SVDFPrinter : public OpPrinter
+{
+public:
+ void options(const circle::Operator *op, std::ostream &os) const override
+ {
+ if (auto *params = op->builtin_options_as_SVDFOptions())
+ {
+ os << " ";
+ os << "rank(" << params->rank() << ") ";
+ os << "activation(" << EnumNameActivationFunctionType(params->fused_activation_function())
+ << ") ";
+ os << "asymmetric_quantize_inputs(" << params->asymmetric_quantize_inputs() << ") ";
+ os << std::endl;
+ }
+ }
+};
+
class TransposeConvPrinter : public OpPrinter
{
public:
@@ -754,6 +772,22 @@ public:
}
};
+class InstanceNormPrinter : public OpPrinter
+{
+public:
+ void options(const circle::Operator *op, std::ostream &os) const override
+ {
+ if (auto *params = op->builtin_options_as_InstanceNormOptions())
+ {
+ os << " ";
+ os << "epsilon(" << params->epsilon() << ") ";
+ os << "Activation(" << EnumNameActivationFunctionType(params->fused_activation_function())
+ << ") ";
+ os << std::endl;
+ }
+ }
+};
+
OpPrinterRegistry::OpPrinterRegistry()
{
_op_map[circle::BuiltinOperator_ADD] = make_unique<AddPrinter>();
@@ -824,6 +858,7 @@ OpPrinterRegistry::OpPrinterRegistry()
_op_map[circle::BuiltinOperator_STRIDED_SLICE] = make_unique<StridedSlicePrinter>();
_op_map[circle::BuiltinOperator_SUB] = make_unique<SubPrinter>();
_op_map[circle::BuiltinOperator_SUM] = make_unique<ReducerPrinter>();
+ _op_map[circle::BuiltinOperator_SVDF] = make_unique<SVDFPrinter>();
_op_map[circle::BuiltinOperator_TRANSPOSE_CONV] = make_unique<TransposeConvPrinter>();
// There is no Option for TOPK_V2
_op_map[circle::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM] =
@@ -835,6 +870,7 @@ OpPrinterRegistry::OpPrinterRegistry()
// Circle only
_op_map[circle::BuiltinOperator_BCQ_FULLY_CONNECTED] = make_unique<BCQFullyConnectedPrinter>();
_op_map[circle::BuiltinOperator_BCQ_GATHER] = make_unique<BCQGatherPrinter>();
+ _op_map[circle::BuiltinOperator_INSTANCE_NORM] = make_unique<InstanceNormPrinter>();
}
} // namespace circledump
diff --git a/compiler/circledump/src/Read.cpp b/compiler/circledump/src/Read.cpp
index db8298585..3a7e98cde 100644
--- a/compiler/circledump/src/Read.cpp
+++ b/compiler/circledump/src/Read.cpp
@@ -16,72 +16,21 @@
#include "Read.h"
+#include <mio_circle/Helper.h>
+
#include <sstream>
#include <string>
namespace circleread
{
-bool is_valid(const circle::OperatorCode *opcode)
-{
- circle::BuiltinOperator code = opcode->builtin_code();
- return (circle::BuiltinOperator_MIN <= code && code <= circle::BuiltinOperator_MAX);
-}
-
-bool is_custom(const circle::OperatorCode *opcode)
-{
- circle::BuiltinOperator code = opcode->builtin_code();
- return (code == circle::BuiltinOperator_CUSTOM);
-}
-
-std::string opcode_name(const circle::OperatorCode *opcode)
-{
- assert(opcode);
-
- if (!is_valid(opcode))
- {
- std::ostringstream oss;
- oss << "(invalid)";
- return oss.str();
- }
-
- if (is_custom(opcode))
- {
- if (!opcode->custom_code())
- return "(invalid custom)";
-
- std::string custom_op = "CUSTOM(";
- custom_op += opcode->custom_code()->c_str();
- custom_op += ")";
- return custom_op;
- }
-
- circle::BuiltinOperator code = opcode->builtin_code();
- return circle::EnumNameBuiltinOperator(code);
-}
-
-const char *tensor_type(const circle::Tensor *tensor)
-{
- return circle::EnumNameTensorType(tensor->type());
-}
-
-const char *tensor_name(const circle::Tensor *tensor)
-{
- static const char *kEmptyTensorName = "(noname)";
-
- auto name = tensor->name();
- if (name)
- return name->c_str();
-
- return kEmptyTensorName;
-}
-
Reader::Reader(const circle::Model *model)
{
_version = model->version();
_subgraphs = model->subgraphs();
_buffers = model->buffers();
_metadata = model->metadata();
+ _signature_defs = model->signature_defs();
auto opcodes = model->operator_codes();
for (const ::circle::OperatorCode *opcode : *opcodes)
@@ -127,14 +76,14 @@ std::string Reader::opcode_name(const circle::Operator *op) const
assert(index < _op_codes.size());
const circle::OperatorCode *opcode = _op_codes.at(index);
- if (!is_valid(opcode))
+ if (!mio::circle::is_valid(opcode))
{
std::ostringstream oss;
oss << "(invalid: " << index << ")";
return oss.str();
}
- return circleread::opcode_name(opcode);
+ return mio::circle::opcode_name(opcode);
}
bool Reader::select_subgraph(uint32_t sgindex)
diff --git a/compiler/circledump/src/Read.h b/compiler/circledump/src/Read.h
index c61a1ab6d..05b0e5072 100644
--- a/compiler/circledump/src/Read.h
+++ b/compiler/circledump/src/Read.h
@@ -41,12 +41,6 @@ template <typename T> std::vector<T> as_index_vector(const flatbuffers::Vector<T
return ret;
}
-bool is_valid(const circle::OperatorCode *opcode);
-bool is_custom(const circle::OperatorCode *opcode);
-std::string opcode_name(const circle::OperatorCode *opcode);
-const char *tensor_type(const circle::Tensor *tensor);
-const char *tensor_name(const circle::Tensor *tensor);
-
/**
* @brief Loads Circle file and provides helpers to access attributes
*/
@@ -58,6 +52,7 @@ private:
using CircleTensors_t = flatbuffers::Vector<flatbuffers::Offset<circle::Tensor>>;
using CircleOperators_t = flatbuffers::Vector<flatbuffers::Offset<circle::Operator>>;
using CircleMetadata_t = flatbuffers::Vector<flatbuffers::Offset<circle::Metadata>>;
+ using CircleSignatureDef_t = flatbuffers::Vector<flatbuffers::Offset<circle::SignatureDef>>;
public:
Reader(const circle::Model *model);
@@ -75,6 +70,7 @@ public:
const std::vector<int32_t> &outputs() const { return _outputs; }
const circle::DataFormat &data_format() const { return _data_format; }
const CircleMetadata_t *metadata() const { return _metadata; }
+ const CircleSignatureDef_t *signature_defs() const { return _signature_defs; }
uint32_t num_subgraph() const { return _subgraphs->Length(); }
@@ -95,6 +91,7 @@ private:
const CircleTensors_t *_tensors{nullptr};
const CircleOperators_t *_operators{nullptr};
const CircleMetadata_t *_metadata{nullptr};
+ const CircleSignatureDef_t *_signature_defs{nullptr};
uint32_t _subgraph_index = 0;
std::string _subgraph_name;
diff --git a/compiler/cli/CMakeLists.txt b/compiler/cli/CMakeLists.txt
index 2ab8c0529..0fb99ddba 100644
--- a/compiler/cli/CMakeLists.txt
+++ b/compiler/cli/CMakeLists.txt
@@ -4,11 +4,11 @@ list(APPEND TESTS "src/App.test.cpp")
add_library(cli ${SOURCES})
target_include_directories(cli PUBLIC include)
-nnas_find_package(GTest QUIET)
-
-if(NOT GTest_FOUND)
+if(NOT ENABLE_TEST)
return()
-endif(NOT GTest_FOUND)
+endif(NOT ENABLE_TEST)
+
+nnas_find_package(GTest QUIET)
GTest_AddTEst(cli_test ${TESTS})
target_link_libraries(cli_test cli)
diff --git a/compiler/common-artifacts/CMakeLists.txt b/compiler/common-artifacts/CMakeLists.txt
index 6de634a25..404149c15 100644
--- a/compiler/common-artifacts/CMakeLists.txt
+++ b/compiler/common-artifacts/CMakeLists.txt
@@ -1,82 +1,63 @@
#[[ Generate common python virtual enviornment ]]
-find_package(PythonInterp 3 QUIET)
-find_package(PythonLibs 3 QUIET)
+find_package(PythonInterp 3.8 QUIET)
+find_package(PythonLibs 3.8 QUIET)
if(NOT ${PYTHONINTERP_FOUND})
message(STATUS "Build common-artifacts: FALSE (Python3 is missing)")
return()
endif()
-if(${PYTHON_VERSION_MINOR} LESS 3)
- message(STATUS "Build common-artifacts: FALSE (You need to install Python version higher than 3.3)")
+if(${PYTHON_VERSION_MINOR} LESS 8)
+ message(STATUS "Build common-artifacts: FALSE (You need to install Python version higher than 3.8)")
return()
endif()
-# Create python virtual environment with tensorflow 1.13.2
-set(VIRTUALENV_OVERLAY_TF_1_13_2 "${NNCC_OVERLAY_DIR}/venv_1_13_2")
-
-# Create python virtual environment with tensorflow 2.3.0
-set(VIRTUALENV_OVERLAY_TF_2_3_0 "${NNCC_OVERLAY_DIR}/venv_2_3_0")
# Create python virtual environment with tensorflow 2.6.0
set(VIRTUALENV_OVERLAY_TF_2_6_0 "${NNCC_OVERLAY_DIR}/venv_2_6_0")
add_custom_command(
- OUTPUT ${VIRTUALENV_OVERLAY_TF_1_13_2}
- COMMAND ${PYTHON_EXECUTABLE} -m venv ${VIRTUALENV_OVERLAY_TF_1_13_2}
-)
-
-add_custom_command(
- OUTPUT ${VIRTUALENV_OVERLAY_TF_2_3_0}
- COMMAND ${PYTHON_EXECUTABLE} -m venv ${VIRTUALENV_OVERLAY_TF_2_3_0}
-)
-add_custom_command(
OUTPUT ${VIRTUALENV_OVERLAY_TF_2_6_0}
COMMAND ${PYTHON_EXECUTABLE} -m venv ${VIRTUALENV_OVERLAY_TF_2_6_0}
)
-# Create requirements.txt and install required pip packages
-set(REQUIREMENTS_FILE "requirements.txt")
-set(REQUIREMENTS_OVERLAY_PATH_TF_1_13_2 "${VIRTUALENV_OVERLAY_TF_1_13_2}/${REQUIREMENTS_FILE}")
-set(REQUIREMENTS_OVERLAY_PATH_TF_2_3_0 "${VIRTUALENV_OVERLAY_TF_2_3_0}/${REQUIREMENTS_FILE}")
-set(REQUIREMENTS_OVERLAY_PATH_TF_2_6_0 "${VIRTUALENV_OVERLAY_TF_2_6_0}/${REQUIREMENTS_FILE}")
+# Create python virtual environment with tensorflow 2.8.0
+set(VIRTUALENV_OVERLAY_TF_2_8_0 "${NNCC_OVERLAY_DIR}/venv_2_8_0")
-# TODO remove version number of '--upgrade pip==20.2.1 setuptools==49.3.0'
-# NOTE adding version is for temporary hotfix of setuptools 50.x.y version
add_custom_command(
- OUTPUT ${REQUIREMENTS_OVERLAY_PATH_TF_1_13_2}
- COMMAND ${CMAKE_COMMAND} -E echo "tensorflow==1.13.2" > ${REQUIREMENTS_OVERLAY_PATH_TF_1_13_2}
- COMMAND ${VIRTUALENV_OVERLAY_TF_1_13_2}/bin/python -m pip --default-timeout=1000 install --upgrade pip==20.2.1 setuptools==49.3.0
- COMMAND ${VIRTUALENV_OVERLAY_TF_1_13_2}/bin/python -m pip --default-timeout=1000 install -r ${REQUIREMENTS_OVERLAY_PATH_TF_1_13_2} --upgrade
- DEPENDS ${VIRTUALENV_OVERLAY_TF_1_13_2}
+ OUTPUT ${VIRTUALENV_OVERLAY_TF_2_8_0}
+ COMMAND ${PYTHON_EXECUTABLE} -m venv ${VIRTUALENV_OVERLAY_TF_2_8_0}
)
-add_custom_command(
- OUTPUT ${REQUIREMENTS_OVERLAY_PATH_TF_2_3_0}
- COMMAND ${CMAKE_COMMAND} -E remove -f ${REQUIREMENTS_OVERLAY_PATH_TF_2_3_0}
- COMMAND ${CMAKE_COMMAND} -E echo "tensorflow-cpu==2.3.0" >> ${REQUIREMENTS_OVERLAY_PATH_TF_2_3_0}
- COMMAND ${CMAKE_COMMAND} -E echo "flatbuffers==1.12" >> ${REQUIREMENTS_OVERLAY_PATH_TF_2_3_0}
- COMMAND ${VIRTUALENV_OVERLAY_TF_2_3_0}/bin/python -m pip --default-timeout=1000 install --upgrade pip==20.2.1 setuptools==49.3.0
- COMMAND ${VIRTUALENV_OVERLAY_TF_2_3_0}/bin/python -m pip --default-timeout=1000 install -r ${REQUIREMENTS_OVERLAY_PATH_TF_2_3_0} --upgrade
- DEPENDS ${VIRTUALENV_OVERLAY_TF_2_3_0}
-)
+# Create requirements.txt and install required pip packages
+set(REQUIREMENTS_FILE "requirements.txt")
+set(REQUIREMENTS_OVERLAY_PATH_TF_2_6_0 "${VIRTUALENV_OVERLAY_TF_2_6_0}/${REQUIREMENTS_FILE}")
+set(REQUIREMENTS_OVERLAY_PATH_TF_2_8_0 "${VIRTUALENV_OVERLAY_TF_2_8_0}/${REQUIREMENTS_FILE}")
add_custom_command(
OUTPUT ${REQUIREMENTS_OVERLAY_PATH_TF_2_6_0}
COMMAND ${CMAKE_COMMAND} -E remove -f ${REQUIREMENTS_OVERLAY_PATH_TF_2_6_0}
COMMAND ${CMAKE_COMMAND} -E echo "tensorflow-cpu==2.6.0" >> ${REQUIREMENTS_OVERLAY_PATH_TF_2_6_0}
COMMAND ${CMAKE_COMMAND} -E echo "flatbuffers==1.12" >> ${REQUIREMENTS_OVERLAY_PATH_TF_2_6_0}
- COMMAND ${VIRTUALENV_OVERLAY_TF_2_6_0}/bin/python -m pip --default-timeout=1000 install --upgrade pip==20.2.1 setuptools==49.3.0
- COMMAND ${VIRTUALENV_OVERLAY_TF_2_6_0}/bin/python -m pip --default-timeout=1000 install -r ${REQUIREMENTS_OVERLAY_PATH_TF_2_6_0} --upgrade
+ COMMAND ${VIRTUALENV_OVERLAY_TF_2_6_0}/bin/python3.8 -m pip --default-timeout=1000 install --upgrade pip setuptools
+ COMMAND ${VIRTUALENV_OVERLAY_TF_2_6_0}/bin/python3.8 -m pip --default-timeout=1000 install -r ${REQUIREMENTS_OVERLAY_PATH_TF_2_6_0} --upgrade
DEPENDS ${VIRTUALENV_OVERLAY_TF_2_6_0}
)
+add_custom_command(
+ OUTPUT ${REQUIREMENTS_OVERLAY_PATH_TF_2_8_0}
+ COMMAND ${CMAKE_COMMAND} -E remove -f ${REQUIREMENTS_OVERLAY_PATH_TF_2_8_0}
+ COMMAND ${CMAKE_COMMAND} -E echo "tensorflow-cpu==2.8.0" >> ${REQUIREMENTS_OVERLAY_PATH_TF_2_8_0}
+ COMMAND ${CMAKE_COMMAND} -E echo "flatbuffers==1.12" >> ${REQUIREMENTS_OVERLAY_PATH_TF_2_8_0}
+ COMMAND ${VIRTUALENV_OVERLAY_TF_2_8_0}/bin/python3.8 -m pip --default-timeout=1000 install --upgrade pip setuptools
+ COMMAND ${VIRTUALENV_OVERLAY_TF_2_8_0}/bin/python3.8 -m pip --default-timeout=1000 install -r ${REQUIREMENTS_OVERLAY_PATH_TF_2_8_0} --upgrade
+ DEPENDS ${VIRTUALENV_OVERLAY_TF_2_8_0}
+)
+
add_custom_target(common_artifacts_python_deps ALL
- DEPENDS ${VIRTUALENV_OVERLAY_TF_1_13_2}
- ${VIRTUALENV_OVERLAY_TF_2_3_0}
- ${VIRTUALENV_OVERLAY_TF_2_6_0}
- ${REQUIREMENTS_OVERLAY_PATH_TF_1_13_2}
- ${REQUIREMENTS_OVERLAY_PATH_TF_2_3_0}
+ DEPENDS ${VIRTUALENV_OVERLAY_TF_2_6_0}
+ ${VIRTUALENV_OVERLAY_TF_2_8_0}
${REQUIREMENTS_OVERLAY_PATH_TF_2_6_0}
+ ${REQUIREMENTS_OVERLAY_PATH_TF_2_8_0}
)
#[[ Generate common resources ]]
@@ -97,7 +78,6 @@ target_link_libraries(testDataGenerator PRIVATE arser)
target_link_libraries(testDataGenerator PRIVATE foder)
target_link_libraries(testDataGenerator PRIVATE luci_import)
target_link_libraries(testDataGenerator PRIVATE luci_interpreter)
-target_link_libraries(testDataGenerator PRIVATE mio_circle)
target_link_libraries(testDataGenerator PRIVATE safemain)
unset(TEST_DEPS)
@@ -109,6 +89,7 @@ set(TFLITE_RECIPE_REPO "${TensorFlowLiteRecipes_DIR}")
set(CIRCLE_RECIPE_REPO "${CircleRecipes_DIR}")
set(TEST_RECIPE_FILENAME "test.recipe")
set(TEST_RULE_FILENAME "test.rule")
+set(TEST_QCONFIG_FILENAME "test.qconf.json")
set(MODEL2NNPKG "${NNAS_PROJECT_SOURCE_DIR}/tools/nnpackage_tool/model2nnpkg/model2nnpkg.sh")
# Get test case list
@@ -140,12 +121,20 @@ endmacro()
include("exclude.lst")
+# TODO revise using variadic arguments
+macro(tcgenerate_option NAME OPTION ARG1 ARG2 ARG3)
+ set(TCGEN_OPT_${NAME} ${OPTION} ${ARG1} ${ARG2} ${ARG3})
+endmacro()
+
+include("options.lst")
+
foreach(RECIPE IN ITEMS ${RECIPES})
unset(OPT_FORMAT)
unset(MODEL_FORMAT)
set(RECIPE_FILE "${RECIPE}.recipe")
set(RULE_FILE "${RECIPE}.rule")
+ set(QCONFIG_FILE "${RECIPE}.qconf.json")
set(TFLITE_RECIPE_SOURCE_PATH "${TFLITE_RECIPE_REPO}/${RECIPE}/${TEST_RECIPE_FILENAME}")
set(CIRCLE_RECIPE_SOURCE_PATH "${CIRCLE_RECIPE_REPO}/${RECIPE}/${TEST_RECIPE_FILENAME}")
@@ -174,8 +163,20 @@ foreach(RECIPE IN ITEMS ${RECIPES})
set(RULE_SOURCE_PATH ${CIRCLE_RULE_SOURCE_PATH})
endif()
+ set(TFLITE_QCONFIG_SOURCE_PATH "${TFLITE_RECIPE_REPO}/${RECIPE}/${TEST_QCONFIG_FILENAME}")
+ set(CIRCLE_QCONFIG_SOURCE_PATH "${CIRCLE_RECIPE_REPO}/${RECIPE}/${TEST_QCONFIG_FILENAME}")
+
+ unset(QCONFIG_SOURCE_PATH)
+ if(EXISTS "${TFLITE_QCONFIG_SOURCE_PATH}")
+ set(QCONFIG_SOURCE_PATH ${TFLITE_QCONFIG_SOURCE_PATH})
+ endif()
+ if(EXISTS "${CIRCLE_QCONFIG_SOURCE_PATH}")
+ set(QCONFIG_SOURCE_PATH ${CIRCLE_QCONFIG_SOURCE_PATH})
+ endif()
+
set(RECIPE_BINARY_PATH "${CMAKE_CURRENT_BINARY_DIR}/${RECIPE_FILE}")
set(RULE_BINARY_PATH "${CMAKE_CURRENT_BINARY_DIR}/${RULE_FILE}")
+ set(QCONFIG_BINARY_PATH "${CMAKE_CURRENT_BINARY_DIR}/${QCONFIG_FILE}")
set(TFLITE_FILE "${RECIPE}.tflite")
set(TFLITE_OUTPUT_PATH "${CMAKE_CURRENT_BINARY_DIR}/${TFLITE_FILE}")
@@ -200,6 +201,16 @@ foreach(RECIPE IN ITEMS ${RECIPES})
list(APPEND TEST_DEPS ${RULE_BINARY_PATH})
endif()
+ if(DEFINED QCONFIG_SOURCE_PATH)
+ # Copy .qconf.json
+ add_custom_command(OUTPUT ${QCONFIG_BINARY_PATH}
+ COMMAND ${CMAKE_COMMAND} -E copy "${QCONFIG_SOURCE_PATH}" "${QCONFIG_BINARY_PATH}"
+ DEPENDS ${QCONFIG_SOURCE_PATH}
+ COMMENT "Generate ${QCONFIG_FILE}"
+ )
+ list(APPEND TEST_DEPS ${QCONFIG_BINARY_PATH})
+ endif()
+
if(${MODEL_FORMAT} STREQUAL "tflite")
# Generate .tflite
add_custom_command(OUTPUT ${TFLITE_OUTPUT_PATH}
@@ -274,11 +285,21 @@ foreach(RECIPE IN ITEMS ${RECIPES})
)
list(APPEND TEST_DEPS ${TC_DIRECTORY})
+ # set ADDITIONAL_OPTIONS as empty (one space before closing is intentional)
+ set(ADDITIONAL_OPTIONS )
+ if(DEFINED TCGEN_OPT_${RECIPE})
+ set(ADDITIONAL_OPTIONS ${ADDITIONAL_OPTIONS} ${TCGEN_OPT_${RECIPE}})
+ endif()
+
# Generate input.h5, expected.h5
set(INPUT_HDF5_FILE "${TC_DIRECTORY}/input.h5")
set(EXPECTED_HDF5_FILE "${TC_DIRECTORY}/expected.h5")
add_custom_command(OUTPUT ${INPUT_HDF5_FILE} ${EXPECTED_HDF5_FILE}
- COMMAND $<TARGET_FILE:testDataGenerator> --input_data ${INPUT_HDF5_FILE} --expected_data ${EXPECTED_HDF5_FILE} ${MODEL_FILE}
+ COMMAND $<TARGET_FILE:testDataGenerator>
+ --input_data ${INPUT_HDF5_FILE}
+ --expected_data ${EXPECTED_HDF5_FILE}
+ ${ADDITIONAL_OPTIONS}
+ ${MODEL_FILE}
DEPENDS $<TARGET_FILE:testDataGenerator> ${MODEL_FILE} ${TC_DIRECTORY}
COMMENT "Generate input.h5 and expected.h5 in ${NNPKG_FILE}/metadata/tc"
)
diff --git a/compiler/common-artifacts/exclude.lst b/compiler/common-artifacts/exclude.lst
index f32e00413..92b07fde8 100644
--- a/compiler/common-artifacts/exclude.lst
+++ b/compiler/common-artifacts/exclude.lst
@@ -14,7 +14,6 @@ optimize(UnidirectionalSequenceLSTM_001) # This recipe contains is_variable Tens
tcgenerate(Abs_000)
tcgenerate(AddN_000)
tcgenerate(Add_001) # runtime doesn't support
-tcgenerate(Add_U8_000)
tcgenerate(Add_STR_000) # STRING is not supported
tcgenerate(Add_STR_001) # STRING is not supported
tcgenerate(All_000)
@@ -26,32 +25,24 @@ tcgenerate(ArgMin_U8_000)
tcgenerate(ArgMin_U8_001)
tcgenerate(ArgMin_U8_002)
tcgenerate(ArgMin_U8_003)
-tcgenerate(BatchMatMul_000)
tcgenerate(BatchMatMulV2_000)
tcgenerate(BatchMatMulV2_001)
tcgenerate(BatchToSpaceND_000)
tcgenerate(BroadcastTo_000) # luci-interpreter doesn't support custom operator
-tcgenerate(Cast_000)
-tcgenerate(Cast_001)
tcgenerate(Ceil_000)
tcgenerate(Conv2D_003) # runtime doesn't support dilation
tcgenerate(Cos_000)
tcgenerate(DepthwiseConv2D_001) # runtime doesn't support dilation
tcgenerate(DepthwiseConv2D_003) # runtime doesn't support dilation
tcgenerate(DepthwiseConv2D_U8_001) # luci-interpreter doesn't support channel-wise quantization yet
-tcgenerate(Dequantize_000) # runtime and luci-interpreter doesn't support Dequantize op yet
-tcgenerate(ExpandDims_000)
-tcgenerate(ExpandDims_001)
-tcgenerate(ExpandDims_002)
-tcgenerate(ExpandDims_003)
-tcgenerate(ExpandDims_004)
+tcgenerate(ExpandDims_001) # luci-interpreter doesn't support undefined shape
+tcgenerate(ExpandDims_002) # luci-interpreter doesn't support undefined shape
tcgenerate(FakeQuant_000) # runtime and luci-interpreter doesn't support yet
tcgenerate(Fill_000)
tcgenerate(Fill_001)
tcgenerate(FloorMod_000)
tcgenerate(FloorMod_001)
tcgenerate(FullyConnected_U8_000)
-tcgenerate(Gather_000)
tcgenerate(GatherNd_000)
tcgenerate(GatherNd_001)
tcgenerate(L2Pool2D_U8_000)
@@ -75,8 +66,8 @@ tcgenerate(Mul_U8_000)
tcgenerate(Neg_000)
tcgenerate(Net_BroadcastTo_AddV2_001) # luci-interpreter doesn't support custom operator
tcgenerate(Net_Conv_FakeQuant_000) # luci-interpreter doesn't support FakeQuant yet
-tcgenerate(Net_Conv_QuantDequant_000) # luci-interpreter doesn't support Quantize/Dequantize yet
tcgenerate(Net_Dangle_001)
+tcgenerate(Net_Gather_SparseToDense_AddV2_000) # luci-interpreter doesn't support custom operator
tcgenerate(Net_ZeroDim_001) # luci-interpreter doesn't support zero dim
tcgenerate(OneHot_000)
tcgenerate(OneHot_001)
@@ -157,13 +148,11 @@ tcgenerate(While_001) # Needs luci-interpreter int32_t support for ADD, EQUAL
tcgenerate(While_002) # Needs luci-interpreter int32_t support for ADD, EQUAL
tcgenerate(While_003) # Needs luci-interpreter int32_t support for ADD, EQUAL, and dynamic shape for WHILE
tcgenerate(YUV_TO_RGB_000)
-tcgenerate(YUV_TO_RGB_U8_000)
tcgenerate(ZerosLike_000)
## CircleRecipes
tcgenerate(BCQFullyConnected_000)
tcgenerate(BCQFullyConnected_001)
tcgenerate(BCQGather_000)
-tcgenerate(CircleBatchMatMul_000)
tcgenerate(InstanceNorm_000)
tcgenerate(InstanceNorm_001)
diff --git a/compiler/common-artifacts/options.lst b/compiler/common-artifacts/options.lst
new file mode 100644
index 000000000..5e0ff9da5
--- /dev/null
+++ b/compiler/common-artifacts/options.lst
@@ -0,0 +1,6 @@
+## Additional Options for test recipe
+
+#[[ tcgenerate_option : add additional option(s) for generation ]]
+
+# make valid 'indices' input value
+tcgenerate_option(Gather_001 --input_range indices 0 3)
diff --git a/compiler/common-artifacts/requires.cmake b/compiler/common-artifacts/requires.cmake
index d7bed21fe..cc07e17f6 100644
--- a/compiler/common-artifacts/requires.cmake
+++ b/compiler/common-artifacts/requires.cmake
@@ -4,6 +4,6 @@ require("circlechef")
require("foder")
require("luci")
require("luci-interpreter")
-require("mio-circle")
require("safemain")
require("tflchef")
+require("tflite2circle")
diff --git a/compiler/common-artifacts/src/TestDataGenerator.cpp b/compiler/common-artifacts/src/TestDataGenerator.cpp
index b00e93e88..33cecbbe2 100644
--- a/compiler/common-artifacts/src/TestDataGenerator.cpp
+++ b/compiler/common-artifacts/src/TestDataGenerator.cpp
@@ -18,7 +18,6 @@
#include <foder/FileLoader.h>
#include <luci/Importer.h>
#include <luci_interpreter/Interpreter.h>
-#include <mio/circle/schema_generated.h>
#include <H5Cpp.h>
@@ -27,6 +26,9 @@
#include <memory>
#include <random>
#include <string>
+#include <vector>
+#include <cassert>
+#include <cstdlib>
namespace
{
@@ -43,6 +45,8 @@ H5::PredType hdf5_dtype_cast(const loco::DataType loco_dtype)
{
case loco::DataType::U8:
return H5::PredType::NATIVE_UINT8;
+ case loco::DataType::S16:
+ return H5::PredType::NATIVE_INT16;
case loco::DataType::S32:
return H5::PredType::NATIVE_INT32;
case loco::DataType::S64:
@@ -56,7 +60,7 @@ H5::PredType hdf5_dtype_cast(const loco::DataType loco_dtype)
}
}
-template <typename T> void geneate_random_data(std::mt19937 &gen, void *data, uint32_t size)
+template <typename T> void generate_random_data(std::mt19937 &gen, void *data, uint32_t size)
{
std::normal_distribution<float> distrib(0, 2); // mean(0), stddev(2)
for (uint32_t i = 0; i < size; i++)
@@ -65,7 +69,7 @@ template <typename T> void geneate_random_data(std::mt19937 &gen, void *data, ui
}
}
-template <> void geneate_random_data<bool>(std::mt19937 &gen, void *data, uint32_t size)
+template <> void generate_random_data<bool>(std::mt19937 &gen, void *data, uint32_t size)
{
std::normal_distribution<float> distrib(0, 2); // mean(0), stddev(2)
for (uint32_t i = 0; i < size; i++)
@@ -74,6 +78,20 @@ template <> void geneate_random_data<bool>(std::mt19937 &gen, void *data, uint32
}
}
+template <typename T>
+void generate_random_range(void *data, uint32_t size, int32_t range_min, int32_t range_max)
+{
+ assert(range_min <= range_max);
+
+ for (uint32_t i = 0; i < size; i++)
+ {
+ // +1 will make value of [range_min, range_max]
+ int32_t range = range_max - range_min + 1;
+ int32_t value = (rand() % range) + range_min;
+ static_cast<T *>(data)[i] = static_cast<T>(value);
+ }
+}
+
void fill_random_data(void *data, uint32_t size, loco::DataType dtype, uint32_t seed)
{
std::mt19937 gen(seed); // standard mersenne_twister_engine seeded with rd()
@@ -81,19 +99,38 @@ void fill_random_data(void *data, uint32_t size, loco::DataType dtype, uint32_t
switch (dtype)
{
case loco::DataType::U8:
- geneate_random_data<uint8_t>(gen, data, size);
+ generate_random_data<uint8_t>(gen, data, size);
+ break;
+ case loco::DataType::S16:
+ generate_random_data<int16_t>(gen, data, size);
break;
case loco::DataType::S32:
- geneate_random_data<int32_t>(gen, data, size);
+ generate_random_data<int32_t>(gen, data, size);
break;
case loco::DataType::S64:
- geneate_random_data<int64_t>(gen, data, size);
+ generate_random_data<int64_t>(gen, data, size);
break;
case loco::DataType::FLOAT32:
- geneate_random_data<float>(gen, data, size);
+ generate_random_data<float>(gen, data, size);
break;
case loco::DataType::BOOL:
- geneate_random_data<bool>(gen, data, size);
+ generate_random_data<bool>(gen, data, size);
+ break;
+ default:
+ throw std::runtime_error("NYI data type.");
+ }
+}
+
+void fill_random_range(void *data, uint32_t size, loco::DataType dtype, int32_t range_min,
+ int32_t range_max)
+{
+ switch (dtype)
+ {
+ case loco::DataType::S32:
+ generate_random_range<int32_t>(data, size, range_min, range_max);
+ break;
+ case loco::DataType::S64:
+ generate_random_range<int64_t>(data, size, range_min, range_max);
break;
default:
throw std::runtime_error("NYI data type.");
@@ -120,6 +157,11 @@ int entry(int argc, char **argv)
.required(false)
.nargs(0)
.help("Put a fixed seed into the random number generator");
+ arser.add_argument("--input_range")
+ .required(false)
+ .nargs(3)
+ .type(arser::DataType::STR_VEC)
+ .help("Set random number range [min max] for the input as 'name min max'");
try
{
@@ -176,6 +218,24 @@ int entry(int argc, char **argv)
std::unique_ptr<H5::Group> output_value_group =
std::make_unique<H5::Group>(output_file.createGroup("value"));
+ std::string range_name;
+ int32_t range_min = 0;
+ int32_t range_max = 0;
+ bool range_check = false;
+ bool range_input_found = false;
+ if (arser["--input_range"])
+ {
+ // NOTE limitation: we can only set one input range
+ // TODO expand this for multiple inputs
+ std::vector<std::string> values = arser.get<std::vector<std::string>>("--input_range");
+ assert(values.size() == 3);
+ range_name = values.at(0);
+ // TODO add check for valid numbers
+ range_min = std::atoi(values.at(1).c_str());
+ range_max = std::atoi(values.at(2).c_str());
+ range_check = true;
+ }
+
std::random_device rd; // used to obtain a seed for the random number engine
uint32_t input_index = 0;
// TODO remove indentation
@@ -187,6 +247,7 @@ int entry(int argc, char **argv)
{
const auto *input_node = dynamic_cast<const luci::CircleInput *>(node);
std::string name = input_node->name();
+ assert(not name.empty());
if (name.find(":") == std::string::npos)
name += ":0";
@@ -217,7 +278,12 @@ int entry(int argc, char **argv)
std::vector<int8_t> data(byte_size);
// generate random data
- if (arser["--fixed_seed"])
+ if (range_name == input_node->name())
+ {
+ fill_random_range(data.data(), data_size, input_node->dtype(), range_min, range_max);
+ range_input_found = true;
+ }
+ else if (arser["--fixed_seed"])
fill_random_data(data.data(), data_size, input_node->dtype(), 0);
else
fill_random_data(data.data(), data_size, input_node->dtype(), rd());
@@ -230,6 +296,12 @@ int entry(int argc, char **argv)
}
}
+ if (range_check && not range_input_found)
+ {
+ std::cerr << "ERROR: input_range for input [" << range_name << "] not found." << std::endl;
+ return EXIT_FAILURE;
+ }
+
interpreter.interpret();
// dump output data into hdf5 file
diff --git a/compiler/dio-hdf5/CMakeLists.txt b/compiler/dio-hdf5/CMakeLists.txt
new file mode 100644
index 000000000..199c0d59d
--- /dev/null
+++ b/compiler/dio-hdf5/CMakeLists.txt
@@ -0,0 +1,30 @@
+nnas_find_package(HDF5 COMPONENTS STATIC QUIET)
+
+if(NOT HDF5_FOUND)
+ message(STATUS "Build dio_hdf5: FAILED (missing HDF5)")
+ return()
+endif(NOT HDF5_FOUND)
+
+file(GLOB_RECURSE SOURCES "src/*.cpp")
+file(GLOB_RECURSE TESTS "src/*.test.cpp")
+list(REMOVE_ITEM SOURCES ${TESTS})
+
+add_library(dio_hdf5 SHARED ${SOURCES})
+target_include_directories(dio_hdf5 PUBLIC include)
+target_include_directories(dio_hdf5 PUBLIC ${HDF5_INCLUDE_DIRS})
+target_link_libraries(dio_hdf5 PUBLIC ${HDF5_CXX_LIBRARIES})
+target_link_libraries(dio_hdf5 PUBLIC loco)
+
+install(TARGETS dio_hdf5 DESTINATION lib)
+install(DIRECTORY include/ DESTINATION include
+ FILES_MATCHING PATTERN "*.h")
+
+if(NOT ENABLE_TEST)
+ return()
+endif(NOT ENABLE_TEST)
+
+nnas_find_package(GTest REQUIRED)
+
+GTest_AddTest(dio_hdf5_test ${TESTS})
+target_include_directories(dio_hdf5_test PRIVATE include)
+target_link_libraries(dio_hdf5_test dio_hdf5)
diff --git a/compiler/dio-hdf5/README.md b/compiler/dio-hdf5/README.md
new file mode 100644
index 000000000..aa2398ce8
--- /dev/null
+++ b/compiler/dio-hdf5/README.md
@@ -0,0 +1,29 @@
+# dio-hdf5
+
+_dio-hdf5_ is a library to help loading hdf5 files (_dio_ indicates data I/O).
+
+The hdf5 file should have the following structure.
+
+```
+Group "/"
+ > Group <group_name>
+ > Group <data_idx>
+ > Dataset <input_idx>
+```
+
+## Example
+
+```cpp
+dio_hdf5::HDF5Importer h5{input_path};
+
+h5.importGroup("value");
+
+// Prepare buffer
+const uint32_t input_byte_size = 16;
+std::vector<char> buffer(input_byte_size);
+
+// Write the first input of the first data to buffer
+readTensor(0, 0, buffer.data());
+
+DO_SOMETHING_WITH(buffer);
+```
diff --git a/compiler/dio-hdf5/include/dio_hdf5/HDF5Importer.h b/compiler/dio-hdf5/include/dio_hdf5/HDF5Importer.h
new file mode 100644
index 000000000..aafcfbbf3
--- /dev/null
+++ b/compiler/dio-hdf5/include/dio_hdf5/HDF5Importer.h
@@ -0,0 +1,82 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __DIO_HDF5_H__
+#define __DIO_HDF5_H__
+
+#include <H5Cpp.h>
+
+#include <loco.h>
+
+#include <string>
+#include <vector>
+
+namespace dio
+{
+namespace hdf5
+{
+
+// HDF5Importer reads an input data saved in the hdf5 file in the given path
+// The hierarchy of the hdf5 file is as follows.
+// Group "/"
+// > Group <group_name>
+// > Group <data_idx>
+// > Dataset <input_idx>
+// data_idx : index of the data (dataset file can contain multiple data)
+// input_idx : index of the input (DNN model can have multiple inputs)
+// Ex: the j'th input of the i'th data of group 'value' can be accessed by "/value/i/j"
+class HDF5Importer final
+{
+public:
+ explicit HDF5Importer(const std::string &path);
+
+public:
+ /**
+ * @note importGroup has to be called before readTensor is called
+ * Otherwise, readTensor will throw an exception
+ */
+ void importGroup(const std::string &group) { _group = _file.openGroup(group); }
+
+ /**
+ * @brief Read tensor data from file and store it into buffer
+ * @details A tensor in the file can be retrieved with (data_idx, input_idx)
+ * @param data_idx : index of the data
+ * @param input_idx : index of the input
+ * @param dtype : pointer to write the tensor's data type
+ * @param shape : pointer to write the tensor's shape
+ * @param buffer : pointer to write the tensor's data
+ */
+ void readTensor(int32_t data_idx, int32_t input_idx, loco::DataType *dtype,
+ std::vector<loco::Dimension> *shape, void *buffer);
+
+ // Read a raw tensor (no type/shape is specified)
+ void readTensor(int32_t data_idx, int32_t input_idx, void *buffer);
+
+ bool isRawData() { return _group.attrExists("rawData"); }
+
+ int32_t numData() { return _group.getNumObjs(); }
+
+ int32_t numInputs(int32_t data_idx);
+
+private:
+ H5::H5File _file;
+ H5::Group _group;
+};
+
+} // namespace hdf5
+} // namespace dio
+
+#endif // __DIO_HDF5_H__
diff --git a/compiler/dio-hdf5/requires.cmake b/compiler/dio-hdf5/requires.cmake
new file mode 100644
index 000000000..44f6870da
--- /dev/null
+++ b/compiler/dio-hdf5/requires.cmake
@@ -0,0 +1 @@
+require("loco")
diff --git a/compiler/record-minmax/src/HDF5Importer.cpp b/compiler/dio-hdf5/src/HDF5Importer.cpp
index cfb270ce0..9ae556b77 100644
--- a/compiler/record-minmax/src/HDF5Importer.cpp
+++ b/compiler/dio-hdf5/src/HDF5Importer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,16 +14,17 @@
* limitations under the License.
*/
-#include "HDF5Importer.h"
+#include "dio_hdf5/HDF5Importer.h"
#include <H5Cpp.h>
#include <string>
+#include <vector>
#include <cassert>
#include <stdexcept>
-using Shape = luci_interpreter::Shape;
-using DataType = luci_interpreter::DataType;
+using Shape = std::vector<loco::Dimension>;
+using DataType = loco::DataType;
namespace
{
@@ -36,10 +37,10 @@ Shape toInternalShape(const H5::DataSpace &dataspace)
dims.resize(rank, 0);
dataspace.getSimpleExtentDims(dims.data());
- Shape res(rank);
+ Shape res;
for (int axis = 0; axis < rank; ++axis)
{
- res.dim(axis) = dims[axis];
+ res.emplace_back(dims[axis]);
}
return res;
@@ -108,18 +109,28 @@ void readTensorData(H5::DataSet &tensor, int64_t *buffer)
} // namespace
-namespace record_minmax
+namespace dio
{
+namespace hdf5
+{
+
+HDF5Importer::HDF5Importer(const std::string &path)
+{
+ if (_file.isHdf5(path) == false)
+ throw std::runtime_error("Given data file is not HDF5");
+
+ _file = H5::H5File(path, H5F_ACC_RDONLY);
+}
int32_t HDF5Importer::numInputs(int32_t record_idx)
{
- auto records = _value_grp.openGroup(std::to_string(record_idx));
+ auto records = _group.openGroup(std::to_string(record_idx));
return records.getNumObjs();
}
void HDF5Importer::readTensor(int32_t record_idx, int32_t input_idx, void *buffer)
{
- auto record = _value_grp.openGroup(std::to_string(record_idx));
+ auto record = _group.openGroup(std::to_string(record_idx));
auto tensor = record.openDataSet(std::to_string(input_idx));
readTensorData(tensor, static_cast<uint8_t *>(buffer));
@@ -128,7 +139,7 @@ void HDF5Importer::readTensor(int32_t record_idx, int32_t input_idx, void *buffe
void HDF5Importer::readTensor(int32_t record_idx, int32_t input_idx, DataType *dtype, Shape *shape,
void *buffer)
{
- auto record = _value_grp.openGroup(std::to_string(record_idx));
+ auto record = _group.openGroup(std::to_string(record_idx));
auto tensor = record.openDataSet(std::to_string(input_idx));
auto tensor_dtype = tensor.getDataType();
@@ -156,4 +167,5 @@ void HDF5Importer::readTensor(int32_t record_idx, int32_t input_idx, DataType *d
}
}
-} // namespace record_minmax
+} // namespace hdf5
+} // namespace dio
diff --git a/compiler/dio-hdf5/src/HDF5Importer.test.cpp b/compiler/dio-hdf5/src/HDF5Importer.test.cpp
new file mode 100644
index 000000000..61a027fc5
--- /dev/null
+++ b/compiler/dio-hdf5/src/HDF5Importer.test.cpp
@@ -0,0 +1,134 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "dio_hdf5/HDF5Importer.h"
+
+#include <loco.h>
+
+#include <H5Cpp.h>
+
+#include <cstdio>
+
+#include <gtest/gtest.h>
+
+using HDF5Importer = dio::hdf5::HDF5Importer;
+using Shape = std::vector<loco::Dimension>;
+using DataType = loco::DataType;
+
+namespace
+{
+
+const std::string file_name("dio_hdf5_test.h5");
+
+void createFile()
+{
+ // File already exists. Remove it.
+ if (auto f = fopen(file_name.c_str(), "r"))
+ {
+ fclose(f);
+ if (remove(file_name.c_str()) != 0)
+ throw std::runtime_error("Error deleting file.");
+ }
+
+ const auto rank = 3;
+ hsize_t dim[3] = {1, 2, 3};
+ H5::DataSpace space(rank, dim);
+
+ float data[] = {0, 1, 2, 3, 4, 5};
+
+ // Create test file in the current directory
+ H5::H5File file(file_name, H5F_ACC_TRUNC);
+ {
+ file.createGroup("/value");
+ file.createGroup("/value/0");
+ H5::DataSet dataset(file.createDataSet("/value/0/0", H5::PredType::IEEE_F32BE, space));
+ dataset.write(data, H5::PredType::IEEE_F32LE);
+ }
+}
+
+} // namespace
+
+TEST(dio_hdf5_test, read_with_type_shape)
+{
+ createFile();
+
+ HDF5Importer h5(::file_name);
+
+ h5.importGroup("value");
+
+ std::vector<float> buffer(6);
+
+ DataType dtype;
+ Shape shape;
+ h5.readTensor(0, 0, &dtype, &shape, buffer.data());
+
+ for (uint32_t i = 0; i < 6; i++)
+ EXPECT_EQ(i, buffer[i]);
+
+ EXPECT_EQ(DataType::FLOAT32, dtype);
+ EXPECT_EQ(3, shape.size());
+ EXPECT_EQ(1, shape[0]);
+ EXPECT_EQ(2, shape[1]);
+ EXPECT_EQ(3, shape[2]);
+}
+
+TEST(dio_hdf5_test, wrong_path_NEG)
+{
+ const std::string wrong_path = "not_existing_file_for_dio_hdf5_test";
+
+ EXPECT_ANY_THROW(HDF5Importer h5(wrong_path));
+}
+
+TEST(dio_hdf5_test, wrong_group_name_NEG)
+{
+ createFile();
+
+ HDF5Importer h5(::file_name);
+
+ EXPECT_ANY_THROW(h5.importGroup("wrong"));
+}
+
+TEST(dio_hdf5_test, data_out_of_index_NEG)
+{
+ createFile();
+
+ HDF5Importer h5(::file_name);
+
+ h5.importGroup("value");
+
+ std::vector<float> buffer(6);
+
+ DataType dtype;
+ Shape shape;
+ // Read non-existing data (data_idx = 1)
+ EXPECT_ANY_THROW(h5.readTensor(1, 0, &dtype, &shape, buffer.data()));
+}
+
+TEST(dio_hdf5_test, input_out_of_index_NEG)
+{
+ createFile();
+
+ HDF5Importer h5(::file_name);
+
+ h5.importGroup("value");
+
+ std::vector<float> buffer(6);
+
+ DataType dtype;
+ Shape shape;
+ // Read non-existing input (input_idx = 1)
+ EXPECT_ANY_THROW(h5.readTensor(0, 1, &dtype, &shape, buffer.data()));
+}
diff --git a/compiler/dredd-rule-lib/rule-lib.sh b/compiler/dredd-rule-lib/rule-lib.sh
index 9254cc9a7..c25dc5fb4 100755
--- a/compiler/dredd-rule-lib/rule-lib.sh
+++ b/compiler/dredd-rule-lib/rule-lib.sh
@@ -217,4 +217,21 @@ op_version()
echo ${ACTUAL}
}
+tensor_dtype()
+{
+ argc_check $# 1
+ file_path_check ${COMPILED_FILE}
+ file_path_check ${INSPECT_PROG_PATH}
+
+ set -o pipefail
+
+ ACTUAL=`init_error_log ; \
+ ${INSPECT_PROG_PATH} --tensor_dtype ${COMPILED_FILE} | \
+ awk -v tensor_name="$1" '{ if ($1 == tensor_name) print $2}'`
+
+ check_success_exit_code $? 0
+
+ echo ${ACTUAL}
+}
+
# TODO define more qullity test function
diff --git a/compiler/embedded-import-value-test/.gitignore b/compiler/embedded-import-value-test/.gitignore
new file mode 100644
index 000000000..8dbfa9012
--- /dev/null
+++ b/compiler/embedded-import-value-test/.gitignore
@@ -0,0 +1 @@
+/test.local.lst
diff --git a/compiler/embedded-import-value-test/CMakeLists.txt b/compiler/embedded-import-value-test/CMakeLists.txt
new file mode 100644
index 000000000..785edfc7d
--- /dev/null
+++ b/compiler/embedded-import-value-test/CMakeLists.txt
@@ -0,0 +1,34 @@
+if(NOT ENABLE_TEST)
+ return()
+endif(NOT ENABLE_TEST)
+
+set(SRCS_TEST_DRIVER src/TestDriver.cpp)
+
+# create driver
+add_executable(test_driver ${SRCS_TEST_DRIVER})
+target_link_libraries(test_driver PRIVATE luci_interpreter_import)
+target_link_libraries(test_driver PRIVATE luci_interpreter)
+target_link_libraries(test_driver PRIVATE safemain)
+
+unset(EMBEDDED_IMPORT_VALUE_TESTS)
+
+macro(addeval NAME)
+ list(APPEND EMBEDDED_IMPORT_VALUE_TESTS ${NAME})
+endmacro(addeval)
+
+# Read "test.lst"
+include("test.lst")
+# Read "test.local.lst" if exists
+include("test.local.lst" OPTIONAL)
+
+# Generate dependencies
+add_custom_target(embedded_import_testfiles ALL DEPENDS ${TESTFILES})
+
+get_target_property(ARTIFACTS_BIN_PATH testDataGenerator BINARY_DIR)
+
+add_test(NAME embedded_import_value_test
+ COMMAND "${CMAKE_CURRENT_SOURCE_DIR}/evalverify.sh"
+ "${CMAKE_CURRENT_BINARY_DIR}"
+ "${ARTIFACTS_BIN_PATH}"
+ ${EMBEDDED_IMPORT_VALUE_TESTS}
+)
diff --git a/compiler/embedded-import-value-test/README.md b/compiler/embedded-import-value-test/README.md
new file mode 100644
index 000000000..71a95486f
--- /dev/null
+++ b/compiler/embedded-import-value-test/README.md
@@ -0,0 +1,13 @@
+# embedded-import-value-test
+
+`embedded-import-value-test` checks models imported with and without constant copying produces same output values.
+
+The test proceeds as follows:
+
+1. Generate random input for provided circle model.
+
+2. Import circle model to luci in 2 modes:
+ - With constant copying (default mode).
+ - Without constant copying (experimental feature)
+
+3. Compare the execution result of both modes. The result must be the same.
diff --git a/compiler/embedded-import-value-test/evalverify.sh b/compiler/embedded-import-value-test/evalverify.sh
new file mode 100755
index 000000000..a99e76f3e
--- /dev/null
+++ b/compiler/embedded-import-value-test/evalverify.sh
@@ -0,0 +1,58 @@
+#!/bin/bash
+
+# This script verifies that imported without constants copying models executes well in luci_interpreter
+#
+# HOW TO USE
+#
+# ./evalverify.sh <path/to/bin_dir> <path/to/work_dir> <TEST 1> <TEST 2> ...
+# bin_dir : build directory of embedded-import-value-test (ex: build/compiler/embedded-import-value-test)
+# work_dir : artifacts directory where test materials exist
+
+BINDIR="$1"; shift
+WORKDIR="$1"; shift
+TEST_DRIVER_PATH="${BINDIR}/test_driver"
+TEST_RESULT_DIR="${BINDIR}/result"
+
+TESTED=()
+PASSED=()
+FAILED=()
+
+mkdir -p "${TEST_RESULT_DIR}"
+for TESTCASE in "$@"; do
+ TESTED+=("${TESTCASE}")
+
+ TESTCASE_FILE="${WORKDIR}/${TESTCASE}"
+ TEST_RESULT_FILE="${TEST_RESULT_DIR}/${TESTCASE}"
+
+ PASSED_TAG="${TEST_RESULT_FILE}.passed"
+ rm -f "${PASSED_TAG}"
+
+ cat > "${TEST_RESULT_FILE}.log" <(
+ exec 2>&1
+ set -ex
+
+ "${TEST_DRIVER_PATH}" --model "${TESTCASE_FILE}.circle"
+
+ if [[ $? -eq 0 ]]; then
+ touch "${PASSED_TAG}"
+ fi
+ )
+
+ if [[ -f "${PASSED_TAG}" ]]; then
+ PASSED+=("${TESTCASE}")
+ else
+ FAILED+=("${TESTCASE}")
+ fi
+done
+
+if [[ ${#TESTED[@]} -ne ${#PASSED[@]} ]]; then
+ echo "FAILED"
+ for TEST in "${FAILED[@]}"
+ do
+ echo "- ${TEST}"
+ done
+ exit 255
+fi
+
+echo "PASSED"
+exit 0
diff --git a/compiler/embedded-import-value-test/requires.cmake b/compiler/embedded-import-value-test/requires.cmake
new file mode 100644
index 000000000..f8af5f27e
--- /dev/null
+++ b/compiler/embedded-import-value-test/requires.cmake
@@ -0,0 +1,6 @@
+require("common-artifacts")
+require("luci")
+require("luci-interpreter")
+require("safemain")
+require("oops")
+require("loco")
diff --git a/compiler/embedded-import-value-test/src/TestDriver.cpp b/compiler/embedded-import-value-test/src/TestDriver.cpp
new file mode 100644
index 000000000..63fd745eb
--- /dev/null
+++ b/compiler/embedded-import-value-test/src/TestDriver.cpp
@@ -0,0 +1,242 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci_interpreter/GraphBuilderRegistry.h>
+#include <luci_interpreter/Interpreter.h>
+
+#include <luci/Importer.h>
+
+#include <cstdlib>
+#include <fstream>
+#include <iostream>
+#include <vector>
+#include <string>
+#include <random>
+
+namespace
+{
+
+uint32_t tensor_size_of(const luci::CircleNode *node)
+{
+ uint32_t tensor_size = loco::size(node->dtype());
+ for (uint32_t i = 0; i < node->rank(); ++i)
+ tensor_size *= node->dim(i).value();
+ return tensor_size;
+}
+
+std::vector<uint8_t> random_data_for(const luci::CircleInput *node)
+{
+ // allocate data buffer
+ std::vector<uint8_t> inputs_data(tensor_size_of(node));
+ auto *buffer = inputs_data.data();
+
+ // define size of buffer in elements
+ const auto dtype = node->dtype();
+ assert(inputs_data.size() % loco::size(dtype) == 0); // FIX ME UNLESS
+ const auto element_count = inputs_data.size() / loco::size(dtype);
+
+ // random generator engine
+ std::random_device device;
+ std::mt19937 engine{device()};
+
+ // fill buffer with random data
+ switch (node->dtype())
+ {
+ case loco::DataType::FLOAT32:
+ {
+ auto element_buffer = reinterpret_cast<float *>(buffer);
+
+ std::uniform_real_distribution<float> distrib(-3, 3);
+ const auto generator = [&distrib, &engine]() { return distrib(engine); };
+ std::generate(element_buffer, element_buffer + element_count, generator);
+
+ break;
+ }
+ case loco::DataType::U8:
+ {
+ auto element_buffer = buffer;
+
+ std::uniform_int_distribution<uint8_t> distrib(100, 200);
+ const auto generator = [&distrib, &engine]() { return distrib(engine); };
+ std::generate(element_buffer, element_buffer + element_count, generator);
+
+ break;
+ }
+ case loco::DataType::S16:
+ {
+ auto element_buffer = reinterpret_cast<int16_t *>(buffer);
+
+ std::uniform_int_distribution<int16_t> distrib(0, 100);
+ const auto generator = [&distrib, &engine]() { return distrib(engine); };
+ std::generate(element_buffer, element_buffer + element_count, generator);
+
+ break;
+ }
+ case loco::DataType::S32:
+ {
+ auto element_buffer = reinterpret_cast<int32_t *>(buffer);
+
+ std::uniform_int_distribution<int32_t> distrib(0, 100);
+ const auto generator = [&distrib, &engine]() { return distrib(engine); };
+ std::generate(element_buffer, element_buffer + element_count, generator);
+
+ break;
+ }
+ case loco::DataType::BOOL:
+ {
+ // num of bool data type is equivalent to uint8_t num in [0, 1] range
+ auto element_buffer = buffer;
+
+ std::uniform_int_distribution<uint8_t> distrib(0, 1);
+ const auto generator = [&distrib, &engine]() { return distrib(engine); };
+ std::generate(element_buffer, element_buffer + element_count, generator);
+
+ break;
+ }
+ default:
+ // TODO Support other dtypes
+ throw std::runtime_error("Unsupported data type, yet!");
+ }
+
+ return inputs_data;
+}
+
+} // namespace
+
+int entry(int argc, char **argv)
+{
+ // check arguments
+ if (argc != 3 || std::string(argv[1]) != "--model")
+ {
+ std::cerr << "Usage: " << argv[0] << " --model <path/to/model>" << std::endl;
+ return EXIT_FAILURE;
+ }
+
+ // open file with model
+ const auto model_file = std::string(argv[2]);
+ std::ifstream fs(model_file, std::ifstream::binary);
+ if (fs.fail())
+ {
+ std::cerr << "Cannot open model file \"" << model_file << "\"." << std::endl;
+ return EXIT_FAILURE;
+ }
+
+ // create constant circle model
+ const std::vector<char> model_buffer((std::istreambuf_iterator<char>(fs)),
+ std::istreambuf_iterator<char>());
+ const auto circle_model = circle::GetModel(model_buffer.data());
+
+ // create random model's inputs
+ std::vector<std::vector<uint8_t>> inputs_data;
+ {
+ // model inputs
+ auto model = luci::Importer(nullptr).importModule(circle_model);
+ const auto inputs = loco::input_nodes(model->graph());
+
+ // create random data for each input
+ for (const auto *input : inputs)
+ {
+ const auto input_node = loco::must_cast<const luci::CircleInput *>(input);
+ inputs_data.emplace_back(random_data_for(input_node));
+ }
+ }
+
+ // interpret given module
+ const auto interpret_module_and_compute_output =
+ [&](const std::unique_ptr<luci::Module> &module) {
+ // create interpreter
+ luci_interpreter::Interpreter interpreter(module.get());
+
+ // model's input and output nodes
+ const auto input_nodes = loco::input_nodes(module->graph());
+ const auto output_nodes = loco::output_nodes(module->graph());
+
+ // set inputs
+ for (uint32_t i = 0; i < input_nodes.size(); ++i)
+ {
+ const auto input_node = loco::must_cast<const luci::CircleInput *>(input_nodes[i]);
+ const auto &data = inputs_data.at(i);
+ interpreter.writeInputTensor(input_node, data.data(), data.size());
+ }
+
+ // do inference
+ interpreter.interpret();
+
+ // read outputs
+ std::vector<std::vector<uint8_t>> outputs_data;
+ for (const auto *node : output_nodes)
+ {
+ const auto output_node = loco::must_cast<const luci::CircleOutput *>(node);
+
+ // allocate output buffer
+ outputs_data.emplace_back(tensor_size_of(output_node));
+
+ auto &data = outputs_data.back();
+ interpreter.readOutputTensor(output_node, data.data(), data.size());
+ }
+
+ return outputs_data;
+ };
+
+ // import with copying, execute and save
+ std::vector<std::vector<uint8_t>> outputs_data_1;
+ {
+ const auto default_source = &luci::GraphBuilderRegistry::get();
+ const auto module = luci::Importer(default_source).importModule(circle_model);
+ if (not module)
+ {
+ std::cerr << "Fail to import model with constant copying." << std::endl;
+ return EXIT_FAILURE;
+ }
+
+ outputs_data_1 = interpret_module_and_compute_output(module);
+ }
+
+ // import without copying, execute and save
+ std::vector<std::vector<uint8_t>> outputs_data_2;
+ {
+ const auto optimized_source = luci_interpreter::source_without_constant_copying();
+ const auto module = luci::Importer(optimized_source.get()).importModule(circle_model);
+ if (not module)
+ {
+ std::cerr << "Fail to import model without constant copying." << std::endl;
+ return EXIT_FAILURE;
+ }
+
+ outputs_data_2 = interpret_module_and_compute_output(module);
+ }
+
+ // check all tensors are equal
+ assert(outputs_data_1.size() == outputs_data_2.size());
+ for (uint32_t n = 0; n < outputs_data_1.size(); ++n)
+ {
+ const auto &output_1 = outputs_data_1.at(n);
+ const auto &output_2 = outputs_data_2.at(n);
+ assert(output_1.size() == output_2.size());
+
+ for (uint32_t o = 0; o < output_1.size(); ++o)
+ {
+ if (output_1[o] != output_2[o])
+ {
+ std::cerr << "Values mismatch in model's output number " << n << std::endl;
+ return EXIT_FAILURE;
+ }
+ }
+ }
+
+ std::cout << "[TEST PASSED]" << std::endl;
+ return EXIT_SUCCESS;
+}
diff --git a/compiler/embedded-import-value-test/test.lst b/compiler/embedded-import-value-test/test.lst
new file mode 100644
index 000000000..924a60dcc
--- /dev/null
+++ b/compiler/embedded-import-value-test/test.lst
@@ -0,0 +1,192 @@
+#addeval(Abs_000)
+addeval(Add_000)
+#addeval(Add_001)
+addeval(Add_U8_000)
+#addeval(AddN_000)
+addeval(ArgMax_000)
+addeval(ArgMax_001)
+addeval(ArgMax_002)
+addeval(ArgMax_003)
+addeval(ArgMax_U8_000)
+addeval(ArgMax_U8_001)
+addeval(ArgMax_U8_002)
+addeval(ArgMax_U8_003)
+#addeval(ArgMin_000)
+#addeval(ArgMin_001)
+#addeval(ArgMin_002)
+#addeval(ArgMin_003)
+#addeval(ArgMin_U8_000)
+#addeval(ArgMin_U8_001)
+#addeval(ArgMin_U8_002)
+#addeval(ArgMin_U8_003)
+addeval(AveragePool2D_000)
+#addeval(BatchMatMul_000)
+#addeval(BatchMatMulV2_000)
+#addeval(BatchMatMulV2_001)
+#addeval(BatchToSpaceND_000)
+addeval(Cast_000)
+addeval(Cast_001)
+#addeval(Ceil_000)
+addeval(Concatenation_000)
+addeval(Concatenation_U8_000)
+addeval(Conv2D_000)
+addeval(Conv2D_001)
+addeval(Conv2D_002)
+addeval(Conv2D_003)
+addeval(Conv2D_U8_000)
+addeval(Conv2D_U8_001)
+#addeval(Cos_000)
+addeval(DepthToSpace_000)
+addeval(DepthwiseConv2D_000)
+addeval(DepthwiseConv2D_U8_000)
+#addeval(DepthwiseConv2D_U8_001)
+addeval(DepthwiseConv2D_001)
+addeval(Div_000)
+addeval(ELU_000)
+addeval(Equal_000)
+addeval(Exp_000)
+#addeval(ExpandDims_000)
+#addeval(ExpandDims_001)
+#addeval(ExpandDims_002)
+#addeval(ExpandDims_003)
+#addeval(Fill_000)
+#addeval(Fill_001)
+addeval(Floor_000)
+#addeval(FloorDiv_000)
+#addeval(FloorDiv_001)
+#addeval(FloorMod_000)
+#addeval(FloorMod_001)
+addeval(FullyConnected_000)
+addeval(FullyConnected_001)
+addeval(FullyConnected_002)
+#addeval(FullyConnected_U8_000)
+addeval(Gather_000)
+#addeval(GatherNd_000)
+#addeval(Greater_000)
+#addeval(GreaterEqual_000)
+addeval(If_000)
+addeval(If_001)
+addeval(L2Normalize_000)
+addeval(L2Pool2D_000)
+#addeval(L2Pool2D_U8_000)
+addeval(LeakyRelu_000)
+addeval(Less_000)
+addeval(LessEqual_000)
+addeval(LocalResponseNormalization_000)
+#addeval(Log_000)
+addeval(LogicalAnd_000)
+addeval(LogicalNot_000)
+addeval(LogicalOr_000)
+addeval(Logistic_000)
+addeval(LogSoftmax_000)
+#addeval(MatMul_000)
+#addeval(MatrixDiag_000)
+#addeval(MatrixSetDiag_000)
+addeval(Maximum_000)
+addeval(MaxPool2D_000)
+addeval(MaxPool2D_U8_000)
+addeval(Mean_000)
+addeval(Mean_001)
+#addeval(Mean_U8_000)
+#addeval(Minimum_000)
+#addeval(MirrorPad_000)
+addeval(Mul_000)
+#addeval(Mul_U8_000)
+addeval(Neg_000)
+addeval(NotEqual_000)
+addeval(OneHot_000)
+addeval(OneHot_001)
+addeval(OneHot_002)
+#addeval(OneHot_003)
+addeval(Pack_000)
+addeval(Pack_U8_000)
+addeval(Pad_000)
+addeval(Pad_U8_000)
+addeval(Pow_000)
+addeval(PRelu_000)
+#addeval(Range_000)
+#addeval(Rank_000)
+#addeval(ReduceAny_000)
+#addeval(ReduceAny_001)
+#addeval(ReduceAny_002)
+#addeval(ReduceAny_003)
+#addeval(ReduceMax_000)
+#addeval(ReduceMin_000)
+#addeval(ReduceProd_000)
+#addeval(ReduceProd_001)
+#addeval(ReduceProd_002)
+#addeval(ReduceProd_003)
+addeval(ReLU_000)
+addeval(ReLU6_000)
+#addeval(ReLUN1To1_000)
+addeval(Reshape_000)
+addeval(Reshape_001)
+addeval(Reshape_002)
+#addeval(Reshape_003)
+addeval(Reshape_U8_000)
+addeval(ResizeBilinear_000)
+addeval(ResizeNearestNeighbor_000)
+#addeval(ReverseSequence_000)
+#addeval(ReverseV2_000)
+#addeval(Round_000)
+addeval(Rsqrt_000)
+#addeval(ScatterNd_000)
+#addeval(SegmentSum_000)
+#addeval(Select_000)
+#addeval(Select_001)
+#addeval(Select_002)
+#addeval(SelectV2_000)
+#addeval(SelectV2_001)
+#addeval(SelectV2_002)
+#addeval(Shape_000)
+addeval(SignatureDef_MultiOut_000)
+addeval(SignatureDef_MultiOut_001)
+#addeval(Sin_000)
+addeval(Slice_000)
+addeval(Softmax_000)
+addeval(Softmax_U8_000)
+addeval(SpaceToBatchND_000)
+addeval(SpaceToBatchND_001)
+addeval(SpaceToBatchND_002)
+addeval(SpaceToBatchND_003)
+addeval(SpaceToDepth_000)
+#addeval(SparseToDense_000)
+addeval(Split_000)
+addeval(SplitV_000)
+addeval(Sqrt_000)
+addeval(Square_000)
+addeval(SquaredDifference_000)
+addeval(Squeeze_000)
+addeval(Squeeze_001)
+addeval(StridedSlice_000)
+addeval(StridedSlice_001)
+addeval(StridedSlice_002)
+addeval(Sub_000)
+addeval(Sub_U8_000)
+#addeval(Sum_000)
+#addeval(Sum_001)
+addeval(SVDF_000)
+addeval(SVDF_001)
+addeval(Tanh_000)
+#addeval(Tile_000)
+#addeval(Tile_U8_000)
+#addeval(TopKV2_000)
+#addeval(TopKV2_001)
+addeval(Transpose_000)
+addeval(TransposeConv_000)
+addeval(Unpack_000)
+addeval(Unpack_001)
+addeval(Unpack_002)
+addeval(Unpack_003)
+#addeval(Where_000)
+#addeval(Where_001)
+#addeval(While_000)
+#addeval(While_001)
+#addeval(While_002)
+#addeval(While_003)
+addeval(YUV_TO_RGB_U8_000)
+#addeval(ZerosLike_000)
+
+# Simple Network test
+addeval(Part_While_000)
+addeval(Part_While_001)
diff --git a/compiler/enco/CMakeLists.txt b/compiler/enco/CMakeLists.txt
index 17300e25e..3702f9501 100644
--- a/compiler/enco/CMakeLists.txt
+++ b/compiler/enco/CMakeLists.txt
@@ -1,4 +1,9 @@
add_subdirectory(core)
add_subdirectory(frontend)
add_subdirectory(cli)
+
+if(NOT ENABLE_TEST)
+ return()
+endif(NOT ENABLE_TEST)
+
add_subdirectory(test)
diff --git a/compiler/enco/core/CMakeLists.txt b/compiler/enco/core/CMakeLists.txt
index 25dad2bc6..19a64231a 100644
--- a/compiler/enco/core/CMakeLists.txt
+++ b/compiler/enco/core/CMakeLists.txt
@@ -20,11 +20,11 @@ target_link_libraries(enco_core PRIVATE morph)
# Let's use nncc project-wide build options
target_link_libraries(enco_core PRIVATE nncc_common)
-nnas_find_package(GTest QUIET)
-
-if(NOT GTest_FOUND)
+if(NOT ENABLE_TEST)
return()
-endif(NOT GTest_FOUND)
+endif(NOT ENABLE_TEST)
+
+nnas_find_package(GTest QUIET)
add_executable(enco_core_test ${TESTS})
target_include_directories(enco_core_test PRIVATE src)
diff --git a/compiler/enco/frontend/caffe/CMakeLists.txt b/compiler/enco/frontend/caffe/CMakeLists.txt
index 9722392a1..baf7f7bd6 100644
--- a/compiler/enco/frontend/caffe/CMakeLists.txt
+++ b/compiler/enco/frontend/caffe/CMakeLists.txt
@@ -17,11 +17,11 @@ target_link_libraries(enco_caffe_frontend enco_intf_cmdline)
target_link_libraries(enco_caffe_frontend morph)
target_link_libraries(enco_caffe_frontend caffeproto)
-nnas_find_package(GTest QUIET)
-
-if(NOT GTest_FOUND)
+if(NOT ENABLE_TEST)
return()
-endif(NOT GTest_FOUND)
+endif(NOT ENABLE_TEST)
+
+nnas_find_package(GTest QUIET)
nnas_find_package(Caffe QUIET)
diff --git a/compiler/enco/frontend/tflite/CMakeLists.txt b/compiler/enco/frontend/tflite/CMakeLists.txt
index b2de2b34b..995e66f81 100644
--- a/compiler/enco/frontend/tflite/CMakeLists.txt
+++ b/compiler/enco/frontend/tflite/CMakeLists.txt
@@ -1,4 +1,4 @@
-nnas_find_package(FlatBuffers EXACT 1.10 QUIET)
+nnas_find_package(FlatBuffers EXACT 2.0 QUIET)
if(NOT FlatBuffers_FOUND)
return()
@@ -17,16 +17,15 @@ add_library(enco_tflite_frontend SHARED ${SOURCES})
target_include_directories(enco_tflite_frontend PRIVATE src)
target_link_libraries(enco_tflite_frontend enco_intf_frontend)
target_link_libraries(enco_tflite_frontend enco_intf_cmdline)
-target_link_libraries(enco_tflite_frontend flatbuffers-1.10)
target_link_libraries(enco_tflite_frontend enco_tflite_schema)
target_link_libraries(enco_tflite_frontend morph)
target_link_libraries(enco_tflite_frontend cwrap)
-nnas_find_package(GTest QUIET)
-
-if(NOT GTest_FOUND)
+if(NOT ENABLE_TEST)
return()
-endif(NOT GTest_FOUND)
+endif(NOT ENABLE_TEST)
+
+nnas_find_package(GTest QUIET)
add_executable(enco_tflite_frontend_test ${TESTS})
target_include_directories(enco_tflite_frontend_test PRIVATE src)
diff --git a/compiler/exo/CMakeLists.txt b/compiler/exo/CMakeLists.txt
index 9d02f7cba..645db714c 100644
--- a/compiler/exo/CMakeLists.txt
+++ b/compiler/exo/CMakeLists.txt
@@ -1,4 +1,4 @@
-nnas_find_package(FlatBuffers EXACT 1.10 QUIET)
+nnas_find_package(FlatBuffers EXACT 2.0 QUIET)
if(NOT FlatBuffers_FOUND)
message(STATUS "Build exo: FALSE (missing FlatBuffers)")
@@ -15,7 +15,7 @@ endif(NOT TensorFlowSource_FOUND)
message(STATUS "Build exo: TRUE")
set(TFLITE_SCHEMA_DIR "${TensorFlowSource_DIR}/tensorflow/lite/schema")
-set(CIRCLE_SCHEMA_DIR "${NNAS_PROJECT_SOURCE_DIR}/nnpackage/schema")
+set(CIRCLE_SCHEMA_DIR "${NNAS_PROJECT_SOURCE_DIR}/res/CircleSchema/0.3")
FlatBuffers_Target(exo_tflite_fbs
OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/gen"
diff --git a/compiler/hermes-std/CMakeLists.txt b/compiler/hermes-std/CMakeLists.txt
index 8fce31953..673d7056c 100644
--- a/compiler/hermes-std/CMakeLists.txt
+++ b/compiler/hermes-std/CMakeLists.txt
@@ -3,7 +3,9 @@ file(GLOB_RECURSE TESTS "src/*.test.cpp")
list(REMOVE_ITEM SOURCES ${TESTS})
add_library(hermes_std STATIC ${SOURCES})
-set_target_properties(hermes_std PROPERTIES POSITION_INDEPENDENT_CODE ON)
+if (NOT NNCC_LIBRARY_NO_PIC)
+ set_target_properties(hermes_std PROPERTIES POSITION_INDEPENDENT_CODE ON)
+endif(NOT NNCC_LIBRARY_NO_PIC)
target_include_directories(hermes_std PUBLIC include)
target_link_libraries(hermes_std PUBLIC hermes)
target_link_libraries(hermes_std PRIVATE pepper_strcast)
diff --git a/compiler/hermes-std/include/hermes/ConsoleReporter.h b/compiler/hermes-std/include/hermes/ConsoleReporter.h
index e09dd5785..c55e46a17 100644
--- a/compiler/hermes-std/include/hermes/ConsoleReporter.h
+++ b/compiler/hermes-std/include/hermes/ConsoleReporter.h
@@ -28,6 +28,10 @@ namespace hermes
struct ConsoleReporter final : public hermes::Sink
{
void notify(const Message *m) final;
+ void set_colored_mode(bool is_colored) { _is_colored = is_colored; }
+
+private:
+ bool _is_colored = false;
};
} // namespace hermes
diff --git a/compiler/hermes-std/src/ConsoleReporter.cpp b/compiler/hermes-std/src/ConsoleReporter.cpp
index 3cc9f09ed..524ed59d8 100644
--- a/compiler/hermes-std/src/ConsoleReporter.cpp
+++ b/compiler/hermes-std/src/ConsoleReporter.cpp
@@ -17,16 +17,68 @@
#include "hermes/ConsoleReporter.h"
#include <iostream>
+#include <cstdlib>
+#include <string>
namespace hermes
{
+static constexpr const char *kTermColorRedTextCode = "\033[0;31m";
+static constexpr const char *kTermColorGreenTextCode = "\033[0;32m";
+static constexpr const char *kTermColorOrangeTextCode = "\033[0;33m";
+static constexpr const char *kTermColorBlueTextCode = "\033[0;34m";
+static constexpr const char *kTermColorMagentaTextCode = "\033[0;35m";
+static constexpr const char *kTermColorCyanTextCode = "\033[0;36m";
+static constexpr const char *kTermColorWhiteTextCode = "\033[0;37m";
+
+static constexpr const char *kTermBoldTextCode = "\033[1m";
+static constexpr const char *kTermUnderlineTextCode = "\033[4m";
+static constexpr const char *kTermInverseTextCode = "\033[7m";
+static constexpr const char *kTermBoldOffTextCode = "\033[21m";
+static constexpr const char *kTermUnderlineOffTextCode = "\033[24m";
+static constexpr const char *kTermInverseOffTextCode = "\033[27m";
+
+static constexpr const char *kTermColorResetAllCode = "\033[0m";
+
void ConsoleReporter::notify(const hermes::Message *m)
{
+ const char *env_color_p = std::getenv("ONE_HERMES_COLOR");
+ if (env_color_p)
+ {
+ auto env_color_str = std::string(env_color_p);
+ if ((env_color_str == "1") or (env_color_str == "ON"))
+ _is_colored = true;
+ }
+
+ if (_is_colored)
+ {
+ switch (m->get_severity())
+ {
+ case FATAL:
+ std::cout << kTermColorRedTextCode << kTermBoldTextCode << kTermUnderlineTextCode;
+ break;
+ case ERROR:
+ std::cout << kTermColorRedTextCode;
+ break;
+ case WARN:
+ std::cout << kTermColorOrangeTextCode;
+ break;
+ case INFO:
+ std::cout << kTermColorGreenTextCode;
+ break;
+ case VERBOSE:
+ std::cout << kTermColorResetAllCode;
+ break;
+ };
+ }
for (uint32_t n = 0; n < m->text()->lines(); ++n)
{
std::cout << m->text()->line(n) << std::endl;
}
+ if (_is_colored)
+ {
+ std::cout << kTermColorResetAllCode;
+ }
}
} // namespace hermes
diff --git a/compiler/hermes-std/src/ConsoleReporter.test.cpp b/compiler/hermes-std/src/ConsoleReporter.test.cpp
index a65585a6a..d959ff3d9 100644
--- a/compiler/hermes-std/src/ConsoleReporter.test.cpp
+++ b/compiler/hermes-std/src/ConsoleReporter.test.cpp
@@ -43,3 +43,168 @@ TEST(ConsoleReporterTest, notify)
ASSERT_NO_THROW(r.notify(&m));
}
+
+TEST(ConsoleReporterTest, notify_fatal)
+{
+ hermes::Message m;
+ {
+ std::stringstream ss;
+
+ ss << "This message is colored as FATAL" << std::endl;
+
+ m.text(std::make_unique<hermes::MessageText>(ss), hermes::FATAL);
+ }
+
+ hermes::ConsoleReporter r;
+
+ r.set_colored_mode(true);
+ ASSERT_NO_THROW(r.notify(&m));
+}
+
+TEST(ConsoleReporterTest, notify_error)
+{
+ hermes::Message m;
+ {
+ std::stringstream ss;
+
+ ss << "This message is colored as ERROR" << std::endl;
+
+ m.text(std::make_unique<hermes::MessageText>(ss), hermes::ERROR);
+ }
+
+ hermes::ConsoleReporter r;
+
+ r.set_colored_mode(true);
+ ASSERT_NO_THROW(r.notify(&m));
+}
+
+TEST(ConsoleReporterTest, notify_warn)
+{
+ hermes::Message m;
+ {
+ std::stringstream ss;
+
+ ss << "This message is colored as WARN" << std::endl;
+
+ m.text(std::make_unique<hermes::MessageText>(ss), hermes::WARN);
+ }
+
+ hermes::ConsoleReporter r;
+
+ r.set_colored_mode(true);
+ ASSERT_NO_THROW(r.notify(&m));
+}
+
+TEST(ConsoleReporterTest, notify_info)
+{
+ hermes::Message m;
+ {
+ std::stringstream ss;
+
+ ss << "This message is colored as INFO" << std::endl;
+
+ m.text(std::make_unique<hermes::MessageText>(ss), hermes::INFO);
+ }
+
+ hermes::ConsoleReporter r;
+
+ r.set_colored_mode(true);
+ ASSERT_NO_THROW(r.notify(&m));
+}
+
+TEST(ConsoleReporterTest, notify_verbose)
+{
+ hermes::Message m;
+ {
+ std::stringstream ss;
+
+ ss << "This message is colored as VERBOSE" << std::endl;
+
+ m.text(std::make_unique<hermes::MessageText>(ss), hermes::VERBOSE);
+ }
+
+ hermes::ConsoleReporter r;
+
+ r.set_colored_mode(true);
+ ASSERT_NO_THROW(r.notify(&m));
+}
+
+TEST(ConsoleReporterTest, notify_fatal_NEG)
+{
+ hermes::Message m;
+ {
+ std::stringstream ss;
+
+ ss << "This message is not colored as FATAL" << std::endl;
+
+ m.text(std::make_unique<hermes::MessageText>(ss), hermes::FATAL);
+ }
+
+ hermes::ConsoleReporter r;
+
+ ASSERT_NO_THROW(r.notify(&m));
+}
+
+TEST(ConsoleReporterTest, notify_error_NEG)
+{
+ hermes::Message m;
+ {
+ std::stringstream ss;
+
+ ss << "This message is not colored as ERROR" << std::endl;
+
+ m.text(std::make_unique<hermes::MessageText>(ss), hermes::ERROR);
+ }
+
+ hermes::ConsoleReporter r;
+
+ ASSERT_NO_THROW(r.notify(&m));
+}
+
+TEST(ConsoleReporterTest, notify_warn_NEG)
+{
+ hermes::Message m;
+ {
+ std::stringstream ss;
+
+ ss << "This message is not colored as WARN" << std::endl;
+
+ m.text(std::make_unique<hermes::MessageText>(ss), hermes::WARN);
+ }
+
+ hermes::ConsoleReporter r;
+
+ ASSERT_NO_THROW(r.notify(&m));
+}
+
+TEST(ConsoleReporterTest, notify_info_NEG)
+{
+ hermes::Message m;
+ {
+ std::stringstream ss;
+
+ ss << "This message is not colored as INFO" << std::endl;
+
+ m.text(std::make_unique<hermes::MessageText>(ss), hermes::INFO);
+ }
+
+ hermes::ConsoleReporter r;
+
+ ASSERT_NO_THROW(r.notify(&m));
+}
+
+TEST(ConsoleReporterTest, notify_verbose_NEG)
+{
+ hermes::Message m;
+ {
+ std::stringstream ss;
+
+ ss << "This message is not colored as VERBOSE" << std::endl;
+
+ m.text(std::make_unique<hermes::MessageText>(ss), hermes::VERBOSE);
+ }
+
+ hermes::ConsoleReporter r;
+
+ ASSERT_NO_THROW(r.notify(&m));
+}
diff --git a/compiler/hermes/CMakeLists.txt b/compiler/hermes/CMakeLists.txt
index e1a71c2b4..d33e2d735 100644
--- a/compiler/hermes/CMakeLists.txt
+++ b/compiler/hermes/CMakeLists.txt
@@ -3,7 +3,9 @@ file(GLOB_RECURSE TESTS "src/*.test.cpp")
list(REMOVE_ITEM SOURCES ${TESTS})
add_library(hermes STATIC ${SOURCES})
-set_target_properties(hermes PROPERTIES POSITION_INDEPENDENT_CODE ON)
+if (NOT NNCC_LIBRARY_NO_PIC)
+ set_target_properties(hermes PROPERTIES POSITION_INDEPENDENT_CODE ON)
+endif(NOT NNCC_LIBRARY_NO_PIC)
target_include_directories(hermes PUBLIC include)
# Let's apply nncc common compile options
#
diff --git a/compiler/hermes/include/hermes/core/Message.h b/compiler/hermes/include/hermes/core/Message.h
index 460163f64..d76f0eb6f 100644
--- a/compiler/hermes/include/hermes/core/Message.h
+++ b/compiler/hermes/include/hermes/core/Message.h
@@ -17,6 +17,8 @@
#ifndef __HERMES_MESSAGE_H__
#define __HERMES_MESSAGE_H__
+#include "Severity.h"
+
#include <memory>
#include <sstream>
#include <string>
@@ -48,7 +50,6 @@ private:
* @brief Message with metadata
*
* TODO Add "Timestamp" field
- * TODO Add "Severity" field
* TODO Support extensible "attribute" annotation
*/
class Message final
@@ -58,10 +59,17 @@ public:
public:
void text(std::unique_ptr<MessageText> &&text) { _text = std::move(text); }
+ void text(std::unique_ptr<MessageText> &&text, SeverityCategory severity)
+ {
+ _text = std::move(text);
+ _severity = severity;
+ }
const MessageText *text(void) const { return _text.get(); }
+ SeverityCategory get_severity(void) const { return _severity; }
private:
std::unique_ptr<MessageText> _text;
+ SeverityCategory _severity = SeverityCategory::INFO;
};
} // namespace hermes
diff --git a/compiler/hermes/include/hermes/core/MessageBuffer.h b/compiler/hermes/include/hermes/core/MessageBuffer.h
index a2f1de74d..1e2e9b9dc 100644
--- a/compiler/hermes/include/hermes/core/MessageBuffer.h
+++ b/compiler/hermes/include/hermes/core/MessageBuffer.h
@@ -18,6 +18,7 @@
#define __HERMES_MESSAGE_BUFFER_H__
#include "hermes/core/MessageBus.h"
+#include "hermes/core/Severity.h"
#include <ostream>
#include <sstream>
@@ -34,6 +35,7 @@ class MessageBuffer final
{
public:
MessageBuffer(MessageBus *);
+ MessageBuffer(MessageBus *bus, SeverityCategory severity);
~MessageBuffer();
public:
@@ -41,6 +43,7 @@ public:
private:
MessageBus *_bus;
+ SeverityCategory _severity = SeverityCategory::INFO;
/// @brief Content buffer
std::stringstream _ss;
diff --git a/compiler/hermes/src/core/MessageBuffer.cpp b/compiler/hermes/src/core/MessageBuffer.cpp
index a4ff4eeff..ce1f176d9 100644
--- a/compiler/hermes/src/core/MessageBuffer.cpp
+++ b/compiler/hermes/src/core/MessageBuffer.cpp
@@ -26,13 +26,19 @@ MessageBuffer::MessageBuffer(MessageBus *bus) : _bus{bus}
// DO NOTHING
}
+MessageBuffer::MessageBuffer(MessageBus *bus, SeverityCategory severity)
+ : _bus{bus}, _severity{severity}
+{
+ // DO NOTHING
+}
+
MessageBuffer::~MessageBuffer()
{
// NOTE The current implementation is unsafe as it may throw an excpetion.
// TODO Find a better safe implementation.
auto msg = std::make_unique<Message>();
- msg->text(std::make_unique<MessageText>(_ss));
+ msg->text(std::make_unique<MessageText>(_ss), _severity);
_bus->post(std::move(msg));
}
diff --git a/compiler/hermes/src/core/Source.cpp b/compiler/hermes/src/core/Source.cpp
index d124f4430..cb60d9a31 100644
--- a/compiler/hermes/src/core/Source.cpp
+++ b/compiler/hermes/src/core/Source.cpp
@@ -60,10 +60,9 @@ void Source::deactivate(void)
void Source::reload(const Config *c) { c->configure(this, _setting); }
-std::unique_ptr<MessageBuffer> Source::buffer(const Severity &) const
+std::unique_ptr<MessageBuffer> Source::buffer(const Severity &severity) const
{
- // TODO Pass Severity
- return std::make_unique<MessageBuffer>(_bus);
+ return std::make_unique<MessageBuffer>(_bus, severity.category());
}
} // namespace hermes
diff --git a/compiler/locomotiv/CMakeLists.txt b/compiler/locomotiv/CMakeLists.txt
index 308f48619..34835e483 100644
--- a/compiler/locomotiv/CMakeLists.txt
+++ b/compiler/locomotiv/CMakeLists.txt
@@ -3,7 +3,9 @@ file(GLOB_RECURSE TESTS "src/*.test.cpp")
list(REMOVE_ITEM SOURCES ${TESTS})
add_library(locomotiv STATIC ${SOURCES})
-set_target_properties(locomotiv PROPERTIES POSITION_INDEPENDENT_CODE ON)
+if (NOT NNCC_LIBRARY_NO_PIC)
+ set_target_properties(locomotiv PROPERTIES POSITION_INDEPENDENT_CODE ON)
+endif (NOT NNCC_LIBRARY_NO_PIC)
target_include_directories(locomotiv PUBLIC include)
target_include_directories(locomotiv PRIVATE src)
target_link_libraries(locomotiv PUBLIC loco)
diff --git a/compiler/locop/CMakeLists.txt b/compiler/locop/CMakeLists.txt
index f02fb1a72..43ec41af4 100644
--- a/compiler/locop/CMakeLists.txt
+++ b/compiler/locop/CMakeLists.txt
@@ -3,7 +3,9 @@ file(GLOB_RECURSE TESTS "src/*.test.cpp")
list(REMOVE_ITEM SOURCES ${TESTS})
add_library(locop STATIC ${SOURCES})
-set_target_properties(locop PROPERTIES POSITION_INDEPENDENT_CODE ON)
+if (NOT NNCC_LIBRARY_NO_PIC)
+ set_target_properties(locop PROPERTIES POSITION_INDEPENDENT_CODE ON)
+endif(NOT NNCC_LIBRARY_NO_PIC)
target_include_directories(locop PUBLIC include)
target_link_libraries(locop PUBLIC loco)
# Let's apply nncc common compile options
diff --git a/compiler/logo-core/CMakeLists.txt b/compiler/logo-core/CMakeLists.txt
index 3bc71dbd0..374794f90 100644
--- a/compiler/logo-core/CMakeLists.txt
+++ b/compiler/logo-core/CMakeLists.txt
@@ -3,7 +3,9 @@ file(GLOB_RECURSE TESTS "src/*.test.cpp")
list(REMOVE_ITEM SOURCES ${TESTS})
add_library(logo_core STATIC ${SOURCES})
-set_target_properties(logo_core PROPERTIES POSITION_INDEPENDENT_CODE ON)
+if (NOT NNCC_LIBRARY_NO_PIC)
+ set_target_properties(logo_core PROPERTIES POSITION_INDEPENDENT_CODE ON)
+endif(NOT NNCC_LIBRARY_NO_PIC)
target_include_directories(logo_core PRIVATE src)
target_include_directories(logo_core PUBLIC include)
target_link_libraries(logo_core PUBLIC loco)
diff --git a/compiler/logo-ex/CMakeLists.txt b/compiler/logo-ex/CMakeLists.txt
new file mode 100644
index 000000000..31d76025e
--- /dev/null
+++ b/compiler/logo-ex/CMakeLists.txt
@@ -0,0 +1,23 @@
+file(GLOB_RECURSE SOURCES "src/*.cpp")
+file(GLOB_RECURSE TESTS "src/*.test.cpp")
+list(REMOVE_ITEM SOURCES ${TESTS})
+
+add_library(logo_ex STATIC ${SOURCES})
+if (NOT NNCC_LIBRARY_NO_PIC)
+ set_target_properties(logo_ex PROPERTIES POSITION_INDEPENDENT_CODE ON)
+endif(NOT NNCC_LIBRARY_NO_PIC)
+target_include_directories(logo_ex PRIVATE src)
+target_include_directories(logo_ex PUBLIC include)
+target_link_libraries(logo_ex PUBLIC loco)
+target_link_libraries(logo_ex PUBLIC logo_core)
+target_link_libraries(logo_ex PRIVATE locomotiv)
+
+if(NOT ENABLE_TEST)
+ return()
+endif(NOT ENABLE_TEST)
+
+nnas_find_package(GTest REQUIRED)
+
+GTest_AddTest(logo_ex_test ${TESTS})
+target_include_directories(logo_ex_test PRIVATE src)
+target_link_libraries(logo_ex_test logo_ex)
diff --git a/compiler/logo-ex/README.md b/compiler/logo-ex/README.md
new file mode 100644
index 000000000..8ea55a202
--- /dev/null
+++ b/compiler/logo-ex/README.md
@@ -0,0 +1,6 @@
+# logo-ex
+
+_logo-ex_ provides _loco_ Extended Graph Passes for Transformation and Optimization
+that gets help from _locomotiv_
+
+NOTE: f2e7c38dcc601cb290c380d8314a3ae627923f58 is where this came from
diff --git a/compiler/logo/include/logo/ConstantFoldingPass.h b/compiler/logo-ex/include/logo/ConstantFoldingPass.h
index 99ccdc315..9143ae49b 100644
--- a/compiler/logo/include/logo/ConstantFoldingPass.h
+++ b/compiler/logo-ex/include/logo/ConstantFoldingPass.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef __LOGO_CONSTANT_FOLDING_PASS_H__
-#define __LOGO_CONSTANT_FOLDING_PASS_H__
+#ifndef __LOGO_EX_CONSTANT_FOLDING_PASS_H__
+#define __LOGO_EX_CONSTANT_FOLDING_PASS_H__
#include <logo/Pass.h>
@@ -38,4 +38,4 @@ public:
} // namespace logo
-#endif // __LOGO_CONSTANT_FOLDING_PASS_H__
+#endif // __LOGO_EX_CONSTANT_FOLDING_PASS_H__
diff --git a/compiler/logo-ex/include/logo/PassesEx.h b/compiler/logo-ex/include/logo/PassesEx.h
new file mode 100644
index 000000000..8bdf93bd9
--- /dev/null
+++ b/compiler/logo-ex/include/logo/PassesEx.h
@@ -0,0 +1,24 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOGO_PASSES_EX_H__
+#define __LOGO_PASSES_EX_H__
+
+// Please keep this in alphabetical order
+
+#include <logo/ConstantFoldingPass.h>
+
+#endif // __LOGO_PASSES_EX_H__
diff --git a/compiler/logo-ex/requires.cmake b/compiler/logo-ex/requires.cmake
new file mode 100644
index 000000000..c76183353
--- /dev/null
+++ b/compiler/logo-ex/requires.cmake
@@ -0,0 +1,3 @@
+require("loco")
+require("logo-core")
+require("locomotiv")
diff --git a/compiler/logo/src/Passes/ConstantFoldingPass.cpp b/compiler/logo-ex/src/Passes/ConstantFoldingPass.cpp
index 2bd4759ca..97d75458b 100644
--- a/compiler/logo/src/Passes/ConstantFoldingPass.cpp
+++ b/compiler/logo-ex/src/Passes/ConstantFoldingPass.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
diff --git a/compiler/logo/src/Passes/ConstantFoldingPass.test.cpp b/compiler/logo-ex/src/Passes/ConstantFoldingPass.test.cpp
index 5d222eb00..ba571a7f6 100644
--- a/compiler/logo/src/Passes/ConstantFoldingPass.test.cpp
+++ b/compiler/logo-ex/src/Passes/ConstantFoldingPass.test.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
diff --git a/compiler/logo-ex/src/TestHelper.h b/compiler/logo-ex/src/TestHelper.h
new file mode 100644
index 000000000..07e3b20aa
--- /dev/null
+++ b/compiler/logo-ex/src/TestHelper.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __TEST_HELPER_H__
+#define __TEST_HELPER_H__
+
+#include <loco.h>
+
+namespace logo
+{
+namespace test
+{
+
+template <typename T> T *find_first_node_by_type(loco::Graph *g)
+{
+ T *first_node = nullptr;
+
+ for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
+ {
+ first_node = dynamic_cast<T *>(node);
+ if (first_node != nullptr)
+ break;
+ }
+
+ return first_node;
+}
+
+} // namespace test
+} // namespace logo
+
+#endif // __TEST_HELPER_H__
diff --git a/compiler/logo/CMakeLists.txt b/compiler/logo/CMakeLists.txt
index a8efd9b03..e6a6f907f 100644
--- a/compiler/logo/CMakeLists.txt
+++ b/compiler/logo/CMakeLists.txt
@@ -3,12 +3,13 @@ file(GLOB_RECURSE TESTS "src/*.test.cpp")
list(REMOVE_ITEM SOURCES ${TESTS})
add_library(logo STATIC ${SOURCES})
-set_target_properties(logo PROPERTIES POSITION_INDEPENDENT_CODE ON)
+if (NOT NNCC_LIBRARY_NO_PIC)
+ set_target_properties(logo PROPERTIES POSITION_INDEPENDENT_CODE ON)
+endif(NOT NNCC_LIBRARY_NO_PIC)
target_include_directories(logo PRIVATE src)
target_include_directories(logo PUBLIC include)
target_link_libraries(logo PUBLIC loco)
target_link_libraries(logo PUBLIC logo_core)
-target_link_libraries(logo PRIVATE locomotiv)
if(NOT ENABLE_TEST)
return()
diff --git a/compiler/logo/include/logo/Passes.h b/compiler/logo/include/logo/Passes.h
index 636251e45..06fd3212b 100644
--- a/compiler/logo/include/logo/Passes.h
+++ b/compiler/logo/include/logo/Passes.h
@@ -19,7 +19,6 @@
// Please keep this in alphabetical order
-#include <logo/ConstantFoldingPass.h>
#include <logo/RemoveDeadNodePass.h>
#include <logo/RemoveForwardNodePass.h>
#include <logo/ReorderDecodePass.h>
diff --git a/compiler/logo/requires.cmake b/compiler/logo/requires.cmake
index c76183353..3e4d227cd 100644
--- a/compiler/logo/requires.cmake
+++ b/compiler/logo/requires.cmake
@@ -1,3 +1,2 @@
require("loco")
require("logo-core")
-require("locomotiv")
diff --git a/compiler/luci-interpreter/README.md b/compiler/luci-interpreter/README.md
index 4a9a34e6d..77ec5c81c 100644
--- a/compiler/luci-interpreter/README.md
+++ b/compiler/luci-interpreter/README.md
@@ -111,7 +111,7 @@ Note that one memory manager could be shared between multiple interpreter instan
List of predefined memory managers:
- `SimpleMemoryManager` This is a simple wrapper around new/delete, default one.
-- `TestMemoryManager` Memorizes all allocated memory and releases it in Manager desctuctor, used in kernel unit tests.
+- `TestMemoryManager` Memorizes all allocated memory and releases it in Manager destructor, used in kernel unit tests.
- `BuddyMemoryManager` Implements Buddy algorithm, uses external buffer for tensor data allocations, does not need new/delete.
- `StaticMemoryManger` Uses precomputed memory allocation plan. Requires preparation with MemoryPlanner, but could reduce memory consumption in restricted environments (like MCUs).
diff --git a/compiler/luci-interpreter/include/luci_interpreter/GraphBuilderRegistry.h b/compiler/luci-interpreter/include/luci_interpreter/GraphBuilderRegistry.h
new file mode 100644
index 000000000..375b1ae20
--- /dev/null
+++ b/compiler/luci-interpreter/include/luci_interpreter/GraphBuilderRegistry.h
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_INTERPRETER_GRAPH_BUILDER_REGISTRY__
+#define __LUCI_INTERPRETER_GRAPH_BUILDER_REGISTRY__
+
+#include <luci/Import/GraphBuilderRegistry.h>
+
+namespace luci_interpreter
+{
+
+/**
+ * @brief Creates and returns GraphBuilderSource, which allows to not copy constant buffers from
+ * model's file.
+ *
+ * @warning Use this source only in case when model's buffer alive longer than Interpreter.
+ */
+std::unique_ptr<luci::GraphBuilderSource> source_without_constant_copying();
+
+} // namespace luci_interpreter
+
+#endif // __LUCI_INTERPRETER_GRAPH_BUILDER_REGISTRY__
diff --git a/compiler/luci-interpreter/include/luci_interpreter/Interpreter.h b/compiler/luci-interpreter/include/luci_interpreter/Interpreter.h
index 7dee8a7f2..8e2f457a5 100644
--- a/compiler/luci-interpreter/include/luci_interpreter/Interpreter.h
+++ b/compiler/luci-interpreter/include/luci_interpreter/Interpreter.h
@@ -50,7 +50,9 @@ public:
class Interpreter
{
public:
- explicit Interpreter(const luci::Module *module, IMemoryManager *memory_manager = nullptr);
+ explicit Interpreter(const luci::Module *module);
+
+ explicit Interpreter(const luci::Module *module, IMemoryManager *memory_manager);
~Interpreter();
@@ -69,7 +71,6 @@ private:
// the order of deletion in the destructor
std::unique_ptr<IMemoryManager> _default_memory_manager = nullptr;
std::unique_ptr<class RuntimeModule> _runtime_module;
- IMemoryManager *_memory_manager = nullptr;
// Observer functionality support.
std::unique_ptr<struct RuntimeToIR> _runtime_to_ir;
diff --git a/compiler/luci-interpreter/pal/cmsisnn/KernelsToBuild.lst b/compiler/luci-interpreter/pal/cmsisnn/KernelsToBuild.lst
index 771974afe..d134a6b95 100644
--- a/compiler/luci-interpreter/pal/cmsisnn/KernelsToBuild.lst
+++ b/compiler/luci-interpreter/pal/cmsisnn/KernelsToBuild.lst
@@ -7,9 +7,11 @@ REGISTER_KERNEL(Concatenation)
REGISTER_KERNEL(Conv2D)
REGISTER_KERNEL(DepthToSpace)
REGISTER_KERNEL(DepthwiseConv2D)
+REGISTER_KERNEL(Dequantize)
REGISTER_KERNEL(Div)
REGISTER_KERNEL(Elu)
REGISTER_KERNEL(Exp)
+REGISTER_KERNEL(ExpandDims)
REGISTER_KERNEL(Floor)
REGISTER_KERNEL(FloorDiv)
REGISTER_KERNEL(Equal)
@@ -37,6 +39,7 @@ REGISTER_KERNEL(NotEqual)
REGISTER_KERNEL(Pad)
REGISTER_KERNEL(PadV2)
REGISTER_KERNEL(PRelu)
+REGISTER_KERNEL(Quantize)
REGISTER_KERNEL(Reshape)
REGISTER_KERNEL(ResizeBilinear)
REGISTER_KERNEL(ResizeNearestNeighbor)
@@ -50,6 +53,7 @@ REGISTER_KERNEL(Square)
REGISTER_KERNEL(SquaredDifference)
REGISTER_KERNEL(Squeeze)
REGISTER_KERNEL(Sub)
+REGISTER_KERNEL(SVDF)
REGISTER_KERNEL(Tanh)
REGISTER_KERNEL(Transpose)
REGISTER_KERNEL(TransposeConv)
diff --git a/compiler/luci-interpreter/pal/cmsisnn/PALAveragePool2d.h b/compiler/luci-interpreter/pal/cmsisnn/PALAveragePool2d.h
new file mode 100644
index 000000000..a274afb7e
--- /dev/null
+++ b/compiler/luci-interpreter/pal/cmsisnn/PALAveragePool2d.h
@@ -0,0 +1,124 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_PAL_AVERAGEPOOL2D_H
+#define LUCI_INTERPRETER_PAL_AVERAGEPOOL2D_H
+
+#include <tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h>
+#include <tensorflow/lite/kernels/internal/reference/pooling.h>
+#include <arm_nn_types.h>
+#include <arm_nnfunctions.h>
+
+namespace luci_interpreter_pal
+{
+template <typename T>
+static inline void AveragePool(const tflite::PoolParams &params,
+ const tflite::RuntimeShape &input_shape, const T *input_data,
+ const tflite::RuntimeShape &output_shape, T *output_data,
+ const tflite::RuntimeShape &scratchpad_shape, T *scratchpad_data)
+{
+ {
+ // MARK: At this moment this operation is not supported
+ assert(false && "AveragePool NYI");
+ (void)params;
+ (void)input_shape;
+ (void)input_data;
+ (void)output_shape;
+ (void)output_data;
+ (void)scratchpad_shape;
+ (void)scratchpad_data;
+ }
+}
+
+template <>
+inline void AveragePool<int8_t>(const tflite::PoolParams &params,
+ const tflite::RuntimeShape &input_shape, const int8_t *input_data,
+ const tflite::RuntimeShape &output_shape, int8_t *output_data,
+ const tflite::RuntimeShape &scratchpad_shape,
+ int8_t *scratchpad_data)
+{
+ assert(input_shape.DimensionsCount() == 4);
+ assert(output_shape.DimensionsCount() == 4);
+ assert(scratchpad_data != nullptr);
+
+ const int32_t batches = tflite::MatchingDim(input_shape, 0, output_shape, 0);
+ assert(batches == 1);
+
+ const int depth = tflite::MatchingDim(input_shape, 3, output_shape, 3);
+
+ cmsis_nn_dims input_dims;
+ input_dims.n = 1;
+ input_dims.h = input_shape.Dims(1);
+ input_dims.w = input_shape.Dims(2);
+ input_dims.c = depth;
+
+ cmsis_nn_dims output_dims;
+ output_dims.n = 1;
+ output_dims.h = output_shape.Dims(1);
+ output_dims.w = output_shape.Dims(2);
+ output_dims.c = depth;
+
+ cmsis_nn_pool_params pool_params;
+ pool_params.stride.h = params.stride_height;
+ pool_params.stride.w = params.stride_width;
+ pool_params.padding.h = params.padding_values.height;
+ pool_params.padding.w = params.padding_values.width;
+ pool_params.activation.min = params.quantized_activation_min;
+ pool_params.activation.max = params.quantized_activation_max;
+
+ cmsis_nn_dims filter_dims;
+ filter_dims.n = 1;
+ filter_dims.h = params.filter_height;
+ filter_dims.w = params.filter_width;
+ filter_dims.c = 1;
+
+ cmsis_nn_context ctx;
+ ctx.buf = scratchpad_data;
+ ctx.size = scratchpad_shape.Dims(0);
+ auto res = arm_avgpool_s8(&ctx, &pool_params, &input_dims, input_data, &filter_dims, &output_dims,
+ output_data);
+ assert(res == ARM_MATH_SUCCESS);
+}
+
+static inline void SetupScratchpadTensor(luci_interpreter::Tensor *scratchpad,
+ const luci_interpreter::DataType &input_data_type,
+ const tflite::RuntimeShape &input_shape,
+ const tflite::RuntimeShape &output_shape)
+
+{
+ if (input_data_type == luci_interpreter::DataType::S8)
+ {
+ assert(input_shape.DimensionsCount() == 4);
+ assert(output_shape.DimensionsCount() == 4);
+
+ const int32_t output_width = output_shape.Dims(2);
+ const int32_t depth = tflite::MatchingDim(input_shape, 3, output_shape, 3);
+
+ const int32_t buf_size = arm_avgpool_s8_get_buffer_size(output_width, depth);
+ auto data_type_size = static_cast<int32_t>(luci_interpreter::getDataTypeSize(input_data_type));
+
+ luci_interpreter::Shape scratchpad_shape{buf_size * data_type_size};
+ scratchpad->resize(scratchpad_shape);
+ }
+ else
+ {
+ scratchpad->set_allocatable(false);
+ }
+}
+
+} // namespace luci_interpreter_pal
+
+#endif // LUCI_INTERPRETER_PAL_AVERAGEPOOL2D_H
diff --git a/compiler/luci-interpreter/pal/cmsisnn/PALConv2d.h b/compiler/luci-interpreter/pal/cmsisnn/PALConv2d.h
index 0a8ae4e48..cfb84ea60 100644
--- a/compiler/luci-interpreter/pal/cmsisnn/PALConv2d.h
+++ b/compiler/luci-interpreter/pal/cmsisnn/PALConv2d.h
@@ -19,6 +19,8 @@
#include <tensorflow/lite/kernels/internal/reference/conv.h>
#include <tensorflow/lite/kernels/internal/reference/integer_ops/conv.h>
+#include <arm_nn_types.h>
+#include <arm_nnfunctions.h>
namespace luci_interpreter_pal
{
@@ -26,11 +28,11 @@ static inline void Conv(const tflite::ConvParams &params, const tflite::RuntimeS
const float *input_data, const tflite::RuntimeShape &filter_shape,
const float *filter_data, const tflite::RuntimeShape &bias_shape,
const float *bias_data, const tflite::RuntimeShape &output_shape,
- float *output_data, const tflite::RuntimeShape &im2col_shape,
- float *im2col_data)
+ float *output_data, const tflite::RuntimeShape &scratchpad_shape,
+ float *scratchpad_data)
{
- (void)im2col_shape;
- (void)im2col_data;
+ (void)scratchpad_shape;
+ (void)scratchpad_data;
tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
bias_shape, bias_data, output_shape, output_data,
tflite::RuntimeShape(), nullptr);
@@ -40,14 +42,14 @@ static inline void Conv(const tflite::ConvParams &params, const tflite::RuntimeS
const uint8 *input_data, const tflite::RuntimeShape &filter_shape,
const uint8 *filter_data, const tflite::RuntimeShape &bias_shape,
const int32 *bias_data, const tflite::RuntimeShape &output_shape,
- uint8 *output_data, const tflite::RuntimeShape &im2col_shape,
- uint8 *im2col_data)
+ uint8 *output_data, const tflite::RuntimeShape &scratchpad_shape,
+ uint8 *scratchpad_data)
{
- (void)im2col_shape;
- (void)im2col_data;
+ (void)scratchpad_shape;
+ (void)scratchpad_data;
tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
- bias_shape, bias_data, output_shape, output_data, im2col_shape,
- im2col_data, nullptr);
+ bias_shape, bias_data, output_shape, output_data, scratchpad_shape,
+ scratchpad_data, nullptr);
}
static inline void ConvPerChannel(const tflite::ConvParams &params, const int32_t *mult,
@@ -55,14 +57,141 @@ static inline void ConvPerChannel(const tflite::ConvParams &params, const int32_
const int8 *input_data, const tflite::RuntimeShape &filter_shape,
const int8 *filter_data, const tflite::RuntimeShape &bias_shape,
const int32 *bias_data, const tflite::RuntimeShape &output_shape,
- int8 *output_data, const tflite::RuntimeShape &im2col_shape,
- int8 *im2col_data)
+ int8 *output_data, const tflite::RuntimeShape &scratchpad_shape,
+ int8 *scratchpad_data)
{
- (void)im2col_shape;
- (void)im2col_data;
- tflite::reference_integer_ops::ConvPerChannel(params, mult, shifts, input_shape, input_data,
- filter_shape, filter_data, bias_shape, bias_data,
- output_shape, output_data);
+ if (scratchpad_data)
+ {
+ cmsis_nn_conv_params conv_params;
+ conv_params.dilation.h = params.dilation_height_factor;
+ conv_params.dilation.w = params.dilation_width_factor;
+
+ assert(conv_params.dilation.h == 1);
+ assert(conv_params.dilation.w == 1);
+
+ conv_params.input_offset = params.input_offset;
+ conv_params.output_offset = params.output_offset;
+ conv_params.stride.h = params.stride_height;
+ conv_params.stride.w = params.stride_width;
+ conv_params.padding.h = params.padding_values.height;
+ conv_params.padding.w = params.padding_values.width;
+ conv_params.activation.min = params.quantized_activation_min;
+ conv_params.activation.max = params.quantized_activation_max;
+
+ cmsis_nn_per_channel_quant_params quant_params;
+ quant_params.multiplier = const_cast<int32_t *>(mult);
+ quant_params.shift = const_cast<int32_t *>(shifts);
+
+ assert(conv_params.activation.min <= conv_params.activation.max);
+ assert(input_shape.DimensionsCount() == 4);
+ assert(filter_shape.DimensionsCount() == 4);
+ assert(output_shape.DimensionsCount() == 4);
+ const int batch_size = tflite::MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_depth = tflite::MatchingDim(input_shape, 3, filter_shape, 3);
+ const int output_depth = tflite::MatchingDim(filter_shape, 0, output_shape, 3);
+ if (bias_data)
+ {
+ assert(bias_shape.FlatSize() == output_depth);
+ }
+
+ cmsis_nn_dims input_dims;
+ input_dims.n = batch_size;
+ input_dims.h = input_shape.Dims(1);
+ input_dims.w = input_shape.Dims(2);
+ input_dims.c = input_depth;
+
+ cmsis_nn_dims filter_dims;
+ filter_dims.n = output_depth;
+ filter_dims.h = filter_shape.Dims(1);
+ filter_dims.w = filter_shape.Dims(2);
+ filter_dims.c = input_depth;
+
+ cmsis_nn_dims bias_dims;
+ bias_dims.n = 1;
+ bias_dims.h = 1;
+ bias_dims.w = 1;
+ bias_dims.c = output_depth;
+
+ cmsis_nn_dims output_dims;
+ output_dims.n = batch_size;
+ output_dims.h = output_shape.Dims(1);
+ output_dims.w = output_shape.Dims(2);
+ output_dims.c = output_depth;
+
+ cmsis_nn_context ctx;
+ ctx.buf = scratchpad_data;
+ ctx.size = scratchpad_shape.Dims(0);
+
+ auto res = arm_convolve_wrapper_s8(&ctx, &conv_params, &quant_params, &input_dims, input_data,
+ &filter_dims, filter_data, &bias_dims, bias_data,
+ &output_dims, output_data);
+ assert(res == ARM_MATH_SUCCESS);
+ }
+ else
+ {
+ tflite::reference_integer_ops::ConvPerChannel(params, mult, shifts, input_shape, input_data,
+ filter_shape, filter_data, bias_shape, bias_data,
+ output_shape, output_data);
+ }
+}
+
+static inline void SetupScratchpadTensor(luci_interpreter::Tensor *scratchpad,
+ const luci_interpreter::DataType &input_data_type,
+ const tflite::ConvParams &params,
+ const tflite::RuntimeShape &input_shape,
+ const tflite::RuntimeShape &filter_shape,
+ const tflite::RuntimeShape &output_shape)
+{
+ cmsis_nn_conv_params conv_params;
+ conv_params.dilation.h = params.dilation_height_factor;
+ conv_params.dilation.w = params.dilation_width_factor;
+
+ if (input_data_type == loco::DataType::S8 && conv_params.dilation.h == 1 &&
+ conv_params.dilation.w == 1)
+ {
+ const int32_t batches = tflite::MatchingDim(input_shape, 0, output_shape, 0);
+ const int32_t input_depth = tflite::MatchingDim(input_shape, 3, filter_shape, 3);
+ const int32_t output_depth = tflite::MatchingDim(filter_shape, 0, output_shape, 3);
+ const int32_t filter_height = filter_shape.Dims(1);
+ const int32_t filter_width = filter_shape.Dims(2);
+ const int32_t output_height = output_shape.Dims(1);
+ const int32_t output_width = output_shape.Dims(2);
+
+ conv_params.input_offset = params.input_offset;
+ conv_params.output_offset = params.output_offset;
+ conv_params.stride.h = params.stride_height;
+ conv_params.stride.w = params.stride_width;
+ conv_params.padding.h = params.padding_values.height;
+ conv_params.padding.w = params.padding_values.width;
+
+ cmsis_nn_dims input_dims;
+ input_dims.n = batches;
+ input_dims.h = input_shape.Dims(1);
+ input_dims.w = input_shape.Dims(2);
+ input_dims.c = input_depth;
+
+ cmsis_nn_dims filter_dims;
+ filter_dims.n = output_depth;
+ filter_dims.h = filter_height;
+ filter_dims.w = filter_width;
+ filter_dims.c = input_depth;
+
+ cmsis_nn_dims output_dims;
+ output_dims.n = batches;
+ output_dims.h = output_height;
+ output_dims.w = output_width;
+ output_dims.c = output_depth;
+
+ const int32_t buf_size = arm_convolve_wrapper_s8_get_buffer_size(&conv_params, &input_dims,
+ &filter_dims, &output_dims);
+
+ luci_interpreter::Shape scratchpad_shape{buf_size};
+ scratchpad->resize(scratchpad_shape);
+ }
+ else
+ {
+ scratchpad->set_allocatable(false);
+ }
}
} // namespace luci_interpreter_pal
diff --git a/compiler/luci-interpreter/pal/cmsisnn/PALDepthwiseConv2d.h b/compiler/luci-interpreter/pal/cmsisnn/PALDepthwiseConv2d.h
new file mode 100644
index 000000000..120dcd803
--- /dev/null
+++ b/compiler/luci-interpreter/pal/cmsisnn/PALDepthwiseConv2d.h
@@ -0,0 +1,192 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_PAL_DEPTHWISECONV2D_H
+#define LUCI_INTERPRETER_PAL_DEPTHWISECONV2D_H
+
+#include <tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h>
+#include <tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h>
+#include <tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h>
+#include <arm_nnfunctions.h>
+
+namespace luci_interpreter_pal
+{
+template <typename T>
+static inline void
+DepthwiseConvPerChannel(const tflite::DepthwiseParams &params, const int32_t *output_multiplier,
+ const int32_t *output_shift, const tflite::RuntimeShape &input_shape,
+ const T *input_data, const tflite::RuntimeShape &filter_shape,
+ const T *filter_data, const tflite::RuntimeShape &bias_shape,
+ const int32_t *bias_data, const tflite::RuntimeShape &output_shape,
+ T *output_data, const tflite::RuntimeShape &scratchpad_shape,
+ T *scratchpad_data)
+{
+ {
+ // MARK: At this moment this operation is not supported
+ assert(false && "DepthwiseConvPerChannel NYI");
+ (void)params;
+ (void)output_multiplier;
+ (void)output_shift;
+ (void)input_shape;
+ (void)output_data;
+ (void)input_data;
+ (void)filter_shape;
+ (void)filter_data;
+ (void)bias_shape;
+ (void)bias_data;
+ (void)output_shape;
+ (void)output_data;
+ (void)scratchpad_shape;
+ (void)scratchpad_data;
+ }
+}
+
+template <>
+inline void DepthwiseConvPerChannel<int8_t>(
+ const tflite::DepthwiseParams &params, const int32_t *output_multiplier,
+ const int32_t *output_shift, const tflite::RuntimeShape &input_shape, const int8_t *input_data,
+ const tflite::RuntimeShape &filter_shape, const int8_t *filter_data,
+ const tflite::RuntimeShape &bias_shape, const int32_t *bias_data,
+ const tflite::RuntimeShape &output_shape, int8_t *output_data,
+ const tflite::RuntimeShape &scratchpad_shape, int8_t *scratchpad_data)
+{
+ if (scratchpad_data)
+ {
+ cmsis_nn_dw_conv_params dw_conv_params;
+ dw_conv_params.dilation.h = params.dilation_height_factor;
+ dw_conv_params.dilation.w = params.dilation_width_factor;
+ assert(dw_conv_params.dilation.h == 1);
+ assert(dw_conv_params.dilation.w == 1);
+
+ dw_conv_params.input_offset = params.input_offset;
+ dw_conv_params.output_offset = params.output_offset;
+ dw_conv_params.stride.h = params.stride_height;
+ dw_conv_params.stride.w = params.stride_width;
+ dw_conv_params.padding.h = params.padding_values.height;
+ dw_conv_params.padding.w = params.padding_values.width;
+
+ dw_conv_params.activation.min = params.quantized_activation_min;
+ dw_conv_params.activation.max = params.quantized_activation_max;
+ dw_conv_params.ch_mult = params.depth_multiplier;
+
+ cmsis_nn_per_channel_quant_params quant_params;
+ int32_t output_multiplier = params.output_multiplier;
+ int32_t output_shift = params.output_shift;
+
+ quant_params.multiplier = &output_multiplier;
+ quant_params.shift = &output_shift;
+
+ assert(dw_conv_params.activation.min <= dw_conv_params.activation.max);
+ const int batch_size = tflite::MatchingDim(input_shape, 0, output_shape, 0);
+ const int output_depth = tflite::MatchingDim(filter_shape, 3, output_shape, 3);
+ if (bias_data)
+ {
+ assert(bias_shape.FlatSize() == output_depth);
+ }
+
+ cmsis_nn_dims input_dims;
+ input_dims.n = batch_size;
+ input_dims.h = input_shape.Dims(1);
+ input_dims.w = input_shape.Dims(2);
+ input_dims.c = input_shape.Dims(3);
+
+ cmsis_nn_dims filter_dims;
+ filter_dims.n = filter_shape.Dims(0);
+ filter_dims.h = filter_shape.Dims(1);
+ filter_dims.w = filter_shape.Dims(2);
+ filter_dims.c = output_depth;
+
+ cmsis_nn_dims bias_dims;
+ bias_dims.n = 1;
+ bias_dims.h = 1;
+ bias_dims.w = 1;
+ bias_dims.c = output_depth;
+
+ cmsis_nn_dims output_dims;
+ output_dims.n = batch_size;
+ output_dims.h = output_shape.Dims(1);
+ output_dims.w = output_shape.Dims(2);
+ output_dims.c = output_depth;
+
+ cmsis_nn_context ctx;
+ ctx.buf = scratchpad_data;
+ ctx.size = scratchpad_shape.Dims(0);
+
+ auto res = arm_depthwise_conv_wrapper_s8(&ctx, &dw_conv_params, &quant_params, &input_dims,
+ input_data, &filter_dims, filter_data, &bias_dims,
+ bias_data, &output_dims, output_data);
+ assert(res == ARM_MATH_SUCCESS);
+ }
+ else
+ {
+ tflite::reference_integer_ops::DepthwiseConvPerChannel(
+ params, output_multiplier, output_shift, input_shape, input_data, filter_shape, filter_data,
+ bias_shape, bias_data, output_shape, output_data);
+ }
+}
+
+static inline void SetupScratchpadTensor(luci_interpreter::Tensor *scratchpad,
+ const tflite::DepthwiseParams &params,
+ const luci_interpreter::DataType &input_data_type,
+ const tflite::RuntimeShape &input_shape,
+ const tflite::RuntimeShape &filter_shape,
+ const tflite::RuntimeShape &output_shape)
+{
+ cmsis_nn_dw_conv_params dw_conv_params;
+ dw_conv_params.dilation.h = params.dilation_height_factor;
+ dw_conv_params.dilation.w = params.dilation_width_factor;
+
+ if (input_data_type == loco::DataType::S8 && dw_conv_params.dilation.h == 1 &&
+ dw_conv_params.dilation.w == 1)
+ {
+ const int batch_size = tflite::MatchingDim(input_shape, 0, output_shape, 0);
+ const int output_depth = tflite::MatchingDim(filter_shape, 3, output_shape, 3);
+
+ cmsis_nn_dims input_dims;
+ input_dims.n = batch_size;
+ input_dims.h = input_shape.Dims(1);
+ input_dims.w = input_shape.Dims(2);
+ input_dims.c = input_shape.Dims(3);
+
+ cmsis_nn_dims filter_dims;
+ filter_dims.n = filter_shape.Dims(0);
+ filter_dims.h = filter_shape.Dims(1);
+ filter_dims.w = filter_shape.Dims(2);
+ filter_dims.c = output_depth;
+
+ cmsis_nn_dims output_dims;
+ output_dims.n = batch_size;
+ output_dims.h = output_shape.Dims(1);
+ output_dims.w = output_shape.Dims(2);
+ output_dims.c = output_depth;
+
+ const int32_t buf_size = arm_depthwise_conv_wrapper_s8_get_buffer_size(
+ &dw_conv_params, &input_dims, &filter_dims, &output_dims);
+
+ auto data_type_size = static_cast<int32_t>(luci_interpreter::getDataTypeSize(input_data_type));
+
+ luci_interpreter::Shape scratchpad_shape{buf_size * data_type_size};
+ scratchpad->resize(scratchpad_shape);
+ }
+ else
+ {
+ scratchpad->set_allocatable(false);
+ }
+}
+
+} // namespace luci_interpreter_pal
+
+#endif // LUCI_INTERPRETER_PAL_DEPTHWISECONV2D_H
diff --git a/compiler/luci-interpreter/pal/cmsisnn/PALDequantize.h b/compiler/luci-interpreter/pal/cmsisnn/PALDequantize.h
new file mode 100644
index 000000000..15ff0327b
--- /dev/null
+++ b/compiler/luci-interpreter/pal/cmsisnn/PALDequantize.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_PAL_DEQUANTIZE_H
+#define LUCI_INTERPRETER_PAL_DEQUANTIZE_H
+
+#include "tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h"
+#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
+
+namespace luci_interpreter_pal
+{
+
+template <typename T>
+static inline void Dequantize(tflite::DequantizationParams &params,
+ const tflite::RuntimeShape &input_shape, const T *input_data,
+ const tflite::RuntimeShape &output_shape, float *output_data)
+{
+ tflite::reference_integer_ops::Dequantize<T>(params, input_shape, input_data, output_shape,
+ output_data);
+}
+
+static inline void Dequantize(tflite::DequantizationParams &params,
+ const tflite::RuntimeShape &input_shape, const uint8_t *input_data,
+ const tflite::RuntimeShape &output_shape, float *output_data)
+{
+ tflite::reference_ops::Dequantize(params, input_shape, input_data, output_shape, output_data);
+}
+
+} // namespace luci_interpreter_pal
+
+#endif // LUCI_INTERPRETER_PAL_DEQUANTIZE_H
diff --git a/compiler/luci-interpreter/pal/cmsisnn/PALFullyConnected.h b/compiler/luci-interpreter/pal/cmsisnn/PALFullyConnected.h
new file mode 100644
index 000000000..32e905761
--- /dev/null
+++ b/compiler/luci-interpreter/pal/cmsisnn/PALFullyConnected.h
@@ -0,0 +1,114 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_PAL_FULLYCONNECTED_H
+#define LUCI_INTERPRETER_PAL_FULLYCONNECTED_H
+
+#include <tensorflow/lite/kernels/internal/reference/fully_connected.h>
+#include <tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h>
+#include <arm_nnfunctions.h>
+
+namespace luci_interpreter_pal
+{
+template <typename T>
+static inline void FullyConnected(const tflite::FullyConnectedParams &params,
+ const tflite::RuntimeShape &input_shape, const T *input_data,
+ const tflite::RuntimeShape &filter_shape, const T *filter_data,
+ const tflite::RuntimeShape &bias_shape, const int32_t *bias_data,
+ const tflite::RuntimeShape &output_shape, T *output_data)
+{
+ {
+ // MARK: At this moment this operation doesn't support
+ assert(false && "FullyConnected NYI");
+ (void)params;
+ (void)input_shape;
+ (void)input_data;
+ (void)filter_shape;
+ (void)filter_data;
+ (void)bias_shape;
+ (void)bias_data;
+ (void)output_shape;
+ (void)output_data;
+ }
+}
+
+template <>
+inline void
+FullyConnected<int8_t>(const tflite::FullyConnectedParams &params,
+ const tflite::RuntimeShape &input_shape, const int8_t *input_data,
+ const tflite::RuntimeShape &filter_shape, const int8_t *filter_data,
+ const tflite::RuntimeShape &bias_shape, const int32_t *bias_data,
+ const tflite::RuntimeShape &output_shape, int8_t *output_data)
+{
+ assert(output_shape.DimensionsCount() == 2);
+
+ const int batches = output_shape.Dims(0);
+ const int output_depth = output_shape.Dims(1);
+
+ const int filter_dim_count = filter_shape.DimensionsCount();
+ const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
+
+ cmsis_nn_fc_params fc_params;
+ fc_params.input_offset = params.input_offset;
+ fc_params.output_offset = params.output_offset;
+ fc_params.filter_offset = params.weights_offset;
+ fc_params.activation.min = params.quantized_activation_min;
+ fc_params.activation.max = params.quantized_activation_max;
+
+ cmsis_nn_per_tensor_quant_params quant_params;
+ quant_params.multiplier = params.output_multiplier;
+ quant_params.shift = params.output_shift;
+
+ cmsis_nn_dims input_dims;
+ input_dims.n = batches;
+ input_dims.h = 1;
+ input_dims.w = 1;
+ input_dims.c = accum_depth;
+
+ cmsis_nn_dims filter_dims;
+ filter_dims.n = accum_depth;
+ filter_dims.h = 1;
+ filter_dims.w = 1;
+ filter_dims.c = output_depth;
+
+ cmsis_nn_dims bias_dims;
+ bias_dims.n = 1;
+ bias_dims.h = 1;
+ bias_dims.w = 1;
+ bias_dims.c = output_depth;
+
+ cmsis_nn_dims output_dims;
+ output_dims.n = batches;
+ output_dims.h = 1;
+ output_dims.w = 1;
+ output_dims.c = output_depth;
+
+ int32_t buf_size = arm_fully_connected_s8_get_buffer_size(&filter_dims);
+ auto buffer = std::make_unique<int8_t[]>(buf_size);
+ assert(buffer != nullptr);
+
+ cmsis_nn_context ctx;
+ ctx.buf = buffer.get();
+ ctx.size = buf_size;
+
+ auto res =
+ arm_fully_connected_s8(&ctx, &fc_params, &quant_params, &input_dims, input_data, &filter_dims,
+ filter_data, &bias_dims, bias_data, &output_dims, output_data);
+ assert(res == ARM_MATH_SUCCESS);
+}
+} // namespace luci_interpreter_pal
+
+#endif // LUCI_INTERPRETER_PAL_FULLYCONNECTED_H
diff --git a/compiler/luci-interpreter/pal/cmsisnn/PALMul.h b/compiler/luci-interpreter/pal/cmsisnn/PALMul.h
index 2b46b100c..347a97a83 100644
--- a/compiler/luci-interpreter/pal/cmsisnn/PALMul.h
+++ b/compiler/luci-interpreter/pal/cmsisnn/PALMul.h
@@ -21,21 +21,21 @@
namespace luci_interpreter_pal
{
+template <typename T>
static inline void Mul(tflite::ArithmeticParams &params, const tflite::RuntimeShape &input1_shape,
- const float *input1_data, const tflite::RuntimeShape &input2_shape,
- const float *input2_data, const tflite::RuntimeShape &output_shape,
- float *output_data)
+ const T *input1_data, const tflite::RuntimeShape &input2_shape,
+ const T *input2_data, const tflite::RuntimeShape &output_shape,
+ T *output_data)
{
tflite::reference_ops::BroadcastMul4DSlow(params, input1_shape, input1_data, input2_shape,
input2_data, output_shape, output_data);
}
-static inline void BroadcastMul4DSlow(tflite::ArithmeticParams &params,
- const tflite::RuntimeShape &input1_shape,
- const float *input1_data,
- const tflite::RuntimeShape &input2_shape,
- const float *input2_data,
- const tflite::RuntimeShape &output_shape, float *output_data)
+template <typename T>
+static inline void
+BroadcastMul4DSlow(tflite::ArithmeticParams &params, const tflite::RuntimeShape &input1_shape,
+ const T *input1_data, const tflite::RuntimeShape &input2_shape,
+ const T *input2_data, const tflite::RuntimeShape &output_shape, T *output_data)
{
tflite::reference_ops::BroadcastMul4DSlow(params, input1_shape, input1_data, input2_shape,
input2_data, output_shape, output_data);
diff --git a/compiler/luci-interpreter/pal/cmsisnn/PALQuantize.h b/compiler/luci-interpreter/pal/cmsisnn/PALQuantize.h
new file mode 100644
index 000000000..6046789ae
--- /dev/null
+++ b/compiler/luci-interpreter/pal/cmsisnn/PALQuantize.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_PAL_QUANTIZE_H
+#define LUCI_INTERPRETER_PAL_QUANTIZE_H
+
+#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
+
+namespace luci_interpreter_pal
+{
+template <typename T>
+static inline void Quantize(tflite::QuantizationParams &params,
+ const tflite::RuntimeShape &input_shape, const float *input_data,
+ const tflite::RuntimeShape &output_shape, T *output_data)
+{
+ tflite::reference_ops::AffineQuantize(params, input_shape, input_data, output_shape, output_data);
+}
+
+template <typename Input, typename Output>
+static inline void Requantize(const Input *input_data, int32_t size,
+ int32_t effective_scale_multiplier, int32_t effective_scale_shift,
+ int32_t input_zero_point, int32_t output_zero_point,
+ Output *output_data)
+{
+ tflite::reference_ops::Requantize(input_data, size, effective_scale_multiplier,
+ effective_scale_shift, input_zero_point, output_zero_point,
+ output_data);
+}
+} // namespace luci_interpreter_pal
+
+#endif // LUCI_INTERPRETER_PAL_QUANTIZE_H
diff --git a/compiler/luci-interpreter/pal/cmsisnn/PALSVDF.h b/compiler/luci-interpreter/pal/cmsisnn/PALSVDF.h
new file mode 100644
index 000000000..a4a5b2a78
--- /dev/null
+++ b/compiler/luci-interpreter/pal/cmsisnn/PALSVDF.h
@@ -0,0 +1,190 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_PAL_SVDF_H
+#define LUCI_INTERPRETER_PAL_SVDF_H
+
+#include <arm_nn_types.h>
+#include <arm_nnfunctions.h>
+
+namespace luci_interpreter_pal
+{
+static inline void
+IntegerSVDF(const TfLiteSVDFParams &params, const tflite::RuntimeShape &input_shape,
+ const int8_t *input_data, const tflite::RuntimeShape &weight_feature_shape,
+ const int8_t *weight_feature_data, const tflite::RuntimeShape &weight_time_shape,
+ const int16_t *weight_time_data, const tflite::RuntimeShape &bias_shape,
+ const int32_t *bias_data, int16_t *activation_state_data,
+ const tflite::RuntimeShape &output_shape, int8_t *output_data, int32_t *scratchpad_data,
+ int32_t *output_temp_data, int32_t scale_1_a, int scale_1_b, int32_t scale_2_a,
+ int scale_2_b, int32_t input_zp, int32_t output_zp)
+{
+ const int32_t rank = params.rank;
+ const int32_t batch_size = input_shape.Dims(0);
+ const int32_t num_filters = weight_feature_shape.Dims(0);
+ const int32_t memory_size = weight_time_shape.Dims(1);
+
+ cmsis_nn_dims input_dims;
+ input_dims.n = input_shape.Dims(0);
+ input_dims.h = input_shape.Dims(1);
+
+ cmsis_nn_dims weights_feature_dims;
+ weights_feature_dims.n = weight_feature_shape.Dims(0);
+ weights_feature_dims.h = weight_feature_shape.Dims(1);
+
+ cmsis_nn_dims weights_time_dims;
+ weights_time_dims.n = weight_time_shape.Dims(0);
+ weights_time_dims.h = weight_time_shape.Dims(1);
+
+ cmsis_nn_dims bias_dims;
+ bias_dims.n = bias_shape.Dims(0);
+
+ cmsis_nn_dims state_dims;
+ state_dims.n = batch_size;
+ state_dims.h = memory_size * num_filters;
+
+ cmsis_nn_dims output_dims;
+ output_dims.n = output_shape.Dims(0);
+ output_dims.h = output_shape.Dims(1);
+
+ cmsis_nn_svdf_params svdf_params;
+ svdf_params.rank = params.rank;
+ svdf_params.input_offset = input_zp;
+ svdf_params.output_offset = output_zp;
+
+ svdf_params.input_activation.min = INT16_MIN;
+ svdf_params.input_activation.max = INT16_MAX;
+
+ svdf_params.output_activation.min = INT8_MIN;
+ svdf_params.output_activation.max = INT8_MAX;
+
+ cmsis_nn_per_tensor_quant_params in_quant_params;
+ in_quant_params.multiplier = scale_1_a;
+ in_quant_params.shift = scale_1_b;
+
+ cmsis_nn_per_tensor_quant_params out_quant_params;
+ out_quant_params.multiplier = scale_2_a;
+ out_quant_params.shift = scale_2_b;
+
+ cmsis_nn_context scratch_ctx;
+ scratch_ctx.buf = scratchpad_data;
+
+ cmsis_nn_context scratch_output_ctx;
+ scratch_output_ctx.buf = output_temp_data;
+
+ arm_svdf_s8(&scratch_ctx, &scratch_output_ctx, &svdf_params, &in_quant_params, &out_quant_params,
+ &input_dims, input_data, &state_dims, activation_state_data, &weights_feature_dims,
+ weight_feature_data, &weights_time_dims, weight_time_data, &bias_dims, bias_data,
+ &output_dims, output_data);
+}
+static inline void
+FloatSVDF(const TfLiteSVDFParams &params, const tflite::RuntimeShape &input_shape,
+ const float *input_data, const tflite::RuntimeShape &weight_feature_shape,
+ const float *weight_feature_data, const tflite::RuntimeShape &weight_time_shape,
+ const float *weight_time_data, const tflite::RuntimeShape &bias_shape,
+ const float *bias_data, float *scratchpad_data, float *activation_state_data,
+ const tflite::RuntimeShape &output_shape, float *output_data)
+{
+ const int32_t rank = params.rank;
+ const int32_t batch_size = input_shape.Dims(0);
+ const int32_t input_size = input_shape.Dims(1);
+ const int32_t num_filters = weight_feature_shape.Dims(0);
+ const int32_t num_units = num_filters / rank;
+ const int32_t memory_size = weight_time_shape.Dims(1);
+
+ // Left shift the activation_state.
+ {
+ float *new_state_start = activation_state_data;
+ const float *old_state_start = activation_state_data + 1;
+ const float *old_state_end = activation_state_data + batch_size * num_filters * memory_size;
+ while (old_state_start != old_state_end)
+ {
+ *new_state_start++ = *old_state_start++;
+ }
+ }
+
+ // Note: no need to clear the latest activation, matmul is not accumulative.
+
+ // Compute conv1d(inputs, weights_feature).
+ // The activation_state's rightmost column is used to save current cycle
+ // activation. This is achieved by starting at state_ptr[memory_size - 1] and
+ // having the stride equal to memory_size.
+
+ // Perform batched matrix vector multiply operation:
+ {
+ const float *matrix = weight_feature_data;
+ const float *vector = input_data;
+ float *result = &activation_state_data[memory_size - 1];
+ float *result_in_batch = result;
+ for (int i = 0; i < batch_size; ++i)
+ {
+ const float *matrix_ptr = matrix;
+ for (int j = 0; j < num_filters; ++j)
+ {
+ float dot_prod = 0.0f;
+ const float *vector_in_batch = vector + i * input_size;
+ for (int k = 0; k < input_size; ++k)
+ {
+ dot_prod += *matrix_ptr++ * *vector_in_batch++;
+ }
+ *result_in_batch = dot_prod;
+ result_in_batch += memory_size;
+ }
+ }
+ }
+
+ tflite::reference_ops::ApplyTimeWeightsBiasAndActivation(
+ batch_size, memory_size, num_filters, num_units, rank, weight_time_data, bias_data,
+ params.activation, activation_state_data, scratchpad_data, output_data);
+}
+
+static inline void SetupScratchpadTensor(
+ const luci_interpreter::DataType &input_data_type,
+ const luci_interpreter::DataType &weight_feature_data_type,
+ luci_interpreter::Tensor *scratchpad_1, luci_interpreter::Tensor *scratchpad_2,
+ luci_interpreter::Tensor *scratchpad_3, luci_interpreter::Tensor *scratchpad_4,
+ luci_interpreter::Tensor *scratchpad_5, luci_interpreter::Tensor *scratchpad_6,
+ const luci_interpreter::Shape input_shape, const luci_interpreter::Shape weight_time_shape,
+ const int32_t batch_size, const int32_t num_filters, const int32_t num_units)
+{
+ if (input_data_type == loco::DataType::FLOAT32 &&
+ (weight_feature_data_type == loco::DataType::S8 ||
+ weight_feature_data_type == loco::DataType::U8))
+ {
+ (void)input_shape;
+ (void)weight_time_shape;
+ (void)scratchpad_3;
+ (void)scratchpad_4;
+ (void)scratchpad_5;
+ (void)scratchpad_6;
+
+ throw std::runtime_error("Hybrid type is not supported for cmsisnn");
+ }
+
+ // Resize scratchpad_1 tensor
+ scratchpad_1->resize({batch_size, num_filters});
+
+ if (input_data_type == loco::DataType::S8)
+ {
+ // Resize scratchpad_2 for full_integer op
+ scratchpad_2->resize({batch_size, num_units});
+ }
+}
+
+} // namespace luci_interpreter_pal
+
+#endif // LUCI_INTERPRETER_PAL_SVDF_H
diff --git a/compiler/luci-interpreter/pal/cmsisnn/pal.cmake b/compiler/luci-interpreter/pal/cmsisnn/pal.cmake
index 9a25a3c5d..a68b363d9 100644
--- a/compiler/luci-interpreter/pal/cmsisnn/pal.cmake
+++ b/compiler/luci-interpreter/pal/cmsisnn/pal.cmake
@@ -42,9 +42,12 @@ macro(add_pal_to_target TGT)
"${TensorFlowSource_DIR}")
target_include_directories(${TGT} PRIVATE ${LUCI_INTERPRETER_PAL_DIR})
- set(PAL_SOURCES ${TensorFlowSource_DIR}/tensorflow/lite/kernels/internal/quantization_util.cc)
+ file(GLOB_RECURSE PAL_SOURCES "${CMSISSource_DIR}/CMSIS/NN/Source/*.c")
+ list(APPEND PAL_SOURCES ${TensorFlowSource_DIR}/tensorflow/lite/kernels/internal/quantization_util.cc
+ ${TensorFlowSource_DIR}/tensorflow/lite/kernels/internal/tensor_utils.cc
+ ${TensorFlowSource_DIR}/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc)
add_library(luci_interpreter_cmsisnn_pal STATIC ${PAL_SOURCES})
- set_target_properties(luci_interpreter_cmsisnn_pal PROPERTIES POSITION_INDEPENDENT_CODE ON)
+ set_property(TARGET luci_interpreter_cmsisnn_pal PROPERTY POSITION_INDEPENDENT_CODE ON)
target_include_directories(luci_interpreter_cmsisnn_pal PRIVATE
"${TensorFlowRuySource_DIR}"
"${TensorFlowGEMMLowpSource_DIR}"
@@ -53,7 +56,7 @@ macro(add_pal_to_target TGT)
)
add_subdirectory(${CMSISSource_DIR}/CMSIS/NN ${CMAKE_CURRENT_BINARY_DIR}/CMSISNN)
- target_include_directories(luci_interpreter_cmsisnn_pal PRIVATE
+ target_include_directories(luci_interpreter_cmsisnn_pal PUBLIC
"${CMSISSource_DIR}/CMSIS/NN/Include"
"${CMSISSource_DIR}/CMSIS/DSP/Include"
"${CMSISSource_DIR}/CMSIS/Core/Include")
diff --git a/compiler/luci-interpreter/pal/linux/KernelsToBuild.lst b/compiler/luci-interpreter/pal/linux/KernelsToBuild.lst
index 9d541276c..428b15ee0 100644
--- a/compiler/luci-interpreter/pal/linux/KernelsToBuild.lst
+++ b/compiler/luci-interpreter/pal/linux/KernelsToBuild.lst
@@ -1,19 +1,23 @@
REGISTER_KERNEL(Add)
REGISTER_KERNEL(ArgMax)
REGISTER_KERNEL(AveragePool2D)
+REGISTER_KERNEL(BatchMatMul)
REGISTER_KERNEL(BatchToSpaceND)
REGISTER_KERNEL(Cast)
REGISTER_KERNEL(Concatenation)
REGISTER_KERNEL(Conv2D)
REGISTER_KERNEL(DepthToSpace)
REGISTER_KERNEL(DepthwiseConv2D)
+REGISTER_KERNEL(Dequantize)
REGISTER_KERNEL(Div)
REGISTER_KERNEL(Elu)
REGISTER_KERNEL(Exp)
+REGISTER_KERNEL(ExpandDims)
REGISTER_KERNEL(Floor)
REGISTER_KERNEL(FloorDiv)
REGISTER_KERNEL(Equal)
REGISTER_KERNEL(FullyConnected)
+REGISTER_KERNEL(Gather)
REGISTER_KERNEL(Greater)
REGISTER_KERNEL(GreaterEqual)
REGISTER_KERNEL(If)
@@ -37,11 +41,13 @@ REGISTER_KERNEL(MirrorPad)
REGISTER_KERNEL(Mul)
REGISTER_KERNEL(Neg)
REGISTER_KERNEL(NotEqual)
+REGISTER_KERNEL(OneHot)
REGISTER_KERNEL(Pack)
REGISTER_KERNEL(Pad)
REGISTER_KERNEL(PadV2)
REGISTER_KERNEL(Pow)
REGISTER_KERNEL(PRelu)
+REGISTER_KERNEL(Quantize)
REGISTER_KERNEL(Relu)
REGISTER_KERNEL(Relu6)
REGISTER_KERNEL(Reshape)
@@ -61,6 +67,7 @@ REGISTER_KERNEL(Square)
REGISTER_KERNEL(SquaredDifference)
REGISTER_KERNEL(Squeeze)
REGISTER_KERNEL(Sub)
+REGISTER_KERNEL(SVDF)
REGISTER_KERNEL(Tanh)
REGISTER_KERNEL(Transpose)
REGISTER_KERNEL(TransposeConv)
diff --git a/compiler/luci-interpreter/pal/linux/PALAveragePool2d.h b/compiler/luci-interpreter/pal/linux/PALAveragePool2d.h
new file mode 100644
index 000000000..cce30601f
--- /dev/null
+++ b/compiler/luci-interpreter/pal/linux/PALAveragePool2d.h
@@ -0,0 +1,73 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_PAL_AVERAGEPOOL2D_H
+#define LUCI_INTERPRETER_PAL_AVERAGEPOOL2D_H
+
+#include <tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h>
+#include <tensorflow/lite/kernels/internal/reference/pooling.h>
+
+namespace luci_interpreter_pal
+{
+template <typename T>
+static inline void AveragePool(const tflite::PoolParams &params,
+ const tflite::RuntimeShape &input_shape, const T *input_data,
+ const tflite::RuntimeShape &output_shape, T *output_data,
+ const tflite::RuntimeShape &scratchpad_shape, T *scratchpad_data)
+{
+ {
+ // MARK: At this moment this operation doesn't support
+ assert(false && "AveragePool NYI");
+ (void)params;
+ (void)input_shape;
+ (void)input_data;
+ (void)output_shape;
+ (void)output_data;
+ (void)scratchpad_shape;
+ (void)scratchpad_data;
+ }
+}
+
+template <>
+inline void AveragePool<int8_t>(const tflite::PoolParams &params,
+ const tflite::RuntimeShape &input_shape, const int8_t *input_data,
+ const tflite::RuntimeShape &output_shape, int8_t *output_data,
+ const tflite::RuntimeShape &scratchpad_shape,
+ int8_t *scratchpad_data)
+{
+ (void)scratchpad_shape;
+ (void)scratchpad_data;
+
+ tflite::reference_integer_ops::AveragePool(params, input_shape, input_data, output_shape,
+ output_data);
+}
+
+static inline void SetupScratchpadTensor(luci_interpreter::Tensor *scratchpad,
+ const luci_interpreter::DataType &input_data_type,
+ const tflite::RuntimeShape &input_shape,
+ const tflite::RuntimeShape &output_shape)
+
+{
+ (void)input_data_type;
+ (void)input_shape;
+ (void)output_shape;
+
+ scratchpad->set_allocatable(false);
+}
+
+} // namespace luci_interpreter_pal
+
+#endif // LUCI_INTERPRETER_PAL_AVERAGEPOOL2D_H
diff --git a/compiler/luci-interpreter/pal/linux/PALBatchMatMul.h b/compiler/luci-interpreter/pal/linux/PALBatchMatMul.h
new file mode 100644
index 000000000..3894f2d92
--- /dev/null
+++ b/compiler/luci-interpreter/pal/linux/PALBatchMatMul.h
@@ -0,0 +1,67 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_PAL_BATCHMATMUL_H
+#define LUCI_INTERPRETER_PAL_BATCHMATMUL_H
+
+#include <tensorflow/lite/kernels/internal/reference/batch_matmul.h>
+
+namespace luci_interpreter_pal
+{
+inline void BatchMatMul(const tflite::RuntimeShape &lhs_shape, const float *lhs_data,
+ const tflite::RuntimeShape &rhs_shape, const float *rhs_data,
+ const tflite::RuntimeShape &output_shape, float *output_data)
+{
+ tflite::reference_ops::BatchMatMul(lhs_shape, lhs_data, rhs_shape, rhs_data, output_shape,
+ output_data);
+}
+
+static inline void SetupScratchpadTensor(luci_interpreter::Tensor *lhs_scratchpad,
+ luci_interpreter::Tensor *rhs_scratchpad,
+ const tflite::RuntimeShape &lhs_shape,
+ const tflite::RuntimeShape &rhs_shape)
+{
+ // Scratchpad for transposed LHS
+ {
+ auto lhs_rank = lhs_shape.DimensionsCount();
+ luci_interpreter::Shape scratchpad_size(lhs_rank);
+ for (int i = 0; i < lhs_rank - 2; ++i)
+ {
+ scratchpad_size.dim(i) = lhs_shape.Dims(i);
+ }
+ scratchpad_size.dim(lhs_rank - 2) = lhs_shape.Dims(lhs_rank - 1);
+ scratchpad_size.dim(lhs_rank - 1) = lhs_shape.Dims(lhs_rank - 2);
+
+ lhs_scratchpad->resize(scratchpad_size);
+ }
+ // Scratchpad for transposed RHS
+ {
+ auto rhs_rank = rhs_shape.DimensionsCount();
+ luci_interpreter::Shape scratchpad_size(rhs_rank);
+ for (int i = 0; i < rhs_rank - 2; ++i)
+ {
+ scratchpad_size.dim(i) = rhs_shape.Dims(i);
+ }
+ scratchpad_size.dim(rhs_rank - 2) = rhs_shape.Dims(rhs_rank - 1);
+ scratchpad_size.dim(rhs_rank - 1) = rhs_shape.Dims(rhs_rank - 2);
+
+ rhs_scratchpad->resize(scratchpad_size);
+ }
+}
+
+} // namespace luci_interpreter_pal
+
+#endif // LUCI_INTERPRETER_PAL_BATCHMATMUL_H
diff --git a/compiler/luci-interpreter/pal/linux/PALConv2d.h b/compiler/luci-interpreter/pal/linux/PALConv2d.h
index 2550dd5d7..985a15f39 100644
--- a/compiler/luci-interpreter/pal/linux/PALConv2d.h
+++ b/compiler/luci-interpreter/pal/linux/PALConv2d.h
@@ -26,14 +26,24 @@ static inline void Conv(const tflite::ConvParams &params, const tflite::RuntimeS
const float *input_data, const tflite::RuntimeShape &filter_shape,
const float *filter_data, const tflite::RuntimeShape &bias_shape,
const float *bias_data, const tflite::RuntimeShape &output_shape,
- float *output_data, const tflite::RuntimeShape &im2col_shape,
- float *im2col_data)
+ float *output_data, const tflite::RuntimeShape &scratchpad_shape,
+ float *scratchpad_data)
{
- if (im2col_data)
+ (void)scratchpad_shape;
+ if (scratchpad_data)
{
+ const int32_t batches = tflite::MatchingDim(input_shape, 0, output_shape, 0);
+ const int32_t input_depth = tflite::MatchingDim(input_shape, 3, filter_shape, 3);
+ const int32_t output_height = output_shape.Dims(1);
+ const int32_t output_width = output_shape.Dims(2);
+ const int32_t filter_height = filter_shape.Dims(1);
+ const int32_t filter_width = filter_shape.Dims(2);
+ tflite::RuntimeShape im2col_shape{batches, output_height, output_width,
+ input_depth * filter_height * filter_width};
+
tflite::optimized_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
bias_shape, bias_data, output_shape, output_data, im2col_shape,
- im2col_data);
+ scratchpad_data);
}
else
tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
@@ -45,8 +55,8 @@ static inline void Conv(const tflite::ConvParams &params, const tflite::RuntimeS
const uint8 *input_data, const tflite::RuntimeShape &filter_shape,
const uint8 *filter_data, const tflite::RuntimeShape &bias_shape,
const int32 *bias_data, const tflite::RuntimeShape &output_shape,
- uint8 *output_data, const tflite::RuntimeShape &im2col_shape,
- uint8 *im2col_data)
+ uint8 *output_data, const tflite::RuntimeShape &scratchpad_shape,
+ uint8 *scratchpad_data)
{
// TODO This should only be done once (although it takes only a few microseconds).
// Also, the user should be able to adjust the number of threads.
@@ -54,8 +64,8 @@ static inline void Conv(const tflite::ConvParams &params, const tflite::RuntimeS
gemmlowp_context->set_max_num_threads(static_cast<int>(std::thread::hardware_concurrency()));
tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
- bias_shape, bias_data, output_shape, output_data, im2col_shape,
- im2col_data, gemmlowp_context.get());
+ bias_shape, bias_data, output_shape, output_data, scratchpad_shape,
+ scratchpad_data, gemmlowp_context.get());
}
static inline void ConvPerChannel(const tflite::ConvParams &params, const int32_t *mult,
@@ -63,17 +73,55 @@ static inline void ConvPerChannel(const tflite::ConvParams &params, const int32_
const int8 *input_data, const tflite::RuntimeShape &filter_shape,
const int8 *filter_data, const tflite::RuntimeShape &bias_shape,
const int32 *bias_data, const tflite::RuntimeShape &output_shape,
- int8 *output_data, const tflite::RuntimeShape &im2col_shape,
- int8 *im2col_data)
+ int8 *output_data, const tflite::RuntimeShape &scratchpad_shape,
+ int8 *scratchpad_data)
{
- (void)im2col_shape;
- (void)im2col_data;
+ (void)scratchpad_shape;
+ (void)scratchpad_data;
// TODO enable optimized version
tflite::reference_integer_ops::ConvPerChannel(params, mult, shifts, input_shape, input_data,
filter_shape, filter_data, bias_shape, bias_data,
output_shape, output_data);
}
+static inline void SetupScratchpadTensor(luci_interpreter::Tensor *scratchpad,
+ const luci_interpreter::DataType &input_data_type,
+ const tflite::ConvParams &params,
+ const tflite::RuntimeShape &input_shape,
+ const tflite::RuntimeShape &filter_shape,
+ const tflite::RuntimeShape &output_shape)
+{
+ const int32_t filter_height = filter_shape.Dims(1);
+ const int32_t filter_width = filter_shape.Dims(2);
+
+ // Allocate tensor for scratchpad, if needed.
+ // The checks here should be aligned with the actual implementation.
+ const bool need_dilated_scratchpad =
+ params.dilation_height_factor != 1 || params.dilation_width_factor != 1;
+ const bool need_non_dilated_scratchpad = params.stride_height != 1 || params.stride_width != 1 ||
+ filter_height != 1 || filter_width != 1;
+ auto _need_scratchpad = input_data_type != luci_interpreter::DataType::S16 &&
+ (need_dilated_scratchpad || need_non_dilated_scratchpad);
+
+ if (_need_scratchpad)
+ {
+ const int32_t batches = tflite::MatchingDim(input_shape, 0, output_shape, 0);
+ const int32_t input_depth = tflite::MatchingDim(input_shape, 3, filter_shape, 3);
+ const int32_t output_height = output_shape.Dims(1);
+ const int32_t output_width = output_shape.Dims(2);
+
+ auto data_type_size = static_cast<int32_t>(luci_interpreter::getDataTypeSize(input_data_type));
+ int32_t scratchpad_size = batches * output_width * output_height * input_depth * filter_height *
+ filter_width * data_type_size;
+ luci_interpreter::Shape scratchpad_shape{scratchpad_size};
+ scratchpad->resize(scratchpad_shape);
+ }
+ else
+ {
+ scratchpad->set_allocatable(false);
+ }
+}
+
} // namespace luci_interpreter_pal
#endif // LUCI_INTERPRETER_PAL_CONV2D_H
diff --git a/compiler/luci-interpreter/pal/linux/PALDepthwiseConv2d.h b/compiler/luci-interpreter/pal/linux/PALDepthwiseConv2d.h
new file mode 100644
index 000000000..c9d1a2948
--- /dev/null
+++ b/compiler/luci-interpreter/pal/linux/PALDepthwiseConv2d.h
@@ -0,0 +1,91 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_PAL_DEPTHWISECONV2D_H
+#define LUCI_INTERPRETER_PAL_DEPTHWISECONV2D_H
+
+#include <tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h>
+#include <tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h>
+#include <tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h>
+
+namespace luci_interpreter_pal
+{
+template <typename T>
+static inline void
+DepthwiseConvPerChannel(const tflite::DepthwiseParams &params, const int32_t *output_multiplier,
+ const int32_t *output_shift, const tflite::RuntimeShape &input_shape,
+ const T *input_data, const tflite::RuntimeShape &filter_shape,
+ const T *filter_data, const tflite::RuntimeShape &bias_shape,
+ const int32_t *bias_data, const tflite::RuntimeShape &output_shape,
+ T *output_data, const tflite::RuntimeShape &scratchpad_shape,
+ T *scratchpad_data)
+{
+ {
+ // MARK: At this moment this operation is not supported
+ assert(false && "DepthwiseConvPerChannel NYI");
+ (void)params;
+ (void)output_multiplier;
+ (void)output_shift;
+ (void)input_shape;
+ (void)output_data;
+ (void)input_data;
+ (void)filter_shape;
+ (void)filter_data;
+ (void)bias_shape;
+ (void)bias_data;
+ (void)output_shape;
+ (void)output_data;
+ (void)scratchpad_shape;
+ (void)scratchpad_data;
+ }
+}
+
+template <>
+inline void DepthwiseConvPerChannel<int8_t>(
+ const tflite::DepthwiseParams &params, const int32_t *output_multiplier,
+ const int32_t *output_shift, const tflite::RuntimeShape &input_shape, const int8_t *input_data,
+ const tflite::RuntimeShape &filter_shape, const int8_t *filter_data,
+ const tflite::RuntimeShape &bias_shape, const int32_t *bias_data,
+ const tflite::RuntimeShape &output_shape, int8_t *output_data,
+ const tflite::RuntimeShape &scratchpad_shape, int8_t *scratchpad_data)
+{
+ (void)scratchpad_shape;
+ (void)scratchpad_data;
+ tflite::reference_integer_ops::DepthwiseConvPerChannel(
+ params, output_multiplier, output_shift, input_shape, input_data, filter_shape, filter_data,
+ bias_shape, bias_data, output_shape, output_data);
+}
+
+static inline void SetupScratchpadTensor(luci_interpreter::Tensor *scratchpad,
+ const tflite::DepthwiseParams &params,
+ const luci_interpreter::DataType &input_data_type,
+ const tflite::RuntimeShape &input_shape,
+ const tflite::RuntimeShape &filter_shape,
+ const tflite::RuntimeShape &output_shape)
+
+{
+ (void)params;
+ (void)input_data_type;
+ (void)input_shape;
+ (void)filter_shape;
+ (void)output_shape;
+
+ scratchpad->set_allocatable(false);
+}
+
+} // namespace luci_interpreter_pal
+
+#endif // LUCI_INTERPRETER_PAL_DEPTHWISECONV2D_H
diff --git a/compiler/luci-interpreter/pal/linux/PALDequantize.h b/compiler/luci-interpreter/pal/linux/PALDequantize.h
new file mode 100644
index 000000000..3af6d0777
--- /dev/null
+++ b/compiler/luci-interpreter/pal/linux/PALDequantize.h
@@ -0,0 +1,34 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_PAL_DEQUANTIZE_H
+#define LUCI_INTERPRETER_PAL_DEQUANTIZE_H
+
+#include <tensorflow/lite/kernels/internal/optimized/optimized_ops.h>
+
+namespace luci_interpreter_pal
+{
+template <typename T>
+static inline void Dequantize(tflite::DequantizationParams &params,
+ const tflite::RuntimeShape &input_shape, const T *input_data,
+ const tflite::RuntimeShape &output_shape, float *output_data)
+{
+ tflite::optimized_ops::Dequantize(params, input_shape, input_data, output_shape, output_data);
+}
+
+} // namespace luci_interpreter_pal
+
+#endif // LUCI_INTERPRETER_PAL_DEQUANTIZE_H
diff --git a/compiler/luci-interpreter/pal/linux/PALFullyConnected.h b/compiler/luci-interpreter/pal/linux/PALFullyConnected.h
new file mode 100644
index 000000000..62970dbf7
--- /dev/null
+++ b/compiler/luci-interpreter/pal/linux/PALFullyConnected.h
@@ -0,0 +1,61 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_PAL_FULLYCONNECTED_H
+#define LUCI_INTERPRETER_PAL_FULLYCONNECTED_H
+
+#include <tensorflow/lite/kernels/internal/reference/fully_connected.h>
+#include <tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h>
+
+namespace luci_interpreter_pal
+{
+template <typename T>
+static inline void FullyConnected(const tflite::FullyConnectedParams &params,
+ const tflite::RuntimeShape &input_shape, const T *input_data,
+ const tflite::RuntimeShape &filter_shape, const T *filter_data,
+ const tflite::RuntimeShape &bias_shape, const int32_t *bias_data,
+ const tflite::RuntimeShape &output_shape, T *output_data)
+{
+ {
+ // MARK: At this moment this operation doesn't support
+ assert(false && "FullyConnected NYI");
+ (void)params;
+ (void)input_shape;
+ (void)input_data;
+ (void)filter_shape;
+ (void)filter_data;
+ (void)bias_shape;
+ (void)bias_data;
+ (void)output_shape;
+ (void)output_data;
+ }
+}
+
+template <>
+inline void
+FullyConnected<int8_t>(const tflite::FullyConnectedParams &params,
+ const tflite::RuntimeShape &input_shape, const int8_t *input_data,
+ const tflite::RuntimeShape &filter_shape, const int8_t *filter_data,
+ const tflite::RuntimeShape &bias_shape, const int32_t *bias_data,
+ const tflite::RuntimeShape &output_shape, int8_t *output_data)
+{
+ tflite::reference_integer_ops::FullyConnected(params, input_shape, input_data, filter_shape,
+ filter_data, bias_shape, bias_data, output_shape,
+ output_data);
+}
+} // namespace luci_interpreter_pal
+
+#endif // LUCI_INTERPRETER_PAL_FULLYCONNECTED_H
diff --git a/compiler/luci-interpreter/pal/linux/PALGather.h b/compiler/luci-interpreter/pal/linux/PALGather.h
new file mode 100644
index 000000000..49ac35f93
--- /dev/null
+++ b/compiler/luci-interpreter/pal/linux/PALGather.h
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_PAL_GATHER_H
+#define LUCI_INTERPRETER_PAL_GATHER_H
+
+#include <tensorflow/lite/kernels/internal/optimized/optimized_ops.h>
+
+namespace luci_interpreter_pal
+{
+template <typename T, typename CoordsT = int32>
+static inline void Gather(const tflite::GatherParams &op_params,
+ const tflite::RuntimeShape &input_shape, const T *input_data,
+ const tflite::RuntimeShape &coords_shape, const CoordsT *coords_data,
+ const tflite::RuntimeShape &output_shape, T *output_data)
+{
+ tflite::optimized_ops::Gather(op_params, input_shape, input_data, coords_shape, coords_data,
+ output_shape, output_data);
+}
+} // namespace luci_interpreter_pal
+
+#endif // LUCI_INTERPRETER_PAL_GATHER_H
diff --git a/compiler/luci-interpreter/pal/linux/PALMul.h b/compiler/luci-interpreter/pal/linux/PALMul.h
index cfaec1b58..a8a9d4abc 100644
--- a/compiler/luci-interpreter/pal/linux/PALMul.h
+++ b/compiler/luci-interpreter/pal/linux/PALMul.h
@@ -21,21 +21,31 @@
namespace luci_interpreter_pal
{
+template <typename T>
static inline void Mul(tflite::ArithmeticParams &params, const tflite::RuntimeShape &input1_shape,
- const float *input1_data, const tflite::RuntimeShape &input2_shape,
- const float *input2_data, const tflite::RuntimeShape &output_shape,
- float *output_data)
+ const T *input1_data, const tflite::RuntimeShape &input2_shape,
+ const T *input2_data, const tflite::RuntimeShape &output_shape,
+ T *output_data)
{
tflite::optimized_ops::Mul(params, input1_shape, input1_data, input2_shape, input2_data,
output_shape, output_data);
}
-static inline void BroadcastMul4DSlow(tflite::ArithmeticParams &params,
- const tflite::RuntimeShape &input1_shape,
- const float *input1_data,
- const tflite::RuntimeShape &input2_shape,
- const float *input2_data,
- const tflite::RuntimeShape &output_shape, float *output_data)
+template <>
+inline void Mul(tflite::ArithmeticParams &params, const tflite::RuntimeShape &input1_shape,
+ const int64_t *input1_data, const tflite::RuntimeShape &input2_shape,
+ const int64_t *input2_data, const tflite::RuntimeShape &output_shape,
+ int64_t *output_data)
+{
+ tflite::optimized_ops::BroadcastMul4DSlow(params, input1_shape, input1_data, input2_shape,
+ input2_data, output_shape, output_data);
+}
+
+template <typename T>
+static inline void
+BroadcastMul4DSlow(tflite::ArithmeticParams &params, const tflite::RuntimeShape &input1_shape,
+ const T *input1_data, const tflite::RuntimeShape &input2_shape,
+ const T *input2_data, const tflite::RuntimeShape &output_shape, T *output_data)
{
tflite::optimized_ops::BroadcastMul4DSlow(params, input1_shape, input1_data, input2_shape,
input2_data, output_shape, output_data);
diff --git a/compiler/luci-interpreter/pal/linux/PALQuantize.h b/compiler/luci-interpreter/pal/linux/PALQuantize.h
new file mode 100644
index 000000000..bf1d7954e
--- /dev/null
+++ b/compiler/luci-interpreter/pal/linux/PALQuantize.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_PAL_QUANTIZE_H
+#define LUCI_INTERPRETER_PAL_QUANTIZE_H
+
+#include <tensorflow/lite/kernels/internal/optimized/optimized_ops.h>
+
+namespace luci_interpreter_pal
+{
+template <typename T>
+static inline void Quantize(tflite::QuantizationParams &params,
+ const tflite::RuntimeShape &input_shape, const float *input_data,
+ const tflite::RuntimeShape &output_shape, T *output_data)
+{
+ tflite::optimized_ops::AffineQuantize(params, input_shape, input_data, output_shape, output_data);
+}
+
+template <typename Input, typename Output>
+static inline void Requantize(const Input *input_data, int32_t size,
+ int32_t effective_scale_multiplier, int32_t effective_scale_shift,
+ int32_t input_zero_point, int32_t output_zero_point,
+ Output *output_data)
+{
+ tflite::optimized_ops::Requantize(input_data, size, effective_scale_multiplier,
+ effective_scale_shift, input_zero_point, output_zero_point,
+ output_data);
+}
+} // namespace luci_interpreter_pal
+
+#endif // LUCI_INTERPRETER_PAL_QUANTIZE_H
diff --git a/compiler/luci-interpreter/pal/linux/PALSVDF.h b/compiler/luci-interpreter/pal/linux/PALSVDF.h
new file mode 100644
index 000000000..0ffba14f0
--- /dev/null
+++ b/compiler/luci-interpreter/pal/linux/PALSVDF.h
@@ -0,0 +1,90 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_PAL_SVDF_H
+#define LUCI_INTERPRETER_PAL_SVDF_H
+
+#include <tensorflow/lite/kernels/internal/reference/svdf.h>
+
+namespace luci_interpreter_pal
+{
+static inline void
+IntegerSVDF(const TfLiteSVDFParams &params, const tflite::RuntimeShape &input_shape,
+ const int8_t *input_data, const tflite::RuntimeShape &weight_feature_shape,
+ const int8_t *weight_feature_data, const tflite::RuntimeShape &weight_time_shape,
+ const int16_t *weight_time_data, const tflite::RuntimeShape &bias_shape,
+ const int32_t *bias_data, int16_t *activation_state_data,
+ const tflite::RuntimeShape &output_shape, int8_t *output_data, int32_t *scratchpad_data,
+ int32_t *output_temp_data, int32_t scale_1_a, int scale_1_b, int32_t scale_2_a,
+ int scale_2_b, int32_t input_zp, int32_t output_zp)
+{
+ tflite::reference_ops::EvalIntegerSVDF(&params, input_shape, input_data, weight_feature_shape,
+ weight_feature_data, weight_time_shape, weight_time_data,
+ bias_shape, bias_data, activation_state_data, output_shape,
+ output_data, scratchpad_data, output_temp_data, scale_1_a,
+ scale_1_b, scale_2_a, scale_2_b, input_zp, output_zp);
+}
+static inline void
+FloatSVDF(const TfLiteSVDFParams &params, const tflite::RuntimeShape &input_shape,
+ const float *input_data, const tflite::RuntimeShape &weight_feature_shape,
+ const float *weight_feature_data, const tflite::RuntimeShape &weight_time_shape,
+ const float *weight_time_data, const tflite::RuntimeShape &bias_shape,
+ const float *bias_data, float *scratchpad_data, float *activation_state_data,
+ const tflite::RuntimeShape &output_shape, float *output_data)
+{
+ tflite::reference_ops::EvalFloatSVDF(&params, input_shape, input_data, weight_feature_shape,
+ weight_feature_data, weight_time_shape, weight_time_data,
+ bias_shape, bias_data, scratchpad_data,
+ activation_state_data, output_shape, output_data);
+}
+
+static inline void SetupScratchpadTensor(
+ const luci_interpreter::DataType &input_data_type,
+ const luci_interpreter::DataType &weight_feature_data_type,
+ luci_interpreter::Tensor *scratchpad_1, luci_interpreter::Tensor *scratchpad_2,
+ luci_interpreter::Tensor *scratchpad_3, luci_interpreter::Tensor *scratchpad_4,
+ luci_interpreter::Tensor *scratchpad_5, luci_interpreter::Tensor *scratchpad_6,
+ const luci_interpreter::Shape input_shape, const luci_interpreter::Shape weight_time_shape,
+ const int32_t batch_size, const int32_t num_filters, const int32_t num_units)
+{
+
+ if (input_data_type == loco::DataType::FLOAT32 &&
+ (weight_feature_data_type == loco::DataType::S8 ||
+ weight_feature_data_type == loco::DataType::U8))
+ {
+ (void)input_shape;
+ (void)weight_time_shape;
+ (void)scratchpad_3;
+ (void)scratchpad_4;
+ (void)scratchpad_5;
+ (void)scratchpad_6;
+
+ throw std::runtime_error("Hybrid type is not currently supported for linux platform");
+ }
+
+ // Resize scratchpad_1 tensor
+ scratchpad_1->resize({batch_size, num_filters});
+
+ if (input_data_type == loco::DataType::S8)
+ {
+ // Resize scratchpad_2 for full_integer op
+ scratchpad_2->resize({batch_size, num_units});
+ }
+}
+
+} // namespace luci_interpreter_pal
+
+#endif // LUCI_INTERPRETER_PAL_SVDF_H
diff --git a/compiler/luci-interpreter/pal/linux/pal.cmake b/compiler/luci-interpreter/pal/linux/pal.cmake
index 84349e0bf..185700cf9 100644
--- a/compiler/luci-interpreter/pal/linux/pal.cmake
+++ b/compiler/luci-interpreter/pal/linux/pal.cmake
@@ -40,7 +40,35 @@ macro(add_pal_to_target TGT)
# TODO put it back, I changed my mind.
# instead add sources with visitors in this library
- set(PAL_SOURCES ${TensorFlowSource_DIR}/tensorflow/lite/kernels/internal/quantization_util.cc)
+ set(PAL_SOURCES ${TensorFlowSource_DIR}/tensorflow/lite/kernels/internal/tensor_utils.cc
+ ${TensorFlowSource_DIR}/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc
+ ${TensorFlowSource_DIR}/tensorflow/lite/kernels/internal/quantization_util.cc)
+
+ if(BUILD_ARM32_NEON)
+ # NOTE may need to revise this list for version upgrade
+ set(PAL_SOURCES ${PAL_SOURCES}
+ ${TensorFlowSource_DIR}/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc
+ ${TensorFlowSource_DIR}/tensorflow/lite/kernels/internal/optimized/cpu_check.cc
+ ${TensorFlowRuySource_DIR}/ruy/allocator.cc
+ ${TensorFlowRuySource_DIR}/ruy/block_map.cc
+ ${TensorFlowRuySource_DIR}/ruy/blocking_counter.cc
+ ${TensorFlowRuySource_DIR}/ruy/context_get_ctx.cc
+ ${TensorFlowRuySource_DIR}/ruy/cpuinfo.cc
+ ${TensorFlowRuySource_DIR}/ruy/ctx.cc
+ ${TensorFlowRuySource_DIR}/ruy/denormal.cc
+ ${TensorFlowRuySource_DIR}/ruy/frontend.cc
+ ${TensorFlowRuySource_DIR}/ruy/pack_arm.cc
+ ${TensorFlowRuySource_DIR}/ruy/prepacked_cache.cc
+ ${TensorFlowRuySource_DIR}/ruy/prepare_packed_matrices.cc
+ ${TensorFlowRuySource_DIR}/ruy/system_aligned_alloc.cc
+ ${TensorFlowRuySource_DIR}/ruy/thread_pool.cc
+ ${TensorFlowRuySource_DIR}/ruy/trmul.cc
+ ${TensorFlowRuySource_DIR}/ruy/tune.cc
+ ${TensorFlowRuySource_DIR}/ruy/wait.cc
+ ${TensorFlowRuySource_DIR}/ruy/kernel_arm32.cc
+ )
+ endif(BUILD_ARM32_NEON)
+
add_library(luci_interpreter_linux_pal STATIC ${PAL_SOURCES})
set_target_properties(luci_interpreter_linux_pal PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_include_directories(luci_interpreter_linux_pal SYSTEM PRIVATE
diff --git a/compiler/luci-interpreter/pal/mcu/KernelsToBuild.lst b/compiler/luci-interpreter/pal/mcu/KernelsToBuild.lst
index 771974afe..d134a6b95 100644
--- a/compiler/luci-interpreter/pal/mcu/KernelsToBuild.lst
+++ b/compiler/luci-interpreter/pal/mcu/KernelsToBuild.lst
@@ -7,9 +7,11 @@ REGISTER_KERNEL(Concatenation)
REGISTER_KERNEL(Conv2D)
REGISTER_KERNEL(DepthToSpace)
REGISTER_KERNEL(DepthwiseConv2D)
+REGISTER_KERNEL(Dequantize)
REGISTER_KERNEL(Div)
REGISTER_KERNEL(Elu)
REGISTER_KERNEL(Exp)
+REGISTER_KERNEL(ExpandDims)
REGISTER_KERNEL(Floor)
REGISTER_KERNEL(FloorDiv)
REGISTER_KERNEL(Equal)
@@ -37,6 +39,7 @@ REGISTER_KERNEL(NotEqual)
REGISTER_KERNEL(Pad)
REGISTER_KERNEL(PadV2)
REGISTER_KERNEL(PRelu)
+REGISTER_KERNEL(Quantize)
REGISTER_KERNEL(Reshape)
REGISTER_KERNEL(ResizeBilinear)
REGISTER_KERNEL(ResizeNearestNeighbor)
@@ -50,6 +53,7 @@ REGISTER_KERNEL(Square)
REGISTER_KERNEL(SquaredDifference)
REGISTER_KERNEL(Squeeze)
REGISTER_KERNEL(Sub)
+REGISTER_KERNEL(SVDF)
REGISTER_KERNEL(Tanh)
REGISTER_KERNEL(Transpose)
REGISTER_KERNEL(TransposeConv)
diff --git a/compiler/luci-interpreter/pal/mcu/PALAveragePool2d.h b/compiler/luci-interpreter/pal/mcu/PALAveragePool2d.h
new file mode 100644
index 000000000..cce30601f
--- /dev/null
+++ b/compiler/luci-interpreter/pal/mcu/PALAveragePool2d.h
@@ -0,0 +1,73 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_PAL_AVERAGEPOOL2D_H
+#define LUCI_INTERPRETER_PAL_AVERAGEPOOL2D_H
+
+#include <tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h>
+#include <tensorflow/lite/kernels/internal/reference/pooling.h>
+
+namespace luci_interpreter_pal
+{
+template <typename T>
+static inline void AveragePool(const tflite::PoolParams &params,
+ const tflite::RuntimeShape &input_shape, const T *input_data,
+ const tflite::RuntimeShape &output_shape, T *output_data,
+ const tflite::RuntimeShape &scratchpad_shape, T *scratchpad_data)
+{
+ {
+ // MARK: At this moment this operation doesn't support
+ assert(false && "AveragePool NYI");
+ (void)params;
+ (void)input_shape;
+ (void)input_data;
+ (void)output_shape;
+ (void)output_data;
+ (void)scratchpad_shape;
+ (void)scratchpad_data;
+ }
+}
+
+template <>
+inline void AveragePool<int8_t>(const tflite::PoolParams &params,
+ const tflite::RuntimeShape &input_shape, const int8_t *input_data,
+ const tflite::RuntimeShape &output_shape, int8_t *output_data,
+ const tflite::RuntimeShape &scratchpad_shape,
+ int8_t *scratchpad_data)
+{
+ (void)scratchpad_shape;
+ (void)scratchpad_data;
+
+ tflite::reference_integer_ops::AveragePool(params, input_shape, input_data, output_shape,
+ output_data);
+}
+
+static inline void SetupScratchpadTensor(luci_interpreter::Tensor *scratchpad,
+ const luci_interpreter::DataType &input_data_type,
+ const tflite::RuntimeShape &input_shape,
+ const tflite::RuntimeShape &output_shape)
+
+{
+ (void)input_data_type;
+ (void)input_shape;
+ (void)output_shape;
+
+ scratchpad->set_allocatable(false);
+}
+
+} // namespace luci_interpreter_pal
+
+#endif // LUCI_INTERPRETER_PAL_AVERAGEPOOL2D_H
diff --git a/compiler/luci-interpreter/pal/mcu/PALConv2d.h b/compiler/luci-interpreter/pal/mcu/PALConv2d.h
index 0a8ae4e48..13976877a 100644
--- a/compiler/luci-interpreter/pal/mcu/PALConv2d.h
+++ b/compiler/luci-interpreter/pal/mcu/PALConv2d.h
@@ -26,11 +26,11 @@ static inline void Conv(const tflite::ConvParams &params, const tflite::RuntimeS
const float *input_data, const tflite::RuntimeShape &filter_shape,
const float *filter_data, const tflite::RuntimeShape &bias_shape,
const float *bias_data, const tflite::RuntimeShape &output_shape,
- float *output_data, const tflite::RuntimeShape &im2col_shape,
- float *im2col_data)
+ float *output_data, const tflite::RuntimeShape &scratchpad_shape,
+ float *scratchpad_data)
{
- (void)im2col_shape;
- (void)im2col_data;
+ (void)scratchpad_shape;
+ (void)scratchpad_data;
tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
bias_shape, bias_data, output_shape, output_data,
tflite::RuntimeShape(), nullptr);
@@ -40,14 +40,14 @@ static inline void Conv(const tflite::ConvParams &params, const tflite::RuntimeS
const uint8 *input_data, const tflite::RuntimeShape &filter_shape,
const uint8 *filter_data, const tflite::RuntimeShape &bias_shape,
const int32 *bias_data, const tflite::RuntimeShape &output_shape,
- uint8 *output_data, const tflite::RuntimeShape &im2col_shape,
- uint8 *im2col_data)
+ uint8 *output_data, const tflite::RuntimeShape &scratchpad_shape,
+ uint8 *scratchpad_data)
{
- (void)im2col_shape;
- (void)im2col_data;
+ (void)scratchpad_shape;
+ (void)scratchpad_data;
tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
- bias_shape, bias_data, output_shape, output_data, im2col_shape,
- im2col_data, nullptr);
+ bias_shape, bias_data, output_shape, output_data, scratchpad_shape,
+ scratchpad_data, nullptr);
}
static inline void ConvPerChannel(const tflite::ConvParams &params, const int32_t *mult,
@@ -55,16 +55,31 @@ static inline void ConvPerChannel(const tflite::ConvParams &params, const int32_
const int8 *input_data, const tflite::RuntimeShape &filter_shape,
const int8 *filter_data, const tflite::RuntimeShape &bias_shape,
const int32 *bias_data, const tflite::RuntimeShape &output_shape,
- int8 *output_data, const tflite::RuntimeShape &im2col_shape,
- int8 *im2col_data)
+ int8 *output_data, const tflite::RuntimeShape &scratchpad_shape,
+ int8 *scratchpad_data)
{
- (void)im2col_shape;
- (void)im2col_data;
+ (void)scratchpad_shape;
+ (void)scratchpad_data;
tflite::reference_integer_ops::ConvPerChannel(params, mult, shifts, input_shape, input_data,
filter_shape, filter_data, bias_shape, bias_data,
output_shape, output_data);
}
+static inline void SetupScratchpadTensor(luci_interpreter::Tensor *scratchpad,
+ const luci_interpreter::DataType &input_data_type,
+ const tflite::ConvParams &params,
+ const tflite::RuntimeShape &input_shape,
+ const tflite::RuntimeShape &filter_shape,
+ const tflite::RuntimeShape &output_shape)
+{
+ (void)input_data_type;
+ (void)params;
+ (void)input_shape;
+ (void)filter_shape;
+ (void)output_shape;
+ scratchpad->set_allocatable(false);
+}
+
} // namespace luci_interpreter_pal
#endif // LUCI_INTERPRETER_PAL_CONV2D_H
diff --git a/compiler/luci-interpreter/pal/mcu/PALDepthwiseConv2d.h b/compiler/luci-interpreter/pal/mcu/PALDepthwiseConv2d.h
new file mode 100644
index 000000000..c9d1a2948
--- /dev/null
+++ b/compiler/luci-interpreter/pal/mcu/PALDepthwiseConv2d.h
@@ -0,0 +1,91 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_PAL_DEPTHWISECONV2D_H
+#define LUCI_INTERPRETER_PAL_DEPTHWISECONV2D_H
+
+#include <tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h>
+#include <tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h>
+#include <tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h>
+
+namespace luci_interpreter_pal
+{
+template <typename T>
+static inline void
+DepthwiseConvPerChannel(const tflite::DepthwiseParams &params, const int32_t *output_multiplier,
+ const int32_t *output_shift, const tflite::RuntimeShape &input_shape,
+ const T *input_data, const tflite::RuntimeShape &filter_shape,
+ const T *filter_data, const tflite::RuntimeShape &bias_shape,
+ const int32_t *bias_data, const tflite::RuntimeShape &output_shape,
+ T *output_data, const tflite::RuntimeShape &scratchpad_shape,
+ T *scratchpad_data)
+{
+ {
+ // MARK: At this moment this operation is not supported
+ assert(false && "DepthwiseConvPerChannel NYI");
+ (void)params;
+ (void)output_multiplier;
+ (void)output_shift;
+ (void)input_shape;
+ (void)output_data;
+ (void)input_data;
+ (void)filter_shape;
+ (void)filter_data;
+ (void)bias_shape;
+ (void)bias_data;
+ (void)output_shape;
+ (void)output_data;
+ (void)scratchpad_shape;
+ (void)scratchpad_data;
+ }
+}
+
+template <>
+inline void DepthwiseConvPerChannel<int8_t>(
+ const tflite::DepthwiseParams &params, const int32_t *output_multiplier,
+ const int32_t *output_shift, const tflite::RuntimeShape &input_shape, const int8_t *input_data,
+ const tflite::RuntimeShape &filter_shape, const int8_t *filter_data,
+ const tflite::RuntimeShape &bias_shape, const int32_t *bias_data,
+ const tflite::RuntimeShape &output_shape, int8_t *output_data,
+ const tflite::RuntimeShape &scratchpad_shape, int8_t *scratchpad_data)
+{
+ (void)scratchpad_shape;
+ (void)scratchpad_data;
+ tflite::reference_integer_ops::DepthwiseConvPerChannel(
+ params, output_multiplier, output_shift, input_shape, input_data, filter_shape, filter_data,
+ bias_shape, bias_data, output_shape, output_data);
+}
+
+static inline void SetupScratchpadTensor(luci_interpreter::Tensor *scratchpad,
+ const tflite::DepthwiseParams &params,
+ const luci_interpreter::DataType &input_data_type,
+ const tflite::RuntimeShape &input_shape,
+ const tflite::RuntimeShape &filter_shape,
+ const tflite::RuntimeShape &output_shape)
+
+{
+ (void)params;
+ (void)input_data_type;
+ (void)input_shape;
+ (void)filter_shape;
+ (void)output_shape;
+
+ scratchpad->set_allocatable(false);
+}
+
+} // namespace luci_interpreter_pal
+
+#endif // LUCI_INTERPRETER_PAL_DEPTHWISECONV2D_H
diff --git a/compiler/luci-interpreter/pal/mcu/PALDequantize.h b/compiler/luci-interpreter/pal/mcu/PALDequantize.h
new file mode 100644
index 000000000..15ff0327b
--- /dev/null
+++ b/compiler/luci-interpreter/pal/mcu/PALDequantize.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_PAL_DEQUANTIZE_H
+#define LUCI_INTERPRETER_PAL_DEQUANTIZE_H
+
+#include "tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h"
+#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
+
+namespace luci_interpreter_pal
+{
+
+template <typename T>
+static inline void Dequantize(tflite::DequantizationParams &params,
+ const tflite::RuntimeShape &input_shape, const T *input_data,
+ const tflite::RuntimeShape &output_shape, float *output_data)
+{
+ tflite::reference_integer_ops::Dequantize<T>(params, input_shape, input_data, output_shape,
+ output_data);
+}
+
+static inline void Dequantize(tflite::DequantizationParams &params,
+ const tflite::RuntimeShape &input_shape, const uint8_t *input_data,
+ const tflite::RuntimeShape &output_shape, float *output_data)
+{
+ tflite::reference_ops::Dequantize(params, input_shape, input_data, output_shape, output_data);
+}
+
+} // namespace luci_interpreter_pal
+
+#endif // LUCI_INTERPRETER_PAL_DEQUANTIZE_H
diff --git a/compiler/luci-interpreter/pal/mcu/PALFullyConnected.h b/compiler/luci-interpreter/pal/mcu/PALFullyConnected.h
new file mode 100644
index 000000000..048624d74
--- /dev/null
+++ b/compiler/luci-interpreter/pal/mcu/PALFullyConnected.h
@@ -0,0 +1,61 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_PAL_FULLYCONNECTED_H
+#define LUCI_INTERPRETER_PAL_FULLYCONNECTED_H
+
+#include <tensorflow/lite/kernels/internal/reference/fully_connected.h>
+#include <tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h>
+
+namespace luci_interpreter_pal
+{
+template <typename T>
+static inline void FullyConnected(const tflite::FullyConnectedParams &params,
+ const tflite::RuntimeShape &input_shape, const T *input_data,
+ const tflite::RuntimeShape &filter_shape, const T *filter_data,
+ const tflite::RuntimeShape &bias_shape, const int32_t *bias_data,
+ const tflite::RuntimeShape &output_shape, T *output_data)
+{
+ {
+ // MARK: At this moment this operation is not supported
+ assert(false && "FullyConnected NYI");
+ (void)params;
+ (void)input_shape;
+ (void)input_data;
+ (void)filter_shape;
+ (void)filter_data;
+ (void)bias_shape;
+ (void)bias_data;
+ (void)output_shape;
+ (void)output_data;
+ }
+}
+
+template <>
+inline void
+FullyConnected<int8_t>(const tflite::FullyConnectedParams &params,
+ const tflite::RuntimeShape &input_shape, const int8_t *input_data,
+ const tflite::RuntimeShape &filter_shape, const int8_t *filter_data,
+ const tflite::RuntimeShape &bias_shape, const int32_t *bias_data,
+ const tflite::RuntimeShape &output_shape, int8_t *output_data)
+{
+ tflite::reference_integer_ops::FullyConnected(params, input_shape, input_data, filter_shape,
+ filter_data, bias_shape, bias_data, output_shape,
+ output_data);
+}
+} // namespace luci_interpreter_pal
+
+#endif // LUCI_INTERPRETER_PAL_FULLYCONNECTED_H
diff --git a/compiler/luci-interpreter/pal/mcu/PALMul.h b/compiler/luci-interpreter/pal/mcu/PALMul.h
index 2b46b100c..347a97a83 100644
--- a/compiler/luci-interpreter/pal/mcu/PALMul.h
+++ b/compiler/luci-interpreter/pal/mcu/PALMul.h
@@ -21,21 +21,21 @@
namespace luci_interpreter_pal
{
+template <typename T>
static inline void Mul(tflite::ArithmeticParams &params, const tflite::RuntimeShape &input1_shape,
- const float *input1_data, const tflite::RuntimeShape &input2_shape,
- const float *input2_data, const tflite::RuntimeShape &output_shape,
- float *output_data)
+ const T *input1_data, const tflite::RuntimeShape &input2_shape,
+ const T *input2_data, const tflite::RuntimeShape &output_shape,
+ T *output_data)
{
tflite::reference_ops::BroadcastMul4DSlow(params, input1_shape, input1_data, input2_shape,
input2_data, output_shape, output_data);
}
-static inline void BroadcastMul4DSlow(tflite::ArithmeticParams &params,
- const tflite::RuntimeShape &input1_shape,
- const float *input1_data,
- const tflite::RuntimeShape &input2_shape,
- const float *input2_data,
- const tflite::RuntimeShape &output_shape, float *output_data)
+template <typename T>
+static inline void
+BroadcastMul4DSlow(tflite::ArithmeticParams &params, const tflite::RuntimeShape &input1_shape,
+ const T *input1_data, const tflite::RuntimeShape &input2_shape,
+ const T *input2_data, const tflite::RuntimeShape &output_shape, T *output_data)
{
tflite::reference_ops::BroadcastMul4DSlow(params, input1_shape, input1_data, input2_shape,
input2_data, output_shape, output_data);
diff --git a/compiler/luci-interpreter/pal/mcu/PALQuantize.h b/compiler/luci-interpreter/pal/mcu/PALQuantize.h
new file mode 100644
index 000000000..6046789ae
--- /dev/null
+++ b/compiler/luci-interpreter/pal/mcu/PALQuantize.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_PAL_QUANTIZE_H
+#define LUCI_INTERPRETER_PAL_QUANTIZE_H
+
+#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
+
+namespace luci_interpreter_pal
+{
+template <typename T>
+static inline void Quantize(tflite::QuantizationParams &params,
+ const tflite::RuntimeShape &input_shape, const float *input_data,
+ const tflite::RuntimeShape &output_shape, T *output_data)
+{
+ tflite::reference_ops::AffineQuantize(params, input_shape, input_data, output_shape, output_data);
+}
+
+template <typename Input, typename Output>
+static inline void Requantize(const Input *input_data, int32_t size,
+ int32_t effective_scale_multiplier, int32_t effective_scale_shift,
+ int32_t input_zero_point, int32_t output_zero_point,
+ Output *output_data)
+{
+ tflite::reference_ops::Requantize(input_data, size, effective_scale_multiplier,
+ effective_scale_shift, input_zero_point, output_zero_point,
+ output_data);
+}
+} // namespace luci_interpreter_pal
+
+#endif // LUCI_INTERPRETER_PAL_QUANTIZE_H
diff --git a/compiler/luci-interpreter/pal/mcu/PALSVDF.h b/compiler/luci-interpreter/pal/mcu/PALSVDF.h
new file mode 100644
index 000000000..3bba668fb
--- /dev/null
+++ b/compiler/luci-interpreter/pal/mcu/PALSVDF.h
@@ -0,0 +1,258 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_PAL_SVDF_H
+#define LUCI_INTERPRETER_PAL_SVDF_H
+
+#include <tensorflow/lite/kernels/internal/reference/svdf.h>
+
+namespace luci_interpreter_pal
+{
+static inline void
+IntegerSVDF(const TfLiteSVDFParams &params, const tflite::RuntimeShape &input_shape,
+ const int8_t *input_data, const tflite::RuntimeShape &weight_feature_shape,
+ const int8_t *weight_feature_data, const tflite::RuntimeShape &weight_time_shape,
+ const int16_t *weight_time_data, const tflite::RuntimeShape &bias_shape,
+ const int32_t *bias_data, int16_t *activation_state_data,
+ const tflite::RuntimeShape &output_shape, int8_t *output_data, int32_t *scratchpad_data,
+ int32_t *output_temp_data, int32_t scale_1_a, int scale_1_b, int32_t scale_2_a,
+ int scale_2_b, int32_t input_zp, int32_t output_zp)
+{
+ const int n_rank = params.rank;
+ const int n_batch = input_shape.Dims(0);
+ const int n_input = input_shape.Dims(1);
+ const int n_filter = weight_feature_shape.Dims(0);
+ const int n_unit = n_filter / n_rank;
+ const int n_memory = weight_time_shape.Dims(1);
+
+ // Left shift the activation_state.
+ {
+ int16_t *new_state_start = activation_state_data;
+ const int16_t *old_state_start = activation_state_data + 1;
+ const int16_t *old_state_end = activation_state_data + n_batch * n_filter * n_memory;
+ while (old_state_start != old_state_end)
+ {
+ *new_state_start++ = *old_state_start++;
+ }
+ }
+
+ // Note: no need to clear the latest activation, matmul is not accumulative.
+
+ // Feature matmul.
+ {
+ const int32_t output_max = std::numeric_limits<int16_t>::max();
+ const int32_t output_min = std::numeric_limits<int16_t>::min();
+ int16_t *result_in_batch = activation_state_data + (n_memory - 1);
+ for (int b = 0; b < n_batch; b++)
+ {
+ const int8_t *matrix_ptr = weight_feature_data;
+ for (int r = 0; r < n_filter; r++)
+ {
+ int32_t dot_prod = 0;
+ const int8_t *vector_in_batch = input_data + b * n_input;
+ for (int c = 0; c < n_input; c++)
+ {
+ dot_prod += *matrix_ptr++ * (*vector_in_batch++ - input_zp);
+ }
+ dot_prod = tflite::MultiplyByQuantizedMultiplier(dot_prod, scale_1_a, scale_1_b);
+ dot_prod = std::min(std::max(output_min, dot_prod), output_max);
+ // This assumes state is symmetrically quantized. Otherwise last bit of
+ // state should be initialized to its zero point and accumulate the
+ // dot_prod.
+ // Equivalent as the following:
+ // result_in_batch = zero point, which happens to be zero.
+ // result_in_batch += dot_prod_56.
+ *result_in_batch = dot_prod;
+ result_in_batch += n_memory;
+ }
+ }
+ }
+
+ // Time.
+ {
+ for (int b = 0; b < n_batch; ++b)
+ {
+ int32_t *scratch_ptr_batch = scratchpad_data + b * n_filter;
+
+ // Perform batched vector dot product:
+ const int16_t *vector1_ptr = weight_time_data;
+ const int16_t *vector2_ptr = activation_state_data + b * n_memory * n_filter;
+
+ for (int i = 0; i < n_filter; i++)
+ {
+ *scratch_ptr_batch = 0;
+ for (int j = 0; j < n_memory; j++)
+ {
+ *scratch_ptr_batch += *vector1_ptr++ * *vector2_ptr++;
+ }
+ scratch_ptr_batch++;
+ }
+ }
+ }
+
+ // Reduce, add bias, rescale, activation.
+ {
+ // Add bias.
+ if (bias_data)
+ {
+ // Vector batch assign:
+ for (int i = 0; i < n_batch; ++i)
+ {
+ int32_t *output_ptr = output_temp_data + i * n_unit;
+ const int32_t *bias_ptr = bias_data;
+ for (int j = 0; j < n_unit; ++j)
+ {
+ *output_ptr++ = *bias_ptr++;
+ }
+ }
+ }
+ else
+ {
+ int32_t *output_ptr = output_temp_data;
+ for (int i = 0; i < n_batch * n_unit; ++i)
+ {
+ *output_ptr++ = 0;
+ }
+ }
+
+ // Reduce.
+ for (int b = 0; b < n_batch; ++b)
+ {
+ int32_t *output_temp_ptr = output_temp_data + b * n_unit;
+ int32_t *scratch_ptr_batch = scratchpad_data + b * n_filter;
+
+ // Reduction sum vector
+ for (int i = 0; i < n_unit; ++i)
+ {
+ for (int j = 0; j < n_rank; ++j)
+ {
+ output_temp_ptr[i] += *scratch_ptr_batch++;
+ }
+ }
+ }
+
+ // Rescale.
+ const int32_t output_max = std::numeric_limits<int8_t>::max();
+ const int32_t output_min = std::numeric_limits<int8_t>::min();
+ for (int i = 0; i < n_batch * n_unit; ++i)
+ {
+ int32_t x1 = output_temp_data[i];
+ int32_t x2 = tflite::MultiplyByQuantizedMultiplier(x1, scale_2_a, scale_2_b);
+ int32_t x3 = x2 + output_zp;
+ int32_t x4 = std::min(std::max(output_min, x3), output_max);
+ output_data[i] = static_cast<int8_t>(x4);
+ }
+ }
+}
+static inline void
+FloatSVDF(const TfLiteSVDFParams &params, const tflite::RuntimeShape &input_shape,
+ const float *input_data, const tflite::RuntimeShape &weight_feature_shape,
+ const float *weight_feature_data, const tflite::RuntimeShape &weight_time_shape,
+ const float *weight_time_data, const tflite::RuntimeShape &bias_shape,
+ const float *bias_data, float *scratchpad_data, float *activation_state_data,
+ const tflite::RuntimeShape &output_shape, float *output_data)
+{
+ const int32_t rank = params.rank;
+ const int32_t batch_size = input_shape.Dims(0);
+ const int32_t input_size = input_shape.Dims(1);
+ const int32_t num_filters = weight_feature_shape.Dims(0);
+ const int32_t num_units = num_filters / rank;
+ const int32_t memory_size = weight_time_shape.Dims(1);
+
+ // Left shift the activation_state.
+ {
+ float *new_state_start = activation_state_data;
+ const float *old_state_start = activation_state_data + 1;
+ const float *old_state_end = activation_state_data + batch_size * num_filters * memory_size;
+ while (old_state_start != old_state_end)
+ {
+ *new_state_start++ = *old_state_start++;
+ }
+ }
+
+ // Note: no need to clear the latest activation, matmul is not accumulative.
+
+ // Compute conv1d(inputs, weights_feature).
+ // The activation_state's rightmost column is used to save current cycle
+ // activation. This is achieved by starting at state_ptr[memory_size - 1] and
+ // having the stride equal to memory_size.
+
+ // Perform batched matrix vector multiply operation:
+ {
+ const float *matrix = weight_feature_data;
+ const float *vector = input_data;
+ float *result = &activation_state_data[memory_size - 1];
+ float *result_in_batch = result;
+ for (int i = 0; i < batch_size; ++i)
+ {
+ const float *matrix_ptr = matrix;
+ for (int j = 0; j < num_filters; ++j)
+ {
+ float dot_prod = 0.0f;
+ const float *vector_in_batch = vector + i * input_size;
+ for (int k = 0; k < input_size; ++k)
+ {
+ dot_prod += *matrix_ptr++ * *vector_in_batch++;
+ }
+ *result_in_batch = dot_prod;
+ result_in_batch += memory_size;
+ }
+ }
+ }
+
+ tflite::reference_ops::ApplyTimeWeightsBiasAndActivation(
+ batch_size, memory_size, num_filters, num_units, rank, weight_time_data, bias_data,
+ params.activation, activation_state_data, scratchpad_data, output_data);
+}
+
+static inline void SetupScratchpadTensor(
+ const luci_interpreter::DataType &input_data_type,
+ const luci_interpreter::DataType &weight_feature_data_type,
+ luci_interpreter::Tensor *scratchpad_1, luci_interpreter::Tensor *scratchpad_2,
+ luci_interpreter::Tensor *scratchpad_3, luci_interpreter::Tensor *scratchpad_4,
+ luci_interpreter::Tensor *scratchpad_5, luci_interpreter::Tensor *scratchpad_6,
+ const luci_interpreter::Shape input_shape, const luci_interpreter::Shape weight_time_shape,
+ const int32_t batch_size, const int32_t num_filters, const int32_t num_units)
+{
+
+ if (input_data_type == loco::DataType::FLOAT32 &&
+ (weight_feature_data_type == loco::DataType::S8 ||
+ weight_feature_data_type == loco::DataType::U8))
+ {
+ (void)input_shape;
+ (void)weight_time_shape;
+ (void)scratchpad_3;
+ (void)scratchpad_4;
+ (void)scratchpad_5;
+ (void)scratchpad_6;
+
+ throw std::runtime_error("Hybrid type is not currently supported for mcu platform");
+ }
+
+ // Resize scratchpad_1 tensor
+ scratchpad_1->resize({batch_size, num_filters});
+
+ if (input_data_type == loco::DataType::S8)
+ {
+ // Resize scratchpad_2 for full_integer op
+ scratchpad_2->resize({batch_size, num_units});
+ }
+}
+
+} // namespace luci_interpreter_pal
+
+#endif // LUCI_INTERPRETER_PAL_SVDF_H
diff --git a/compiler/luci-interpreter/pal/mcu/pal.cmake b/compiler/luci-interpreter/pal/mcu/pal.cmake
index a479d407b..907d51de6 100644
--- a/compiler/luci-interpreter/pal/mcu/pal.cmake
+++ b/compiler/luci-interpreter/pal/mcu/pal.cmake
@@ -39,7 +39,9 @@ macro(add_pal_to_target TGT)
# TODO put it back, I changed my mind.
# instead add sources with visitors in this library
- set(PAL_SOURCES ${TensorFlowSource_DIR}/tensorflow/lite/kernels/internal/quantization_util.cc)
+ set(PAL_SOURCES ${TensorFlowSource_DIR}/tensorflow/lite/kernels/internal/quantization_util.cc
+ ${TensorFlowSource_DIR}/tensorflow/lite/kernels/internal/tensor_utils.cc
+ ${TensorFlowSource_DIR}/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc)
add_library(luci_interpreter_mcu_pal STATIC ${PAL_SOURCES})
set_target_properties(luci_interpreter_mcu_pal PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_include_directories(luci_interpreter_mcu_pal PRIVATE
diff --git a/compiler/luci-interpreter/src/CMakeLists.txt b/compiler/luci-interpreter/src/CMakeLists.txt
index e37150336..997b75a84 100644
--- a/compiler/luci-interpreter/src/CMakeLists.txt
+++ b/compiler/luci-interpreter/src/CMakeLists.txt
@@ -13,6 +13,7 @@ set(LUCI_INTERPRETER_BINARY "luci_interpreter${LUCI_INTERPRETER_SUFFIX}")
set(LUCI_INTERPRETER_CORE "luci_interpreter_core${LUCI_INTERPRETER_SUFFIX}")
set(LUCI_INTERPRETER_KERNELS "luci_interpreter_kernels${LUCI_INTERPRETER_SUFFIX}")
set(LUCI_INTERPRETER_LOADER "luci_interpreter_loader${LUCI_INTERPRETER_SUFFIX}")
+set(LUCI_INTERPRETER_IMPORT "luci_interpreter_import${LUCI_INTERPRETER_SUFFIX}")
add_subdirectory(core)
message(STATUS "LUCI INTERPRETER CORE")
@@ -20,6 +21,8 @@ add_subdirectory(kernels)
message(STATUS "LUCI INTERPRETER KERNELS")
add_subdirectory(loader)
message(STATUS "LUCI INTERPRETER LOADER")
+add_subdirectory(import)
+message(STATUS "LUCI INTERPRETER IMPORT")
message(STATUS "LUCI INTERPTER INITALIZED")
diff --git a/compiler/luci-interpreter/src/Interpreter.cpp b/compiler/luci-interpreter/src/Interpreter.cpp
index 1b8792a6c..8cf272efd 100644
--- a/compiler/luci-interpreter/src/Interpreter.cpp
+++ b/compiler/luci-interpreter/src/Interpreter.cpp
@@ -70,25 +70,30 @@ private:
} // namespace
+Interpreter::Interpreter(const luci::Module *module)
+{
+ _runtime_to_ir = std::make_unique<RuntimeToIR>();
+ _event_notifier = std::make_unique<EventNotifierImpl>(*_runtime_to_ir, _observers);
+ _runtime_module = std::make_unique<RuntimeModule>(_event_notifier.get());
+
+ _default_memory_manager = std::make_unique<SimpleMemoryManager>();
+
+ ModuleLoader loader(module, _runtime_module.get(), *_runtime_to_ir, _node_to_tensor,
+ _default_memory_manager.get());
+ loader.load();
+}
+
Interpreter::Interpreter(const luci::Module *module,
luci_interpreter::IMemoryManager *memory_manager)
{
+ assert(memory_manager && "Use Interpreter::Interpreter(module) constructor instead");
+
_runtime_to_ir = std::make_unique<RuntimeToIR>();
_event_notifier = std::make_unique<EventNotifierImpl>(*_runtime_to_ir, _observers);
_runtime_module = std::make_unique<RuntimeModule>(_event_notifier.get());
- if (memory_manager == nullptr)
- {
- _default_memory_manager = std::make_unique<SimpleMemoryManager>();
- _memory_manager = _default_memory_manager.get();
- }
- else
- {
- _memory_manager = memory_manager;
- }
-
ModuleLoader loader(module, _runtime_module.get(), *_runtime_to_ir, _node_to_tensor,
- _memory_manager);
+ memory_manager);
loader.load();
}
diff --git a/compiler/luci-interpreter/src/core/CMakeLists.txt b/compiler/luci-interpreter/src/core/CMakeLists.txt
index 4430cba11..c2471e01c 100644
--- a/compiler/luci-interpreter/src/core/CMakeLists.txt
+++ b/compiler/luci-interpreter/src/core/CMakeLists.txt
@@ -10,7 +10,9 @@ set(SOURCES
Tensor.cpp)
add_library(${LUCI_INTERPRETER_CORE} STATIC ${SOURCES})
-set_target_properties(${LUCI_INTERPRETER_CORE} PROPERTIES POSITION_INDEPENDENT_CODE ON)
+if (NOT NNCC_LIBRARY_NO_PIC)
+ set_target_properties(${LUCI_INTERPRETER_CORE} PROPERTIES POSITION_INDEPENDENT_CODE ON)
+endif(NOT NNCC_LIBRARY_NO_PIC)
target_include_directories(${LUCI_INTERPRETER_CORE} PUBLIC "${LUCI_INTERPRETER_INCLUDE_DIR}")
target_include_directories(${LUCI_INTERPRETER_CORE} PUBLIC "${LUCI_INTERPRETER_SOURCE_DIR}")
target_link_libraries(${LUCI_INTERPRETER_CORE} PUBLIC luci_lang)
diff --git a/compiler/luci-interpreter/src/core/KernelParams.h b/compiler/luci-interpreter/src/core/KernelParams.h
index ee0390fcc..958fd4b74 100644
--- a/compiler/luci-interpreter/src/core/KernelParams.h
+++ b/compiler/luci-interpreter/src/core/KernelParams.h
@@ -43,6 +43,12 @@ struct ArgMaxParams
DataType output_type;
};
+struct BatchMatMulParams
+{
+ bool adj_x;
+ bool adj_y;
+};
+
struct ConcatenationParams
{
int axis;
@@ -83,6 +89,13 @@ struct DivParams
struct FullyConnectedParams
{
Activation activation;
+ bool keep_num_dims = false;
+};
+
+struct GatherParams
+{
+ int32_t axis;
+ int32_t batch_dims;
};
struct InstanceNormParams
@@ -119,6 +132,11 @@ struct MulParams
Activation activation;
};
+struct OneHotParams
+{
+ int32_t axis;
+};
+
struct PackParams
{
int32_t values_count;
@@ -157,6 +175,13 @@ struct SubParams
Activation activation;
};
+struct SVDFParams
+{
+ bool asymmetric_quantize_inputs;
+ int32_t svdf_rank;
+ Activation activation;
+};
+
struct SpaceToDepthParams
{
int block_size;
diff --git a/compiler/luci-interpreter/src/import/CMakeLists.txt b/compiler/luci-interpreter/src/import/CMakeLists.txt
new file mode 100644
index 000000000..dd9733f92
--- /dev/null
+++ b/compiler/luci-interpreter/src/import/CMakeLists.txt
@@ -0,0 +1,15 @@
+set(SOURCES
+ "${LUCI_INTERPRETER_INCLUDE_DIR}/luci_interpreter/GraphBuilderRegistry.h"
+ GraphBuilderRegistry.cpp)
+
+# include specific builders
+file(GLOB_RECURSE NODES "Nodes/*")
+list(APPEND SOURCES ${NODES})
+
+add_library(${LUCI_INTERPRETER_IMPORT} STATIC ${SOURCES})
+if (NOT NNCC_LIBRARY_NO_PIC)
+ set_target_properties(${LUCI_INTERPRETER_IMPORT} PROPERTIES POSITION_INDEPENDENT_CODE ON)
+endif(NOT NNCC_LIBRARY_NO_PIC)
+
+target_include_directories(${LUCI_INTERPRETER_IMPORT} PUBLIC "${LUCI_INTERPRETER_INCLUDE_DIR}")
+target_link_libraries(${LUCI_INTERPRETER_IMPORT} PUBLIC luci_import)
diff --git a/compiler/luci-interpreter/src/import/GraphBuilderRegistry.cpp b/compiler/luci-interpreter/src/import/GraphBuilderRegistry.cpp
new file mode 100644
index 000000000..a33bca6a4
--- /dev/null
+++ b/compiler/luci-interpreter/src/import/GraphBuilderRegistry.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "luci_interpreter/GraphBuilderRegistry.h"
+#include "Nodes/CircleReferencingConst.h"
+
+namespace luci_interpreter
+{
+
+std::unique_ptr<luci::GraphBuilderSource> source_without_constant_copying()
+{
+ auto builder = std::make_unique<luci::GraphBuilderRegistry>();
+ {
+ // redefine NodeBuilder of BUFFER type
+ builder->add(std::make_unique<CircleReferencingConstNodeBuilder>());
+ }
+
+ return builder;
+}
+
+} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/import/Nodes/CircleReferencingConst.cpp b/compiler/luci-interpreter/src/import/Nodes/CircleReferencingConst.cpp
new file mode 100644
index 000000000..14e90f240
--- /dev/null
+++ b/compiler/luci-interpreter/src/import/Nodes/CircleReferencingConst.cpp
@@ -0,0 +1,113 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleReferencingConst.h"
+
+#include <vector>
+
+namespace
+{
+
+// helper struct which describes data loaded to custom_options of CircleReferencingConst node
+struct ConstDataReference
+{
+ const uint8_t *data = nullptr;
+ uint32_t size = 0;
+};
+
+} // namespace
+
+namespace luci_interpreter
+{
+using namespace luci;
+
+CircleNode *CircleReferencingConstNodeBuilder::build(TensorIndex tensor_index,
+ GraphBuilderContext *context) const
+{
+ assert(tensor_index >= 0);
+
+ const auto graph = context->graph();
+ const auto reader = context->reader();
+ const auto tensors = reader->tensors();
+ auto const const_tensor = tensors[tensor_index];
+ assert(const_tensor != nullptr);
+ if (const_tensor->is_variable())
+ {
+ // Create CircleVariable for variable
+ return nullptr;
+ }
+
+ auto const buffer = wrap(reader->buffers()[const_tensor->buffer()]->data());
+ auto const const_dims = wrap(const_tensor->shape()); // in NHWC
+ if (const_dims.empty() && buffer.empty())
+ {
+ // unknown shape tensor and scalar tensor
+ return nullptr;
+ }
+
+ // if tensor_index is used as output to some other operator, this is not a constant
+ auto tensoroutputs = context->tensoroutputs();
+ if (tensoroutputs->find(tensor_index))
+ {
+ // other operator output tensor
+ return nullptr;
+ }
+
+ uint32_t num_elements = 1;
+ for (uint32_t r = 0; r < const_dims.size(); ++r)
+ {
+ num_elements = num_elements * const_dims[r];
+ }
+
+ if (buffer.empty() && num_elements > 0)
+ {
+ // normal empty tensor
+ return nullptr;
+ }
+
+ // create CircleReferencingConst
+ auto custom_node = graph->nodes()->create<CircleCustom>(0, 1);
+ {
+ custom_node->custom_code("CircleReferencingConst");
+
+ copy_tensor_attributes(const_tensor, custom_node);
+ custom_node->shape_status(luci::ShapeStatus::VALID);
+
+ // custom options stores size of buffer and pointer's value to buffer's data
+ {
+ std::vector<uint8_t> custom_options(sizeof(ConstDataReference));
+ {
+ auto &const_data_ref = *reinterpret_cast<ConstDataReference *>(custom_options.data());
+ const_data_ref = {buffer.data(), buffer.size()};
+ }
+ custom_node->custom_options(custom_options);
+ }
+ }
+
+ // Output of CircleCustom node presented with CircleConstNode
+ auto out_node = graph->nodes()->create<CircleCustomOut>();
+ {
+ out_node->index(0);
+ out_node->input(custom_node);
+
+ copy_tensor_attributes(const_tensor, out_node);
+ out_node->shape_status(luci::ShapeStatus::VALID);
+ }
+
+ return out_node;
+}
+
+} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/import/Nodes/CircleReferencingConst.h b/compiler/luci-interpreter/src/import/Nodes/CircleReferencingConst.h
new file mode 100644
index 000000000..ed8f95124
--- /dev/null
+++ b/compiler/luci-interpreter/src/import/Nodes/CircleReferencingConst.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_INTERPRETER_IMPORT_OP_CIRCLE_REFERENCING_CONST_H__
+#define __LUCI_INTERPRETER_IMPORT_OP_CIRCLE_REFERENCING_CONST_H__
+
+#include <luci/Import/NodeBuilder.h>
+
+#include <luci/IR/Nodes/CircleConst.h>
+
+namespace luci_interpreter
+{
+using namespace luci;
+
+/**
+ * @brief Builder creates CircleCustom node with pointer to constants data from Tensor with buffer.
+ */
+class CircleReferencingConstNodeBuilder : public TypedNodeBuilder<NodeBuilderType::BUFFER>
+{
+public:
+ CircleNode *build(TensorIndex tensor_index, GraphBuilderContext *ctx) const final;
+};
+
+} // namespace luci_interpreter
+
+#endif // __LUCI_INTERPRETER_IMPORT_OP_CIRCLE_REFERENCING_CONST_H__
diff --git a/compiler/luci-interpreter/src/kernels/Add.cpp b/compiler/luci-interpreter/src/kernels/Add.cpp
index 7381c3849..d7bf3084f 100644
--- a/compiler/luci-interpreter/src/kernels/Add.cpp
+++ b/compiler/luci-interpreter/src/kernels/Add.cpp
@@ -38,8 +38,11 @@ Add::Add(const Tensor *input1, const Tensor *input2, Tensor *output, const AddPa
void Add::configure()
{
LUCI_INTERPRETER_CHECK(input1()->element_type() == input2()->element_type());
+ LUCI_INTERPRETER_CHECK(input1()->element_type() == output()->element_type());
if (input1()->element_type() == DataType::S16)
{
+ LUCI_INTERPRETER_CHECK(input1()->zero_points().size() == 1 &&
+ input2()->zero_points().size() == 1);
LUCI_INTERPRETER_CHECK(input1()->zero_point() == 0 && input2()->zero_point() == 0 &&
output()->zero_point() == 0);
}
@@ -54,6 +57,12 @@ void Add::execute() const
case DataType::FLOAT32:
evalFloat();
break;
+ case DataType::S64:
+ evalInteger<int64_t>();
+ break;
+ case DataType::S32:
+ evalInteger<int32_t>();
+ break;
case DataType::U8:
evalQuantized();
break;
@@ -67,13 +76,8 @@ void Add::execute() const
void Add::evalFloat() const
{
- float activation_min{};
- float activation_max{};
- calculateActivationRange(_params.activation, &activation_min, &activation_max);
-
tflite::ArithmeticParams params{};
- params.float_activation_min = activation_min;
- params.float_activation_max = activation_max;
+ fillArithmeticActivationRange<float>(params, _params.activation);
const bool need_broadcast = tflite::reference_ops::ProcessBroadcastShapes(
getTensorShape(input1()), getTensorShape(input2()), &params);
@@ -92,6 +96,28 @@ void Add::evalFloat() const
}
}
+template <typename T> void Add::evalInteger() const
+{
+ tflite::ArithmeticParams params{};
+ fillArithmeticActivationRange<T>(params, _params.activation);
+
+ const bool need_broadcast = tflite::reference_ops::ProcessBroadcastShapes(
+ getTensorShape(input1()), getTensorShape(input2()), &params);
+
+ if (need_broadcast)
+ {
+ tflite::reference_ops::BroadcastAdd4DSlow(
+ params, getTensorShape(input1()), getTensorData<T>(input1()), getTensorShape(input2()),
+ getTensorData<T>(input2()), getTensorShape(output()), getTensorData<T>(output()));
+ }
+ else
+ {
+ tflite::reference_ops::Add(params, getTensorShape(input1()), getTensorData<T>(input1()),
+ getTensorShape(input2()), getTensorData<T>(input2()),
+ getTensorShape(output()), getTensorData<T>(output()));
+ }
+}
+
void Add::evalQuantized() const
{
const auto input1_scale = static_cast<double>(input1()->scale());
diff --git a/compiler/luci-interpreter/src/kernels/Add.h b/compiler/luci-interpreter/src/kernels/Add.h
index 79518845d..91d95b6af 100644
--- a/compiler/luci-interpreter/src/kernels/Add.h
+++ b/compiler/luci-interpreter/src/kernels/Add.h
@@ -39,6 +39,7 @@ public:
private:
void evalFloat() const;
+ template <typename T> void evalInteger() const;
void evalQuantized() const;
void evalQuantizedS16() const;
};
diff --git a/compiler/luci-interpreter/src/kernels/Add.test.cpp b/compiler/luci-interpreter/src/kernels/Add.test.cpp
index 847b65667..b8b1c3089 100644
--- a/compiler/luci-interpreter/src/kernels/Add.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/Add.test.cpp
@@ -166,6 +166,69 @@ TEST_F(AddTest, Float)
}
}
+template <loco::DataType DType> void CheckInteger(luci_interpreter::IMemoryManager *memory_manager)
+{
+ using dtype = typename loco::DataTypeImpl<DType>::Type;
+ Shape base_shape = {2, 3, 1, 2};
+ std::vector<Shape> test_shapes{{1, 1, 3, 2}, {1, 3, 1, 2}, {2, 1, 3, 1}, {2, 3, 1, 1}};
+ std::vector<std::vector<dtype>> test_outputs = {
+ {3, 3, 0, 1, 0, 8, 5, 1, 0, 0, 2, 6, 8, 0, 1, 0, 5, 1,
+ 5, 4, 0, 2, 2, 9, 11, 0, 4, 0, 8, 5, 11, 2, 4, 0, 8, 7},
+ {3, 3, 0, 0, 5, 1, 5, 4, 4, 0, 8, 7},
+ {3, 6, 0, 3, 0, 0, 5, 4, 2, 1, 0, 0, 8, 0, 5, 0, 1, 0,
+ 0, 2, 2, 4, 7, 9, 6, 0, 8, 0, 13, 5, 6, 0, 8, 2, 13, 7},
+ {3, 6, 2, 1, 1, 0, 0, 2, 8, 0, 13, 7}};
+ std::vector<dtype> input1_data{-1, 2, 1, 0, 4, -5, 1, 3, 7, -1, 7, 1};
+ std::vector<dtype> input2_data{4, 1, -3, -1, 1, 6};
+ for (size_t i = 0; i < test_shapes.size(); ++i)
+ {
+ Tensor input1_tensor = makeInputTensor<DType>(base_shape, input1_data, memory_manager);
+ Tensor input2_tensor = makeInputTensor<DType>(test_shapes[i], input2_data, memory_manager);
+ Tensor output_tensor = makeOutputTensor(DType);
+
+ AddParams params{};
+ params.activation = Activation::RELU;
+
+ Add kernel(&input1_tensor, &input2_tensor, &output_tensor, params);
+ kernel.configure();
+ memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<dtype>(output_tensor), test_outputs[i])
+ << "With shape number " << i;
+ }
+ // Re-run with exchanged inputs.
+ for (size_t i = 0; i < test_shapes.size(); ++i)
+ {
+ Tensor input1_tensor = makeInputTensor<DType>(test_shapes[i], input2_data, memory_manager);
+ Tensor input2_tensor = makeInputTensor<DType>(base_shape, input1_data, memory_manager);
+ Tensor output_tensor = makeOutputTensor(DType);
+
+ AddParams params{};
+ params.activation = Activation::RELU;
+
+ Add kernel(&input1_tensor, &input2_tensor, &output_tensor, params);
+ kernel.configure();
+ memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<dtype>(output_tensor), test_outputs[i])
+ << "With shape number " << i;
+ }
+};
+
+TEST_F(AddTest, SInt32)
+{
+ CheckInteger<loco::DataType::S32>(_memory_manager.get());
+ SUCCEED();
+}
+
+TEST_F(AddTest, SInt64)
+{
+ CheckInteger<loco::DataType::S64>(_memory_manager.get());
+ SUCCEED();
+}
+
TEST_F(AddTest, SInt16)
{
Shape base_shape = {2, 3, 1, 2};
@@ -248,11 +311,24 @@ TEST_F(AddTest, Input_Output_Type_NEG)
EXPECT_ANY_THROW(kernel.configure());
}
-TEST_F(AddTest, Invalid_Input_Type_NEG)
+TEST_F(AddTest, Invalid_Output_Type_NEG)
{
Tensor input1_tensor = makeInputTensor<DataType::S64>({1}, {1}, _memory_manager.get());
Tensor input2_tensor = makeInputTensor<DataType::S64>({1}, {2}, _memory_manager.get());
- Tensor output_tensor = makeOutputTensor(DataType::S64);
+ Tensor output_tensor = makeOutputTensor(DataType::S32);
+
+ AddParams params{};
+ params.activation = Activation::RELU;
+
+ Add kernel(&input1_tensor, &input2_tensor, &output_tensor, params);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(AddTest, Invalid_Input_Type_NEG)
+{
+ Tensor input1_tensor = makeInputTensor<DataType::U64>({1}, {1}, _memory_manager.get());
+ Tensor input2_tensor = makeInputTensor<DataType::U64>({1}, {2}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::U64);
AddParams params{};
params.activation = Activation::RELU;
@@ -263,6 +339,19 @@ TEST_F(AddTest, Invalid_Input_Type_NEG)
EXPECT_ANY_THROW(kernel.execute());
}
+TEST_F(AddTest, Invalid_Quantization_NEG)
+{
+ Tensor input1_tensor = makeInputTensor<DataType::S16>({1}, {1}, _memory_manager.get());
+ Tensor input2_tensor = makeInputTensor<DataType::S16>({1}, {2}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::S16);
+
+ AddParams params{};
+ params.activation = Activation::NONE;
+
+ Add kernel(&input1_tensor, &input2_tensor, &output_tensor, params);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
} // namespace
} // namespace kernels
} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/ArgMax.test.cpp b/compiler/luci-interpreter/src/kernels/ArgMax.test.cpp
index 119c69ccf..474f4b321 100644
--- a/compiler/luci-interpreter/src/kernels/ArgMax.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/ArgMax.test.cpp
@@ -57,7 +57,7 @@ template <typename T> class ArgMaxTest : public ::testing::Test
};
using DataTypes = ::testing::Types<float, uint8_t>;
-TYPED_TEST_CASE(ArgMaxTest, DataTypes);
+TYPED_TEST_SUITE(ArgMaxTest, DataTypes);
TYPED_TEST(ArgMaxTest, Simple)
{
diff --git a/compiler/luci-interpreter/src/kernels/AveragePool2D.cpp b/compiler/luci-interpreter/src/kernels/AveragePool2D.cpp
index 5545fb4d4..d3bade9e4 100644
--- a/compiler/luci-interpreter/src/kernels/AveragePool2D.cpp
+++ b/compiler/luci-interpreter/src/kernels/AveragePool2D.cpp
@@ -18,8 +18,7 @@
#include "kernels/Utils.h"
-#include <tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h>
-#include <tensorflow/lite/kernels/internal/reference/pooling.h>
+#include "PALAveragePool2d.h"
#include <stdexcept>
@@ -29,8 +28,9 @@ namespace luci_interpreter
namespace kernels
{
-AveragePool2D::AveragePool2D(const Tensor *input, Tensor *output, const Pool2DParams &params)
- : KernelWithParams<Pool2DParams>({input}, {output}, params)
+AveragePool2D::AveragePool2D(const Tensor *input, Tensor *output, Tensor *scratchpad,
+ const Pool2DParams &params)
+ : KernelWithParams<Pool2DParams>({input}, {output, scratchpad}, params)
{
}
@@ -76,6 +76,10 @@ void AveragePool2D::configure()
LUCI_INTERPRETER_CHECK(output()->zero_point() == input()->zero_point());
}
output()->resize({batches, output_height, output_width, depth});
+
+ auto scratchpad = getOutputTensors()[1];
+ luci_interpreter_pal::SetupScratchpadTensor(scratchpad, input()->element_type(),
+ getTensorShape(input()), getTensorShape(output()));
}
void AveragePool2D::execute() const
@@ -155,9 +159,14 @@ void AveragePool2D::evalSInt8() const
params.quantized_activation_min = activation_min;
params.quantized_activation_max = activation_max;
- tflite::reference_integer_ops::AveragePool(
+ auto scratchpad = getOutputTensors()[1];
+ int8_t *scratchpad_data = nullptr;
+ if (scratchpad->is_allocatable())
+ scratchpad_data = scratchpad->data<int8_t>();
+
+ luci_interpreter_pal::AveragePool<int8_t>(
params, getTensorShape(input()), getTensorData<int8_t>(input()), getTensorShape(output()),
- getTensorData<int8_t>(output()));
+ getTensorData<int8_t>(output()), getTensorShape(scratchpad), scratchpad_data);
}
void AveragePool2D::evalSInt16() const
diff --git a/compiler/luci-interpreter/src/kernels/AveragePool2D.h b/compiler/luci-interpreter/src/kernels/AveragePool2D.h
index b98367f31..2c8fe16e7 100644
--- a/compiler/luci-interpreter/src/kernels/AveragePool2D.h
+++ b/compiler/luci-interpreter/src/kernels/AveragePool2D.h
@@ -28,7 +28,8 @@ namespace kernels
class AveragePool2D : public KernelWithParams<Pool2DParams>
{
public:
- AveragePool2D(const Tensor *input, Tensor *output, const Pool2DParams &params);
+ AveragePool2D(const Tensor *input, Tensor *output, Tensor *scratchpad,
+ const Pool2DParams &params);
const Tensor *input() const { return _inputs[0]; }
Tensor *output() const { return _outputs[0]; }
diff --git a/compiler/luci-interpreter/src/kernels/AveragePool2D.test.cpp b/compiler/luci-interpreter/src/kernels/AveragePool2D.test.cpp
index 7ed421129..478bfa68e 100644
--- a/compiler/luci-interpreter/src/kernels/AveragePool2D.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/AveragePool2D.test.cpp
@@ -46,6 +46,7 @@ TEST_F(AveragePool2DTest, Float)
Tensor input_tensor =
makeInputTensor<DataType::FLOAT32>(input_shape, input_data, _memory_manager.get());
Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+ Tensor scratchpad(DataType::FLOAT32, Shape({}), {}, "");
Pool2DParams params{};
params.padding = Padding::VALID;
@@ -55,8 +56,9 @@ TEST_F(AveragePool2DTest, Float)
params.stride_width = 2;
params.activation = Activation::RELU6;
- AveragePool2D kernel(&input_tensor, &output_tensor, params);
+ AveragePool2D kernel(&input_tensor, &output_tensor, &scratchpad, params);
kernel.configure();
+ _memory_manager->allocate_memory(scratchpad);
_memory_manager->allocate_memory(output_tensor);
kernel.execute();
@@ -78,6 +80,7 @@ TEST_F(AveragePool2DTest, Uint8_0)
Tensor input_tensor = makeInputTensor<DataType::U8>(
{1, 2, 4, 1}, quant_param.first, quant_param.second, input_data, _memory_manager.get());
Tensor output_tensor = makeOutputTensor(DataType::U8, quant_param.first, quant_param.second);
+ Tensor scratchpad(DataType::U8, Shape({}), {}, "");
Pool2DParams params{};
params.padding = Padding::VALID;
@@ -87,8 +90,9 @@ TEST_F(AveragePool2DTest, Uint8_0)
params.stride_width = 2;
params.activation = Activation::RELU6;
- AveragePool2D kernel(&input_tensor, &output_tensor, params);
+ AveragePool2D kernel(&input_tensor, &output_tensor, &scratchpad, params);
kernel.configure();
+ _memory_manager->allocate_memory(scratchpad);
_memory_manager->allocate_memory(output_tensor);
kernel.execute();
@@ -107,6 +111,7 @@ TEST_F(AveragePool2DTest, Uint8_1)
Tensor input_tensor = makeInputTensor<DataType::U8>(
{1, 2, 4, 1}, quant_param.first, quant_param.second, input_data, _memory_manager.get());
Tensor output_tensor = makeOutputTensor(DataType::U8, quant_param.first, quant_param.second);
+ Tensor scratchpad(DataType::U8, Shape({}), {}, "");
Pool2DParams params{};
params.padding = Padding::VALID;
@@ -116,9 +121,10 @@ TEST_F(AveragePool2DTest, Uint8_1)
params.stride_width = 2;
params.activation = Activation::RELU6;
- AveragePool2D kernel(&input_tensor, &output_tensor, params);
+ AveragePool2D kernel(&input_tensor, &output_tensor, &scratchpad, params);
kernel.configure();
_memory_manager->allocate_memory(output_tensor);
+ _memory_manager->allocate_memory(scratchpad);
kernel.execute();
EXPECT_THAT(dequantizeTensorData(output_tensor), FloatArrayNear({2.75, 6.0}));
@@ -141,6 +147,7 @@ TEST_F(AveragePool2DTest, SInt16)
Tensor input_tensor =
makeInputTensor<DataType::S16>(input_shape, 0.5, 0, input_data, _memory_manager.get());
Tensor output_tensor = makeOutputTensor(DataType::S16, 0.5, 0);
+ Tensor scratchpad(DataType::S16, Shape({}), {}, "");
Pool2DParams params{};
params.padding = Padding::VALID;
@@ -150,8 +157,9 @@ TEST_F(AveragePool2DTest, SInt16)
params.stride_width = 2;
params.activation = Activation::RELU6;
- AveragePool2D kernel(&input_tensor, &output_tensor, params);
+ AveragePool2D kernel(&input_tensor, &output_tensor, &scratchpad, params);
kernel.configure();
+ _memory_manager->allocate_memory(scratchpad);
_memory_manager->allocate_memory(output_tensor);
kernel.execute();
@@ -174,6 +182,7 @@ TEST_F(AveragePool2DTest, SInt8)
Tensor input_tensor = makeInputTensor<DataType::S8>(
input_shape, quant_param.first, quant_param.second, input_data, _memory_manager.get());
Tensor output_tensor = makeOutputTensor(DataType::S8, quant_param.first, quant_param.second);
+ Tensor scratchpad(DataType::S8, Shape({}), {}, "");
Pool2DParams params{};
params.padding = Padding::VALID;
@@ -183,8 +192,9 @@ TEST_F(AveragePool2DTest, SInt8)
params.stride_width = 2;
params.activation = Activation::RELU6;
- AveragePool2D kernel(&input_tensor, &output_tensor, params);
+ AveragePool2D kernel(&input_tensor, &output_tensor, &scratchpad, params);
kernel.configure();
+ _memory_manager->allocate_memory(scratchpad);
_memory_manager->allocate_memory(output_tensor);
kernel.execute();
@@ -203,6 +213,7 @@ TEST_F(AveragePool2DTest, Invalid_Input_Shape_NEG)
Tensor input_tensor =
makeInputTensor<DataType::FLOAT32>(input_shape, input_data, _memory_manager.get());
Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+ Tensor scratchpad(DataType::FLOAT32, Shape({}), {}, "");
Pool2DParams params{};
params.padding = Padding::VALID;
@@ -212,7 +223,7 @@ TEST_F(AveragePool2DTest, Invalid_Input_Shape_NEG)
params.stride_width = 2;
params.activation = Activation::RELU6;
- AveragePool2D kernel(&input_tensor, &output_tensor, params);
+ AveragePool2D kernel(&input_tensor, &output_tensor, &scratchpad, params);
EXPECT_ANY_THROW(kernel.configure());
}
@@ -227,6 +238,7 @@ TEST_F(AveragePool2DTest, In_Out_Type_NEG)
Tensor input_tensor =
makeInputTensor<DataType::FLOAT32>(input_shape, input_data, _memory_manager.get());
Tensor output_tensor = makeOutputTensor(DataType::U8);
+ Tensor scratchpad(DataType::FLOAT32, Shape({}), {}, "");
Pool2DParams params{};
params.padding = Padding::VALID;
@@ -236,7 +248,7 @@ TEST_F(AveragePool2DTest, In_Out_Type_NEG)
params.stride_width = 2;
params.activation = Activation::RELU6;
- AveragePool2D kernel(&input_tensor, &output_tensor, params);
+ AveragePool2D kernel(&input_tensor, &output_tensor, &scratchpad, params);
EXPECT_ANY_THROW(kernel.configure());
}
@@ -252,6 +264,7 @@ TEST_F(AveragePool2DTest, Quant_Param_NEG)
Tensor input_tensor = makeInputTensor<DataType::U8>(
{1, 2, 4, 1}, quant_param1.first, quant_param1.second, input_data, _memory_manager.get());
Tensor output_tensor = makeOutputTensor(DataType::U8, quant_param2.first, quant_param2.second);
+ Tensor scratchpad(DataType::U8, Shape({}), {}, "");
Pool2DParams params{};
params.padding = Padding::VALID;
@@ -261,7 +274,7 @@ TEST_F(AveragePool2DTest, Quant_Param_NEG)
params.stride_width = 2;
params.activation = Activation::RELU6;
- AveragePool2D kernel(&input_tensor, &output_tensor, params);
+ AveragePool2D kernel(&input_tensor, &output_tensor, &scratchpad, params);
EXPECT_ANY_THROW(kernel.configure());
}
diff --git a/compiler/luci-interpreter/src/kernels/BatchMatMul.cpp b/compiler/luci-interpreter/src/kernels/BatchMatMul.cpp
new file mode 100644
index 000000000..24ca22996
--- /dev/null
+++ b/compiler/luci-interpreter/src/kernels/BatchMatMul.cpp
@@ -0,0 +1,188 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "kernels/BatchMatMul.h"
+#include "kernels/Utils.h"
+
+#include "PALBatchMatMul.h"
+
+#include <tensorflow/lite/kernels/internal/reference/transpose.h>
+
+#include <stdexcept>
+
+namespace
+{
+
+tflite::RuntimeShape SwapRowColumnDims(const tflite::RuntimeShape &shape)
+{
+ tflite::RuntimeShape swapped_shape(shape);
+ const int32_t dims = shape.DimensionsCount();
+ swapped_shape.SetDim(dims - 2, shape.Dims(dims - 1));
+ swapped_shape.SetDim(dims - 1, shape.Dims(dims - 2));
+ return swapped_shape;
+}
+
+} // namespace
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+
+BatchMatMul::BatchMatMul(const Tensor *x, const Tensor *y, Tensor *output, Tensor *x_tmp,
+ Tensor *y_tmp, const BatchMatMulParams &params)
+ : KernelWithParams({x, y}, {output, x_tmp, y_tmp}, params)
+{
+}
+
+void BatchMatMul::configure()
+{
+ auto lhs = x();
+ auto rhs = y();
+ auto adj_x = params().adj_x;
+ auto adj_y = params().adj_y;
+
+ // TODO Support non-float types
+ if (lhs->element_type() != DataType::FLOAT32 || rhs->element_type() != DataType::FLOAT32)
+ throw std::runtime_error("Unsupported type.");
+
+ LUCI_INTERPRETER_CHECK(lhs->element_type() == rhs->element_type());
+
+ auto lhs_rank = lhs->shape().num_dims();
+ auto rhs_rank = rhs->shape().num_dims();
+ LUCI_INTERPRETER_CHECK(lhs_rank >= 2 && lhs_rank <= 4);
+ LUCI_INTERPRETER_CHECK(rhs_rank >= 2 && rhs_rank <= 4);
+
+ auto lhs_scratchpad = temp_lhs();
+ auto rhs_scratchpad = temp_rhs();
+ luci_interpreter_pal::SetupScratchpadTensor(lhs_scratchpad, rhs_scratchpad, getTensorShape(lhs),
+ getTensorShape(rhs));
+
+ auto output_rank = std::max(lhs_rank, rhs_rank);
+
+ auto extended_lhs_shape = tflite::RuntimeShape::ExtendedShape(output_rank, getTensorShape(lhs));
+ auto extended_rhs_shape = tflite::RuntimeShape::ExtendedShape(output_rank, getTensorShape(rhs));
+
+ // Ensure any batch dimensions obey broacasting rules.
+ for (int i = 0; i < output_rank - 2; ++i)
+ {
+ const int lhs_dim = extended_lhs_shape.Dims(i);
+ const int rhs_dim = extended_rhs_shape.Dims(i);
+ if (lhs_dim != rhs_dim)
+ {
+ if (lhs_dim != 1)
+ {
+ LUCI_INTERPRETER_CHECK(rhs_dim == 1);
+ }
+ }
+ }
+
+ // Ensure other dimensions work for matrix multiplication.
+ int accum_dim_lhs =
+ adj_x ? extended_lhs_shape.Dims(output_rank - 2) : extended_lhs_shape.Dims(output_rank - 1);
+ int accum_dim_rhs =
+ adj_y ? extended_rhs_shape.Dims(output_rank - 1) : extended_rhs_shape.Dims(output_rank - 2);
+ LUCI_INTERPRETER_CHECK(accum_dim_lhs == accum_dim_rhs);
+
+ Shape output_shape(output_rank);
+ // Fill in any broadcast dimensions.
+ for (int i = 0; i < output_rank - 2; ++i)
+ {
+ const int lhs_dim = extended_lhs_shape.Dims(i);
+ const int rhs_dim = extended_rhs_shape.Dims(i);
+ int broadcast_dim = lhs_dim;
+ if ((lhs_dim != rhs_dim) && (lhs_dim == 1))
+ {
+ broadcast_dim = rhs_dim;
+ }
+ output_shape.dim(i) = broadcast_dim;
+ }
+ // Fill in the matmul dimensions.
+ int lhs_rows_index = adj_x ? output_rank - 1 : output_rank - 2;
+ int rhs_cols_index = adj_y ? output_rank - 2 : output_rank - 1;
+
+ output_shape.dim(output_rank - 2) = extended_lhs_shape.Dims(lhs_rows_index);
+ output_shape.dim(output_rank - 1) = extended_rhs_shape.Dims(rhs_cols_index);
+
+ output()->resize(output_shape);
+}
+
+void TransposeRowsColumns(const Tensor *tensor_in, Tensor *tensor_out)
+{
+ tflite::RuntimeShape transposed_shape(getTensorShape(tensor_in));
+ tflite::RuntimeShape shape(getTensorShape(tensor_in));
+ tflite::TransposeParams params;
+ int rank = shape.DimensionsCount();
+ params.perm_count = rank;
+ for (int i = 0; i < rank - 2; ++i)
+ {
+ params.perm[i] = i;
+ }
+ // Transpose the last two dimensions.
+ params.perm[rank - 2] = rank - 1;
+ params.perm[rank - 1] = rank - 2;
+ transposed_shape.SetDim(rank - 1, shape.Dims(rank - 2));
+ transposed_shape.SetDim(rank - 2, shape.Dims(rank - 1));
+ switch (tensor_in->element_type())
+ {
+ case DataType::FLOAT32:
+ tflite::reference_ops::Transpose(params, shape, getTensorData<float>(tensor_in),
+ transposed_shape, getTensorData<float>(tensor_out));
+ break;
+ default:
+ throw std::runtime_error("Only suppport fp32 BatchMatMul for now.");
+ }
+}
+
+void BatchMatMul::execute() const
+{
+ auto lhs = x();
+ auto rhs = y();
+
+ bool adj_x = params().adj_x;
+ bool adj_y = params().adj_y;
+
+ auto orig_lhs_shape = getTensorShape(lhs);
+ auto orig_rhs_shape = getTensorShape(rhs);
+
+ auto rhs_tensor = adj_y ? rhs : temp_rhs();
+ auto lhs_tensor = adj_x ? temp_lhs() : lhs;
+ if (not adj_y)
+ {
+ TransposeRowsColumns(rhs, temp_rhs());
+ }
+ if (adj_x)
+ {
+ TransposeRowsColumns(lhs, temp_lhs());
+ }
+ tflite::RuntimeShape rhs_shape = adj_y ? orig_rhs_shape : SwapRowColumnDims(orig_rhs_shape);
+ tflite::RuntimeShape lhs_shape = adj_x ? orig_lhs_shape : SwapRowColumnDims(orig_lhs_shape);
+
+ switch (x()->element_type())
+ {
+ case DataType::FLOAT32:
+ luci_interpreter_pal::BatchMatMul(rhs_shape, getTensorData<float>(rhs_tensor), lhs_shape,
+ getTensorData<float>(lhs_tensor), getTensorShape(output()),
+ getTensorData<float>(output()));
+ break;
+ default:
+ throw std::runtime_error("Unsupported type.");
+ }
+}
+
+} // namespace kernels
+} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/BatchMatMul.h b/compiler/luci-interpreter/src/kernels/BatchMatMul.h
new file mode 100644
index 000000000..744f49795
--- /dev/null
+++ b/compiler/luci-interpreter/src/kernels/BatchMatMul.h
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_KERNELS_BATCHMATMUL_H
+#define LUCI_INTERPRETER_KERNELS_BATCHMATMUL_H
+
+#include "core/Kernel.h"
+#include "core/KernelParams.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+
+class BatchMatMul : public KernelWithParams<BatchMatMulParams>
+{
+public:
+ BatchMatMul(const Tensor *x, const Tensor *y, Tensor *output, Tensor *x_tmp, Tensor *y_tmp,
+ const BatchMatMulParams &params);
+
+ const Tensor *x() const { return _inputs[0]; }
+ const Tensor *y() const { return _inputs[1]; }
+ Tensor *output() const { return _outputs[0]; }
+
+ void configure() override;
+ void execute() const override;
+
+private:
+ Tensor *temp_lhs() const { return _outputs[1]; }
+ Tensor *temp_rhs() const { return _outputs[2]; }
+};
+
+} // namespace kernels
+} // namespace luci_interpreter
+
+#endif // LUCI_INTERPRETER_KERNELS_BATCHMATMUL_H
diff --git a/compiler/luci-interpreter/src/kernels/BatchMatMul.test.cpp b/compiler/luci-interpreter/src/kernels/BatchMatMul.test.cpp
new file mode 100644
index 000000000..edfa3a685
--- /dev/null
+++ b/compiler/luci-interpreter/src/kernels/BatchMatMul.test.cpp
@@ -0,0 +1,272 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "kernels/BatchMatMul.h"
+#include "kernels/TestUtils.h"
+#include "luci_interpreter/TestMemoryManager.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+namespace
+{
+
+using namespace testing;
+
+class BatchMatMulTest : public ::testing::Test
+{
+protected:
+ void SetUp() override { _memory_manager = std::make_unique<TestMemoryManager>(); }
+
+ std::unique_ptr<IMemoryManager> _memory_manager;
+};
+
+TEST_F(BatchMatMulTest, Float)
+{
+ std::vector<float> lhs_data = {1, 2, 3, 4, 5, 6};
+ std::vector<float> rhs_data = {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18};
+ Tensor lhs_tensor =
+ makeInputTensor<DataType::FLOAT32>({1, 2, 3}, lhs_data, _memory_manager.get());
+ Tensor rhs_tensor =
+ makeInputTensor<DataType::FLOAT32>({1, 3, 4}, rhs_data, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+ Tensor lhs_scratch(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor rhs_scratch(DataType::FLOAT32, Shape({}), {}, "");
+
+ BatchMatMulParams params;
+ params.adj_x = false;
+ params.adj_y = false;
+
+ BatchMatMul kernel(&lhs_tensor, &rhs_tensor, &output_tensor, &lhs_scratch, &rhs_scratch, params);
+ kernel.configure();
+ _memory_manager->allocate_memory(lhs_scratch);
+ _memory_manager->allocate_memory(rhs_scratch);
+ _memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<float>(output_tensor),
+ FloatArrayNear({74., 80., 86., 92., 173., 188., 203., 218.}));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({1, 2, 4}));
+}
+
+TEST_F(BatchMatMulTest, Float_SimpleRHSAdjoint)
+{
+ std::vector<float> lhs_data = {1, 2, 3, 4, 5, 6};
+ std::vector<float> rhs_data = {7, 11, 15, 8, 12, 16, 9, 13, 17, 10, 14, 18};
+ Tensor lhs_tensor =
+ makeInputTensor<DataType::FLOAT32>({1, 2, 3}, lhs_data, _memory_manager.get());
+ Tensor rhs_tensor =
+ makeInputTensor<DataType::FLOAT32>({1, 4, 3}, rhs_data, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+ Tensor lhs_scratch(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor rhs_scratch(DataType::FLOAT32, Shape({}), {}, "");
+
+ BatchMatMulParams params;
+ params.adj_x = false;
+ params.adj_y = true;
+
+ BatchMatMul kernel(&lhs_tensor, &rhs_tensor, &output_tensor, &lhs_scratch, &rhs_scratch, params);
+ kernel.configure();
+ _memory_manager->allocate_memory(lhs_scratch);
+ _memory_manager->allocate_memory(rhs_scratch);
+ _memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<float>(output_tensor),
+ FloatArrayNear({74., 80., 86., 92., 173., 188., 203., 218.}));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({1, 2, 4}));
+}
+
+TEST_F(BatchMatMulTest, Float_SimpleLHSAdjoint)
+{
+ std::vector<float> lhs_data = {1, 4, 2, 5, 3, 6};
+ std::vector<float> rhs_data = {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18};
+ Tensor lhs_tensor =
+ makeInputTensor<DataType::FLOAT32>({1, 3, 2}, lhs_data, _memory_manager.get());
+ Tensor rhs_tensor =
+ makeInputTensor<DataType::FLOAT32>({1, 3, 4}, rhs_data, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+ Tensor lhs_scratch(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor rhs_scratch(DataType::FLOAT32, Shape({}), {}, "");
+
+ BatchMatMulParams params;
+ params.adj_x = true;
+ params.adj_y = false;
+
+ BatchMatMul kernel(&lhs_tensor, &rhs_tensor, &output_tensor, &lhs_scratch, &rhs_scratch, params);
+ kernel.configure();
+ _memory_manager->allocate_memory(lhs_scratch);
+ _memory_manager->allocate_memory(rhs_scratch);
+ _memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<float>(output_tensor),
+ FloatArrayNear({74., 80., 86., 92., 173., 188., 203., 218.}));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({1, 2, 4}));
+}
+
+TEST_F(BatchMatMulTest, Float_BatchSizeTwo)
+{
+ std::vector<float> lhs_data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
+ std::vector<float> rhs_data = {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
+ 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30};
+ Tensor lhs_tensor =
+ makeInputTensor<DataType::FLOAT32>({2, 2, 3}, lhs_data, _memory_manager.get());
+ Tensor rhs_tensor =
+ makeInputTensor<DataType::FLOAT32>({2, 3, 4}, rhs_data, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+ Tensor lhs_scratch(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor rhs_scratch(DataType::FLOAT32, Shape({}), {}, "");
+
+ BatchMatMulParams params;
+ params.adj_x = false;
+ params.adj_y = false;
+
+ BatchMatMul kernel(&lhs_tensor, &rhs_tensor, &output_tensor, &lhs_scratch, &rhs_scratch, params);
+ kernel.configure();
+ _memory_manager->allocate_memory(lhs_scratch);
+ _memory_manager->allocate_memory(rhs_scratch);
+ _memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<float>(output_tensor),
+ FloatArrayNear({74., 80., 86., 92., 173., 188., 203., 218., 560., 584., 608., 632.,
+ 767., 800., 833., 866.}));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({2, 2, 4}));
+}
+
+TEST_F(BatchMatMulTest, Float_DiffBatch)
+{
+ std::vector<float> lhs_data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
+ std::vector<float> rhs_data = {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
+ 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30};
+ Tensor lhs_tensor =
+ makeInputTensor<DataType::FLOAT32>({2, 1, 6}, lhs_data, _memory_manager.get());
+ Tensor rhs_tensor =
+ makeInputTensor<DataType::FLOAT32>({1, 6, 4}, rhs_data, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+ Tensor lhs_scratch(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor rhs_scratch(DataType::FLOAT32, Shape({}), {}, "");
+
+ BatchMatMulParams params;
+ params.adj_x = false;
+ params.adj_y = false;
+
+ BatchMatMul kernel(&lhs_tensor, &rhs_tensor, &output_tensor, &lhs_scratch, &rhs_scratch, params);
+ kernel.configure();
+ _memory_manager->allocate_memory(lhs_scratch);
+ _memory_manager->allocate_memory(rhs_scratch);
+ _memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<float>(output_tensor),
+ FloatArrayNear({427., 448., 469., 490., 1039., 1096., 1153., 1210.}));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({2, 1, 4}));
+}
+
+TEST_F(BatchMatMulTest, Invalid_Shape_NEG)
+{
+ Tensor lhs_tensor =
+ makeInputTensor<DataType::FLOAT32>({1, 2, 2}, {1, 2, 3, 4}, _memory_manager.get());
+ Tensor rhs_tensor =
+ makeInputTensor<DataType::FLOAT32>({1, 3, 2}, {5, 6, 7, 8, 9, 10}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+ Tensor lhs_scratch(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor rhs_scratch(DataType::FLOAT32, Shape({}), {}, "");
+
+ BatchMatMulParams params;
+ params.adj_x = false;
+ params.adj_y = false;
+
+ BatchMatMul kernel(&lhs_tensor, &rhs_tensor, &output_tensor, &lhs_scratch, &rhs_scratch, params);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(BatchMatMulTest, Invalid_Batch_NEG)
+{
+ Tensor lhs_tensor =
+ makeInputTensor<DataType::FLOAT32>({2, 1, 3}, {1, 2, 3, 4, 5, 6}, _memory_manager.get());
+ Tensor rhs_tensor = makeInputTensor<DataType::FLOAT32>({3, 3, 1}, {5, 6, 7, 8, 9, 10, 11, 12, 13},
+ _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+ Tensor lhs_scratch(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor rhs_scratch(DataType::FLOAT32, Shape({}), {}, "");
+
+ BatchMatMulParams params;
+ params.adj_x = false;
+ params.adj_y = false;
+
+ BatchMatMul kernel(&lhs_tensor, &rhs_tensor, &output_tensor, &lhs_scratch, &rhs_scratch, params);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(BatchMatMulTest, Invalid_Rank_NEG)
+{
+ Tensor lhs_tensor = makeInputTensor<DataType::FLOAT32>({4}, {1, 2, 3, 4}, _memory_manager.get());
+ Tensor rhs_tensor = makeInputTensor<DataType::FLOAT32>({1, 4, 2}, {5, 6, 7, 8, 9, 10, 11, 12},
+ _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+ Tensor lhs_scratch(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor rhs_scratch(DataType::FLOAT32, Shape({}), {}, "");
+
+ BatchMatMulParams params;
+ params.adj_x = false;
+ params.adj_y = false;
+
+ BatchMatMul kernel(&lhs_tensor, &rhs_tensor, &output_tensor, &lhs_scratch, &rhs_scratch, params);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(BatchMatMulTest, Invalid_Rank2_NEG)
+{
+ Tensor lhs_tensor =
+ makeInputTensor<DataType::FLOAT32>({1, 1, 1, 1, 4}, {1, 2, 3, 4}, _memory_manager.get());
+ Tensor rhs_tensor = makeInputTensor<DataType::FLOAT32>({1, 4, 2}, {5, 6, 7, 8, 9, 10, 11, 12},
+ _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+ Tensor lhs_scratch(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor rhs_scratch(DataType::FLOAT32, Shape({}), {}, "");
+
+ BatchMatMulParams params;
+ params.adj_x = false;
+ params.adj_y = false;
+
+ BatchMatMul kernel(&lhs_tensor, &rhs_tensor, &output_tensor, &lhs_scratch, &rhs_scratch, params);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(BatchMatMulTest, TypeMisMatch_NEG)
+{
+ Tensor lhs_tensor =
+ makeInputTensor<DataType::U8>({1, 2, 3}, {1, 2, 3, 4, 5, 6}, _memory_manager.get());
+ Tensor rhs_tensor =
+ makeInputTensor<DataType::FLOAT32>({1, 3, 2}, {5, 6, 7, 8, 9, 10}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+ Tensor lhs_scratch(DataType::U8, Shape({}), {}, "");
+ Tensor rhs_scratch(DataType::FLOAT32, Shape({}), {}, "");
+
+ BatchMatMulParams params;
+ params.adj_x = false;
+ params.adj_y = false;
+
+ BatchMatMul kernel(&lhs_tensor, &rhs_tensor, &output_tensor, &lhs_scratch, &rhs_scratch, params);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+} // namespace
+} // namespace kernels
+} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/BatchToSpaceND.test.cpp b/compiler/luci-interpreter/src/kernels/BatchToSpaceND.test.cpp
index f3a344974..52647a763 100644
--- a/compiler/luci-interpreter/src/kernels/BatchToSpaceND.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/BatchToSpaceND.test.cpp
@@ -58,7 +58,7 @@ template <typename T> class BatchToSpaceNDTest : public ::testing::Test
};
using DataTypes = ::testing::Types<float, uint8_t>;
-TYPED_TEST_CASE(BatchToSpaceNDTest, DataTypes);
+TYPED_TEST_SUITE(BatchToSpaceNDTest, DataTypes);
TYPED_TEST(BatchToSpaceNDTest, Simple)
{
diff --git a/compiler/luci-interpreter/src/kernels/CMakeLists.txt b/compiler/luci-interpreter/src/kernels/CMakeLists.txt
index 1b7d0f66a..9f4ba0e0b 100644
--- a/compiler/luci-interpreter/src/kernels/CMakeLists.txt
+++ b/compiler/luci-interpreter/src/kernels/CMakeLists.txt
@@ -15,7 +15,9 @@ endmacro(REGISTER_KERNEL)
include(${KERNEL_REGISTER_FILE})
add_library(${LUCI_INTERPRETER_KERNELS} STATIC ${SOURCES})
-set_target_properties(${LUCI_INTERPRETER_KERNELS} PROPERTIES POSITION_INDEPENDENT_CODE ON)
+if (NOT NNCC_LIBRARY_NO_PIC)
+ set_target_properties(${LUCI_INTERPRETER_KERNELS} PROPERTIES POSITION_INDEPENDENT_CODE ON)
+endif(NOT NNCC_LIBRARY_NO_PIC)
target_include_directories(${LUCI_INTERPRETER_KERNELS} PUBLIC ${LUCI_INTERPRETER_SOURCE_DIR})
target_link_libraries(${LUCI_INTERPRETER_KERNELS} PUBLIC ${LUCI_INTERPRETER_CORE})
diff --git a/compiler/luci-interpreter/src/kernels/Cast.test.cpp b/compiler/luci-interpreter/src/kernels/Cast.test.cpp
index 731260522..4713ad34c 100644
--- a/compiler/luci-interpreter/src/kernels/Cast.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/Cast.test.cpp
@@ -79,7 +79,7 @@ template <typename T> class CastTest : public ::testing::Test
using IntDataTypes =
::testing::Types<uint8_t, uint16_t, uint32_t, uint64_t, int8_t, int16_t, int32_t, int64_t>;
-TYPED_TEST_CASE(CastTest, IntDataTypes);
+TYPED_TEST_SUITE(CastTest, IntDataTypes);
TYPED_TEST(CastTest, FloatToInt)
{
diff --git a/compiler/luci-interpreter/src/kernels/Concatenation.cpp b/compiler/luci-interpreter/src/kernels/Concatenation.cpp
index 7cfdf34b9..46ee5941e 100644
--- a/compiler/luci-interpreter/src/kernels/Concatenation.cpp
+++ b/compiler/luci-interpreter/src/kernels/Concatenation.cpp
@@ -69,11 +69,21 @@ void Concatenation::configure()
Shape output_shape = t0->shape();
output_shape.dim(axis) = sum_axis;
- // TODO S8 type needs more checking: quantization parameters of all input tensors and the output
- // tensor should be the same. Note that there is no such requirement for U8 type.
- if (t0->element_type() == DataType::S8)
- throw std::runtime_error("Unsupported type.");
+ // If input tensors are INT8 type then quantization parameters of all input tensors and the output
+ // should be the same
+ for (auto current_tensor : _inputs)
+ {
+ if (current_tensor->element_type() == DataType::S8)
+ {
+ LUCI_INTERPRETER_CHECK(current_tensor->quantized_dimension() ==
+ output()->quantized_dimension());
+ LUCI_INTERPRETER_CHECK(current_tensor->zero_points().size() ==
+ current_tensor->scales().size());
+ LUCI_INTERPRETER_CHECK(current_tensor->zero_points() == output()->zero_points());
+ LUCI_INTERPRETER_CHECK(current_tensor->scales() == output()->scales());
+ }
+ }
output()->resize(output_shape);
}
diff --git a/compiler/luci-interpreter/src/kernels/Concatenation.test.cpp b/compiler/luci-interpreter/src/kernels/Concatenation.test.cpp
index e4b50611a..f893b38fd 100644
--- a/compiler/luci-interpreter/src/kernels/Concatenation.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/Concatenation.test.cpp
@@ -183,12 +183,12 @@ TEST_F(ConcatenationTest, Mismatching_Input_Dimension_NEG)
EXPECT_ANY_THROW(kernel.configure());
}
-TEST_F(ConcatenationTest, Unsupported_Configure_Type_NEG)
+TEST_F(ConcatenationTest, Int8_Mismatching_Input_Type_NEG)
{
- std::vector<int8_t> input1_data{1, 2, 3, 4, 5, 6};
- std::vector<int8_t> input2_data{7, 8, 9, 10, 11, 12};
- Tensor input1_tensor = makeInputTensor<DataType::S8>({2, 3}, input1_data, _memory_manager.get());
- Tensor input2_tensor = makeInputTensor<DataType::S8>({2, 3}, input2_data, _memory_manager.get());
+ std::vector<uint8_t> input1_data{1, 2, 3, 4};
+ std::vector<int8_t> input2_data{5, 6, 7, 8};
+ Tensor input1_tensor = makeInputTensor<DataType::U8>({2, 2}, input1_data, _memory_manager.get());
+ Tensor input2_tensor = makeInputTensor<DataType::S8>({2, 2}, input2_data, _memory_manager.get());
Tensor output_tensor = makeOutputTensor(DataType::S8);
ConcatenationParams params{};
@@ -199,6 +199,51 @@ TEST_F(ConcatenationTest, Unsupported_Configure_Type_NEG)
EXPECT_ANY_THROW(kernel.configure());
}
+TEST_F(ConcatenationTest, Int8_Mismatching_Input_Output_Quant_Params_NEG)
+{
+ std::vector<float> input1_data{1, 2, 3, 4, 5, 6};
+ std::vector<float> input2_data{7, 8, 9, 10, 11, 12};
+ int quantized_dimension = 3;
+ std::vector<float> scales{0.1, 0.2, 0.3};
+ std::vector<int32_t> zero_points{1, -1, 1};
+
+ Tensor input1_tensor = makeInputTensor<DataType::S8>(
+ {1, 1, 2, 3}, scales, zero_points, quantized_dimension, input1_data, _memory_manager.get());
+ Tensor input2_tensor = makeInputTensor<DataType::S8>(
+ {1, 1, 2, 3}, scales, zero_points, quantized_dimension, input2_data, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::S8, scales.at(0), zero_points.at(0));
+ ConcatenationParams params{};
+
+ params.axis = -1;
+ params.activation = luci::FusedActFunc::NONE;
+
+ Concatenation kernel({&input1_tensor, &input2_tensor}, &output_tensor, params);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(ConcatenationTest, Int8_Mismatching_Zero_Point_NEG)
+{
+ std::vector<float> input1_data{1, 2, 3, 4};
+ std::vector<float> input2_data{5, 6, 7, 8};
+ float scale = 0.1;
+ int32_t zero_point_1 = 1;
+ int32_t zero_point_2 = -1;
+
+ Tensor input1_tensor =
+ makeInputTensor<DataType::S8>({2, 2}, scale, zero_point_1, input1_data, _memory_manager.get());
+ Tensor input2_tensor =
+ makeInputTensor<DataType::S8>({2, 2}, scale, zero_point_2, input2_data, _memory_manager.get());
+
+ Tensor output_tensor = makeOutputTensor(DataType::S8, scale, zero_point_1);
+ ConcatenationParams params{};
+
+ params.axis = -1;
+ params.activation = luci::FusedActFunc::NONE;
+
+ Concatenation kernel({&input1_tensor, &input2_tensor}, &output_tensor, params);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
// TODO: Remove this test when concat w/ fused_activation is supported
TEST_F(ConcatenationTest, With_Fused_Activation_NEG)
{
diff --git a/compiler/luci-interpreter/src/kernels/Conv2D.cpp b/compiler/luci-interpreter/src/kernels/Conv2D.cpp
index 5647f4c44..234f95425 100644
--- a/compiler/luci-interpreter/src/kernels/Conv2D.cpp
+++ b/compiler/luci-interpreter/src/kernels/Conv2D.cpp
@@ -30,8 +30,8 @@ namespace kernels
{
Conv2D::Conv2D(const Tensor *input, const Tensor *filter, const Tensor *bias, Tensor *output,
- Tensor *im2col, const Conv2DParams &params)
- : KernelWithParams<Conv2DParams>({input, filter, bias}, {output, im2col}, params)
+ Tensor *scratchpad, const Conv2DParams &params)
+ : KernelWithParams<Conv2DParams>({input, filter, bias}, {output, scratchpad}, params)
{
}
@@ -108,27 +108,18 @@ void Conv2D::configure()
output()->resize({batches, output_height, output_width, output_depth});
- // Allocate tensor for Im2Col, if needed.
- // The checks here should be aligned with the actual implementation.
- const bool need_dilated_im2col =
- _params.dilation_height_factor != 1 || _params.dilation_width_factor != 1;
- const bool need_non_dilated_im2col = _params.stride_height != 1 || _params.stride_width != 1 ||
- filter_height != 1 || filter_width != 1;
- _need_im2col =
- input()->element_type() != DataType::S16 && (need_dilated_im2col || need_non_dilated_im2col);
- if (_need_im2col)
- {
- const int input_depth = input_shape.dim(3);
- Shape im2col_shape{batches, output_height, output_width,
- input_depth * filter_height * filter_width};
- auto im2col = getOutputTensors()[1];
- im2col->resize(im2col_shape);
- }
- else
- {
- auto im2col = getOutputTensors()[1];
- im2col->set_allocatable(false);
- }
+ // Allocate tensor for scratchpad, if needed.
+ tflite::ConvParams params{};
+ params.padding_values.height = _padding_height;
+ params.padding_values.width = _padding_width;
+ params.stride_height = _params.stride_height;
+ params.stride_width = _params.stride_width;
+ params.dilation_height_factor = _params.dilation_height_factor;
+ params.dilation_width_factor = _params.dilation_width_factor;
+ auto scratchpad = getOutputTensors()[1];
+ luci_interpreter_pal::SetupScratchpadTensor(scratchpad, input()->element_type(), params,
+ getTensorShape(input()), getTensorShape(filter()),
+ getTensorShape(output()));
switch (_params.activation)
{
@@ -193,16 +184,16 @@ void Conv2D::evalFloat() const
params.float_activation_min = activation_min;
params.float_activation_max = activation_max;
- float *im2col_data = nullptr;
- auto im2col = getOutputTensors()[1];
- if (_need_im2col)
- {
- im2col_data = im2col->data<float>();
- }
- luci_interpreter_pal::Conv(
- params, getTensorShape(input()), getTensorData<float>(input()), getTensorShape(filter()),
- getTensorData<float>(filter()), getTensorShape(bias()), getTensorData<float>(bias()),
- getTensorShape(output()), getTensorData<float>(output()), getTensorShape(im2col), im2col_data);
+ auto scratchpad = getOutputTensors()[1];
+ float *scratchpad_data = nullptr;
+ if (scratchpad->is_allocatable())
+ scratchpad_data = scratchpad->data<float>();
+
+ luci_interpreter_pal::Conv(params, getTensorShape(input()), getTensorData<float>(input()),
+ getTensorShape(filter()), getTensorData<float>(filter()),
+ getTensorShape(bias()), getTensorData<float>(bias()),
+ getTensorShape(output()), getTensorData<float>(output()),
+ getTensorShape(scratchpad), scratchpad_data);
}
void Conv2D::evalQuantized() const
@@ -236,12 +227,12 @@ void Conv2D::evalQuantized() const
params.quantized_activation_min = activation_min;
params.quantized_activation_max = activation_max;
- auto im2col = getOutputTensors()[1];
+ auto scratchpad = getOutputTensors()[1];
luci_interpreter_pal::Conv(params, getTensorShape(input()), getTensorData<uint8_t>(input()),
getTensorShape(filter()), getTensorData<uint8_t>(filter()),
getTensorShape(bias()), getTensorData<int32_t>(bias()),
getTensorShape(output()), getTensorData<uint8_t>(output()),
- getTensorShape(im2col), getTensorData<uint8_t>(im2col));
+ getTensorShape(scratchpad), getTensorData<uint8_t>(scratchpad));
}
void Conv2D::evalQuantizedPerChannel() const
@@ -364,18 +355,16 @@ void Conv2D::evalQuantizedS8PerChannel() const
std::back_inserter(multipliers),
[](ChannelQuantMultipliers cm) { return cm.multiplier; });
- int8_t *im2col_data = nullptr;
- auto im2col = getOutputTensors()[1];
- if (_need_im2col)
- {
- im2col_data = im2col->data<int8_t>();
- }
+ auto scratchpad = getOutputTensors()[1];
+ int8_t *scratchpad_data = nullptr;
+ if (scratchpad->is_allocatable())
+ scratchpad_data = scratchpad->data<int8_t>();
luci_interpreter_pal::ConvPerChannel(
params, multipliers.data(), shifts.data(), getTensorShape(input()),
getTensorData<int8_t>(input()), getTensorShape(filter()), getTensorData<int8_t>(filter()),
getTensorShape(bias()), getTensorData<int32_t>(bias()), getTensorShape(output()),
- getTensorData<int8_t>(output()), getTensorShape(im2col), im2col_data);
+ getTensorData<int8_t>(output()), getTensorShape(scratchpad), scratchpad_data);
}
void Conv2D::evalQuantizedS16() const
diff --git a/compiler/luci-interpreter/src/kernels/Conv2D.h b/compiler/luci-interpreter/src/kernels/Conv2D.h
index 5f1317638..330bf3a2a 100644
--- a/compiler/luci-interpreter/src/kernels/Conv2D.h
+++ b/compiler/luci-interpreter/src/kernels/Conv2D.h
@@ -31,7 +31,7 @@ class Conv2D : public KernelWithParams<Conv2DParams>
{
public:
Conv2D(const Tensor *input, const Tensor *filter, const Tensor *bias, Tensor *output,
- Tensor *im2col, const Conv2DParams &params);
+ Tensor *scratchpad, const Conv2DParams &params);
const Tensor *input() const { return _inputs[0]; }
const Tensor *filter() const { return _inputs[1]; }
@@ -49,7 +49,6 @@ private:
void evalQuantizedS16() const;
private:
- bool _need_im2col = false;
int32_t _padding_height{};
int32_t _padding_width{};
};
diff --git a/compiler/luci-interpreter/src/kernels/DepthToSpace.test.cpp b/compiler/luci-interpreter/src/kernels/DepthToSpace.test.cpp
index 9b1c09ba9..88e6e07f1 100644
--- a/compiler/luci-interpreter/src/kernels/DepthToSpace.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/DepthToSpace.test.cpp
@@ -32,7 +32,7 @@ template <typename T> class DepthToSpaceTest : public ::testing::Test
};
using DataTypes = ::testing::Types<float, uint8_t>;
-TYPED_TEST_CASE(DepthToSpaceTest, DataTypes);
+TYPED_TEST_SUITE(DepthToSpaceTest, DataTypes);
TYPED_TEST(DepthToSpaceTest, SimpleCase)
{
diff --git a/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.cpp b/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.cpp
index f2dbf6c68..c554c309d 100644
--- a/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.cpp
+++ b/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.cpp
@@ -18,9 +18,7 @@
#include "kernels/Utils.h"
-#include <tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h>
-#include <tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h>
-#include <tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h>
+#include "PALDepthwiseConv2d.h"
#include <stdexcept>
@@ -30,8 +28,9 @@ namespace kernels
{
DepthwiseConv2D::DepthwiseConv2D(const Tensor *input, const Tensor *filter, const Tensor *bias,
- Tensor *output, const DepthwiseConv2DParams &params)
- : KernelWithParams<DepthwiseConv2DParams>({input, filter, bias}, {output}, params)
+ Tensor *output, Tensor *scratchpad,
+ const DepthwiseConv2DParams &params)
+ : KernelWithParams<DepthwiseConv2DParams>({input, filter, bias}, {output, scratchpad}, params)
{
}
@@ -109,6 +108,16 @@ void DepthwiseConv2D::configure()
filter_width, output_width);
output()->resize({batches, output_height, output_width, channels_out});
+
+ tflite::DepthwiseParams params{};
+
+ params.dilation_height_factor = _params.dilation_height_factor;
+ params.dilation_width_factor = _params.dilation_width_factor;
+
+ auto scratchpad = getOutputTensors()[1];
+ luci_interpreter_pal::SetupScratchpadTensor(scratchpad, params, input()->element_type(),
+ getTensorShape(input()), getTensorShape(filter()),
+ getTensorShape(output()));
}
void DepthwiseConv2D::execute() const
@@ -337,11 +346,16 @@ void DepthwiseConv2D::evalQuantizedS8PerChannel() const
std::back_inserter(multipliers),
[](ChannelQuantMultipliers cm) { return cm.multiplier; });
- tflite::reference_integer_ops::DepthwiseConvPerChannel(
+ auto scratchpad = getOutputTensors()[1];
+ int8_t *scratchpad_data = nullptr;
+ if (scratchpad->is_allocatable())
+ scratchpad_data = scratchpad->data<int8_t>();
+
+ luci_interpreter_pal::DepthwiseConvPerChannel<int8_t>(
params, multipliers.data(), shifts.data(), getTensorShape(input()),
getTensorData<int8_t>(input()), getTensorShape(filter()), getTensorData<int8_t>(filter()),
getTensorShape(bias()), getTensorData<int32_t>(bias()), getTensorShape(output()),
- getTensorData<int8_t>(output()));
+ getTensorData<int8_t>(output()), getTensorShape(scratchpad), scratchpad_data);
}
void DepthwiseConv2D::evalQuantizedS16() const
diff --git a/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.h b/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.h
index 6cffd6583..3d1faf6c1 100644
--- a/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.h
+++ b/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.h
@@ -29,7 +29,7 @@ class DepthwiseConv2D : public KernelWithParams<DepthwiseConv2DParams>
{
public:
DepthwiseConv2D(const Tensor *input, const Tensor *filter, const Tensor *bias, Tensor *output,
- const DepthwiseConv2DParams &params);
+ Tensor *scratchpad, const DepthwiseConv2DParams &params);
const Tensor *input() const { return _inputs[0]; }
const Tensor *filter() const { return _inputs[1]; }
diff --git a/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.test.cpp b/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.test.cpp
index 74975899a..6b4673f3e 100644
--- a/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.test.cpp
@@ -59,6 +59,7 @@ TEST_F(DepthwiseConv2DTest, Float)
makeInputTensor<DataType::FLOAT32>(filter_shape, filter_data, _memory_manager.get());
Tensor bias_tensor =
makeInputTensor<DataType::FLOAT32>(bias_shape, bias_data, _memory_manager.get());
+ Tensor scratchpad(DataType::FLOAT32, Shape({}), {}, "");
Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
DepthwiseConv2DParams params{};
@@ -70,8 +71,10 @@ TEST_F(DepthwiseConv2DTest, Float)
params.dilation_width_factor = 1;
params.activation = Activation::RELU;
- DepthwiseConv2D kernel(&input_tensor, &filter_tensor, &bias_tensor, &output_tensor, params);
+ DepthwiseConv2D kernel(&input_tensor, &filter_tensor, &bias_tensor, &output_tensor, &scratchpad,
+ params);
kernel.configure();
+ _memory_manager->allocate_memory(scratchpad);
_memory_manager->allocate_memory(output_tensor);
kernel.execute();
@@ -111,6 +114,7 @@ TEST_F(DepthwiseConv2DTest, Uint8)
{4}, input_quant_param.first * input_quant_param.first, 0, bias_data, _memory_manager.get());
Tensor output_tensor =
makeOutputTensor(DataType::U8, output_quant_param.first, output_quant_param.second);
+ Tensor scratchpad(DataType::FLOAT32, Shape({}), {}, "");
DepthwiseConv2DParams params{};
params.padding = Padding::VALID;
@@ -121,9 +125,11 @@ TEST_F(DepthwiseConv2DTest, Uint8)
params.dilation_width_factor = 1;
params.activation = Activation::NONE;
- DepthwiseConv2D kernel(&input_tensor, &filter_tensor, &bias_tensor, &output_tensor, params);
+ DepthwiseConv2D kernel(&input_tensor, &filter_tensor, &bias_tensor, &output_tensor, &scratchpad,
+ params);
kernel.configure();
_memory_manager->allocate_memory(output_tensor);
+ _memory_manager->allocate_memory(scratchpad);
kernel.execute();
std::vector<float> ref_output_data{
@@ -166,6 +172,7 @@ TEST_F(DepthwiseConv2DTest, SInt16)
Tensor bias_tensor =
makeInputTensor<DataType::S64>(bias_shape, 0.25 * 0.2, 0, bias_data, _memory_manager.get());
Tensor output_tensor = makeOutputTensor(DataType::S16, 0.5, 0);
+ Tensor scratchpad(DataType::S64, Shape({}), {}, "");
DepthwiseConv2DParams params{};
params.padding = Padding::VALID;
@@ -176,9 +183,11 @@ TEST_F(DepthwiseConv2DTest, SInt16)
params.dilation_width_factor = 1;
params.activation = Activation::RELU;
- DepthwiseConv2D kernel(&input_tensor, &filter_tensor, &bias_tensor, &output_tensor, params);
+ DepthwiseConv2D kernel(&input_tensor, &filter_tensor, &bias_tensor, &output_tensor, &scratchpad,
+ params);
kernel.configure();
_memory_manager->allocate_memory(output_tensor);
+ _memory_manager->allocate_memory(scratchpad);
kernel.execute();
EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(ref_output_shape));
@@ -224,6 +233,7 @@ TEST_F(DepthwiseConv2DTest, SInt16_CWQ_weights)
Tensor bias_tensor = makeInputTensor<DataType::S64>(bias_shape, bias_scales, zerop, 0, bias_data,
_memory_manager.get());
Tensor output_tensor = makeOutputTensor(DataType::S16, 0.5, 0);
+ Tensor scratchpad(DataType::S16, Shape({}), {}, "");
DepthwiseConv2DParams params{};
params.padding = Padding::VALID;
@@ -234,9 +244,11 @@ TEST_F(DepthwiseConv2DTest, SInt16_CWQ_weights)
params.dilation_width_factor = 1;
params.activation = Activation::RELU;
- DepthwiseConv2D kernel(&input_tensor, &filter_tensor, &bias_tensor, &output_tensor, params);
+ DepthwiseConv2D kernel(&input_tensor, &filter_tensor, &bias_tensor, &output_tensor, &scratchpad,
+ params);
kernel.configure();
_memory_manager->allocate_memory(output_tensor);
+ _memory_manager->allocate_memory(scratchpad);
kernel.execute();
EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(ref_output_shape));
@@ -299,6 +311,7 @@ TEST_F(DepthwiseConv2DTest, Uint8_CWQ_weights)
_memory_manager.get());
Tensor output_tensor =
makeOutputTensor(DataType::U8, output_quant_param.first, output_quant_param.second);
+ Tensor scratchpad(DataType::U8, Shape({}), {}, "");
DepthwiseConv2DParams params{};
params.padding = Padding::VALID;
@@ -309,9 +322,11 @@ TEST_F(DepthwiseConv2DTest, Uint8_CWQ_weights)
params.dilation_width_factor = 1;
params.activation = Activation::NONE;
- DepthwiseConv2D kernel(&input_tensor, &filter_tensor, &bias_tensor, &output_tensor, params);
+ DepthwiseConv2D kernel(&input_tensor, &filter_tensor, &bias_tensor, &output_tensor, &scratchpad,
+ params);
kernel.configure();
_memory_manager->allocate_memory(output_tensor);
+ _memory_manager->allocate_memory(scratchpad);
kernel.execute();
EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(ref_output_shape));
@@ -375,6 +390,7 @@ TEST_F(DepthwiseConv2DTest, SInt8_CWQ_weights)
_memory_manager.get());
Tensor output_tensor =
makeOutputTensor(DataType::S8, output_quant_param.first, output_quant_param.second);
+ Tensor scratchpad(DataType::S8, Shape({}), {}, "");
DepthwiseConv2DParams params{};
params.padding = Padding::VALID;
@@ -385,9 +401,11 @@ TEST_F(DepthwiseConv2DTest, SInt8_CWQ_weights)
params.dilation_width_factor = 1;
params.activation = Activation::NONE;
- DepthwiseConv2D kernel(&input_tensor, &filter_tensor, &bias_tensor, &output_tensor, params);
+ DepthwiseConv2D kernel(&input_tensor, &filter_tensor, &bias_tensor, &output_tensor, &scratchpad,
+ params);
kernel.configure();
_memory_manager->allocate_memory(output_tensor);
+ _memory_manager->allocate_memory(scratchpad);
kernel.execute();
EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(ref_output_shape));
@@ -419,6 +437,7 @@ TEST_F(DepthwiseConv2DTest, InvalidBiasType_NEG)
makeInputTensor<DataType::FLOAT32>(filter_shape, filter_data, _memory_manager.get());
Tensor bias_tensor = makeInputTensor<DataType::S32>(bias_shape, bias_data, _memory_manager.get());
Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+ Tensor scratchpad(DataType::FLOAT32, Shape({}), {}, "");
DepthwiseConv2DParams params{};
params.padding = Padding::VALID;
@@ -429,7 +448,8 @@ TEST_F(DepthwiseConv2DTest, InvalidBiasType_NEG)
params.dilation_width_factor = 1;
params.activation = Activation::RELU;
- DepthwiseConv2D kernel(&input_tensor, &filter_tensor, &bias_tensor, &output_tensor, params);
+ DepthwiseConv2D kernel(&input_tensor, &filter_tensor, &bias_tensor, &output_tensor, &scratchpad,
+ params);
EXPECT_ANY_THROW(kernel.configure());
}
@@ -458,6 +478,7 @@ TEST_F(DepthwiseConv2DTest, InOutTypeMismatch_NEG)
Tensor bias_tensor =
makeInputTensor<DataType::FLOAT32>(bias_shape, bias_data, _memory_manager.get());
Tensor output_tensor = makeOutputTensor(DataType::U8);
+ Tensor scratchpad(DataType::U8, Shape({}), {}, "");
DepthwiseConv2DParams params{};
params.padding = Padding::VALID;
@@ -468,7 +489,8 @@ TEST_F(DepthwiseConv2DTest, InOutTypeMismatch_NEG)
params.dilation_width_factor = 1;
params.activation = Activation::RELU;
- DepthwiseConv2D kernel(&input_tensor, &filter_tensor, &bias_tensor, &output_tensor, params);
+ DepthwiseConv2D kernel(&input_tensor, &filter_tensor, &bias_tensor, &output_tensor, &scratchpad,
+ params);
EXPECT_ANY_THROW(kernel.configure());
}
@@ -497,6 +519,7 @@ TEST_F(DepthwiseConv2DTest, InvalidInputShape_NEG)
Tensor bias_tensor =
makeInputTensor<DataType::FLOAT32>(bias_shape, bias_data, _memory_manager.get());
Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+ Tensor scratchpad(DataType::FLOAT32, Shape({}), {}, "");
DepthwiseConv2DParams params{};
params.padding = Padding::VALID;
@@ -507,7 +530,8 @@ TEST_F(DepthwiseConv2DTest, InvalidInputShape_NEG)
params.dilation_width_factor = 1;
params.activation = Activation::RELU;
- DepthwiseConv2D kernel(&input_tensor, &filter_tensor, &bias_tensor, &output_tensor, params);
+ DepthwiseConv2D kernel(&input_tensor, &filter_tensor, &bias_tensor, &output_tensor, &scratchpad,
+ params);
EXPECT_ANY_THROW(kernel.configure());
}
@@ -536,6 +560,7 @@ TEST_F(DepthwiseConv2DTest, InvalidFilterShape_NEG)
Tensor bias_tensor =
makeInputTensor<DataType::FLOAT32>(bias_shape, bias_data, _memory_manager.get());
Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+ Tensor scratchpad(DataType::FLOAT32, Shape({}), {}, "");
DepthwiseConv2DParams params{};
params.padding = Padding::VALID;
@@ -546,7 +571,8 @@ TEST_F(DepthwiseConv2DTest, InvalidFilterShape_NEG)
params.dilation_width_factor = 1;
params.activation = Activation::RELU;
- DepthwiseConv2D kernel(&input_tensor, &filter_tensor, &bias_tensor, &output_tensor, params);
+ DepthwiseConv2D kernel(&input_tensor, &filter_tensor, &bias_tensor, &output_tensor, &scratchpad,
+ params);
EXPECT_ANY_THROW(kernel.configure());
}
@@ -575,6 +601,7 @@ TEST_F(DepthwiseConv2DTest, InvalidBiasDim_NEG)
Tensor bias_tensor =
makeInputTensor<DataType::FLOAT32>(bias_shape, bias_data, _memory_manager.get());
Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+ Tensor scratchpad(DataType::FLOAT32, Shape({}), {}, "");
DepthwiseConv2DParams params{};
params.padding = Padding::VALID;
@@ -585,7 +612,8 @@ TEST_F(DepthwiseConv2DTest, InvalidBiasDim_NEG)
params.dilation_width_factor = 1;
params.activation = Activation::RELU;
- DepthwiseConv2D kernel(&input_tensor, &filter_tensor, &bias_tensor, &output_tensor, params);
+ DepthwiseConv2D kernel(&input_tensor, &filter_tensor, &bias_tensor, &output_tensor, &scratchpad,
+ params);
EXPECT_ANY_THROW(kernel.configure());
}
diff --git a/compiler/luci-interpreter/src/kernels/Dequantize.cpp b/compiler/luci-interpreter/src/kernels/Dequantize.cpp
new file mode 100644
index 000000000..96399e5c7
--- /dev/null
+++ b/compiler/luci-interpreter/src/kernels/Dequantize.cpp
@@ -0,0 +1,79 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "kernels/Dequantize.h"
+#include "kernels/Utils.h"
+#include "PALDequantize.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+
+Dequantize::Dequantize(const Tensor *input, Tensor *output) : Kernel({input}, {output}) {}
+
+void Dequantize::configure()
+{
+ LUCI_INTERPRETER_CHECK(input()->element_type() == loco::DataType::S8 ||
+ input()->element_type() == loco::DataType::U8 ||
+ input()->element_type() == loco::DataType::S16);
+
+ LUCI_INTERPRETER_CHECK(input()->scales().size() == 1);
+
+ if (input()->element_type() == loco::DataType::S16)
+ LUCI_INTERPRETER_CHECK(input()->zero_point() == 0);
+
+ LUCI_INTERPRETER_CHECK(output()->element_type() == loco::DataType::FLOAT32);
+
+ output()->resize(input()->shape());
+}
+
+void Dequantize::execute() const
+{
+ tflite::DequantizationParams op_params;
+ op_params.zero_point = input()->zero_point();
+ op_params.scale = input()->scale();
+
+ switch (input()->element_type())
+ {
+ case loco::DataType::U8:
+ {
+ luci_interpreter_pal::Dequantize(op_params, getTensorShape(input()),
+ getTensorData<uint8_t>(input()), getTensorShape(output()),
+ getTensorData<float>(output()));
+ break;
+ }
+ case loco::DataType::S8:
+ {
+ luci_interpreter_pal::Dequantize(op_params, getTensorShape(input()),
+ getTensorData<int8_t>(input()), getTensorShape(output()),
+ getTensorData<float>(output()));
+ break;
+ }
+ case loco::DataType::S16:
+ {
+ luci_interpreter_pal::Dequantize(op_params, getTensorShape(input()),
+ getTensorData<int16_t>(input()), getTensorShape(output()),
+ getTensorData<float>(output()));
+ break;
+ }
+ default:
+ throw std::runtime_error("Unsupported type.");
+ }
+}
+
+} // namespace kernels
+} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/Dequantize.h b/compiler/luci-interpreter/src/kernels/Dequantize.h
new file mode 100644
index 000000000..5565df0e4
--- /dev/null
+++ b/compiler/luci-interpreter/src/kernels/Dequantize.h
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_KERNELS_DEQUANTIZE_H
+#define LUCI_INTERPRETER_KERNELS_DEQUANTIZE_H
+
+#include "core/Kernel.h"
+#include "core/KernelParams.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+
+class Dequantize : public Kernel
+{
+public:
+ Dequantize(const Tensor *input, Tensor *output);
+
+ const Tensor *input() const { return _inputs[0]; }
+ Tensor *output() const { return _outputs[0]; }
+
+ void configure() override;
+ void execute() const override;
+};
+
+} // namespace kernels
+} // namespace luci_interpreter
+
+#endif // LUCI_INTERPRETER_KERNELS_DEQUANTIZE_H
diff --git a/compiler/luci-interpreter/src/kernels/Dequantize.test.cpp b/compiler/luci-interpreter/src/kernels/Dequantize.test.cpp
new file mode 100644
index 000000000..0cab633d6
--- /dev/null
+++ b/compiler/luci-interpreter/src/kernels/Dequantize.test.cpp
@@ -0,0 +1,149 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "kernels/Dequantize.h"
+#include "kernels/TestUtils.h"
+#include "luci_interpreter/TestMemoryManager.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+namespace
+{
+
+using namespace testing;
+
+class DequantizeTest : public ::testing::Test
+{
+protected:
+ void SetUp() override { _memory_manager = std::make_unique<TestMemoryManager>(); }
+
+ std::unique_ptr<IMemoryManager> _memory_manager;
+};
+
+TEST_F(DequantizeTest, Uint8)
+{
+ std::vector<uint8_t> input_data{0, 1, 2, 3, 4, 251, 252, 253, 254, 255};
+
+ std::vector<float> ref_output_data{-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64};
+
+ Tensor input_tensor(loco::DataType::U8, {2, 5}, {{0.5}, {127}}, "");
+
+ _memory_manager->allocate_memory(input_tensor);
+ input_tensor.writeData(input_data.data(), input_data.size() * sizeof(uint8_t));
+
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+
+ Dequantize kernel(&input_tensor, &output_tensor);
+ kernel.configure();
+ _memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<float>(output_tensor),
+ ::testing::ElementsAreArray(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({2, 5}));
+}
+
+TEST_F(DequantizeTest, Sint8)
+{
+ std::vector<int8_t> input_data{-128, -127, -126, -125, -124, 123, 124, 125, 126, 127};
+
+ std::vector<float> ref_output_data{-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64};
+
+ Tensor input_tensor(loco::DataType::S8, {2, 5}, {{0.5}, {-1}}, "");
+
+ _memory_manager->allocate_memory(input_tensor);
+ input_tensor.writeData(input_data.data(), input_data.size() * sizeof(int8_t));
+
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+
+ Dequantize kernel(&input_tensor, &output_tensor);
+ kernel.configure();
+ _memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<float>(output_tensor),
+ ::testing::ElementsAreArray(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({2, 5}));
+}
+
+TEST_F(DequantizeTest, Sint16)
+{
+ std::vector<int16_t> input_data{-129, -126, -125, -124, -123, 124, 125, 126, 127, 131};
+
+ std::vector<float> ref_output_data{-64.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 65.5};
+
+ Tensor input_tensor(loco::DataType::S16, {2, 5}, {{0.5}, {0}}, "");
+
+ _memory_manager->allocate_memory(input_tensor);
+ input_tensor.writeData(input_data.data(), input_data.size() * sizeof(int16_t));
+
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+
+ Dequantize kernel(&input_tensor, &output_tensor);
+ kernel.configure();
+ _memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<float>(output_tensor),
+ ::testing::ElementsAreArray(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({2, 5}));
+}
+
+TEST_F(DequantizeTest, InvalidInputType_NEG)
+{
+ std::vector<float> input_data{-129, -126, -125, -124, -123, 124, 125, 126, 127, 131};
+
+ Tensor input_tensor =
+ makeInputTensor<DataType::FLOAT32>({2, 5}, input_data, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+
+ Dequantize kernel(&input_tensor, &output_tensor);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(DequantizeTest, InvalidOutputType_NEG)
+{
+ std::vector<int16_t> input_data{-129, -126, -125, -124, -123, 124, 125, 126, 127, 131};
+
+ Tensor input_tensor(loco::DataType::S16, {2, 5}, {{0.5}, {0}}, "");
+
+ _memory_manager->allocate_memory(input_tensor);
+ input_tensor.writeData(input_data.data(), input_data.size() * sizeof(int16_t));
+
+ Tensor output_tensor = makeOutputTensor(DataType::S8, /*scale*/ 0.5, /*zero_point*/ -1);
+
+ Dequantize kernel(&input_tensor, &output_tensor);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(DequantizeTest, InvalidInputZeroPoint_NEG)
+{
+ std::vector<float> input_data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
+
+ Tensor input_tensor =
+ makeInputTensor<DataType::S16>({2, 5}, 0.5, -1, input_data, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+
+ Dequantize kernel(&input_tensor, &output_tensor);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+} // namespace
+} // namespace kernels
+} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/Div.cpp b/compiler/luci-interpreter/src/kernels/Div.cpp
index 0e52ba1f0..dd1532278 100644
--- a/compiler/luci-interpreter/src/kernels/Div.cpp
+++ b/compiler/luci-interpreter/src/kernels/Div.cpp
@@ -46,6 +46,12 @@ void Div::execute() const
case DataType::FLOAT32:
evalFloat();
break;
+ case DataType::S64:
+ evalInteger<int64_t>();
+ break;
+ case DataType::S32:
+ evalInteger<int32_t>();
+ break;
case DataType::U8:
evalQuantized();
break;
@@ -56,13 +62,9 @@ void Div::execute() const
void Div::evalFloat() const
{
- float activation_min{};
- float activation_max{};
- calculateActivationRange(_params.activation, &activation_min, &activation_max);
-
tflite::ArithmeticParams params{};
- params.float_activation_min = activation_min;
- params.float_activation_max = activation_max;
+ fillArithmeticActivationRange<float>(params, _params.activation);
+
const bool need_broadcast = tflite::reference_ops::ProcessBroadcastShapes(
getTensorShape(input1()), getTensorShape(input2()), &params);
@@ -80,6 +82,28 @@ void Div::evalFloat() const
}
}
+template <typename T> void Div::evalInteger() const
+{
+ tflite::ArithmeticParams params{};
+ fillArithmeticActivationRange<T>(params, _params.activation);
+
+ const bool need_broadcast = tflite::reference_ops::ProcessBroadcastShapes(
+ getTensorShape(input1()), getTensorShape(input2()), &params);
+
+ if (need_broadcast)
+ {
+ tflite::reference_ops::BroadcastDivSlow(
+ params, getTensorShape(input1()), getTensorData<T>(input1()), getTensorShape(input2()),
+ getTensorData<T>(input2()), getTensorShape(output()), getTensorData<T>(output()));
+ }
+ else
+ {
+ tflite::reference_ops::Div(params, getTensorShape(input1()), getTensorData<T>(input1()),
+ getTensorShape(input2()), getTensorData<T>(input2()),
+ getTensorShape(output()), getTensorData<T>(output()));
+ }
+}
+
void Div::evalQuantized() const
{
const auto input1_scale = static_cast<double>(input1()->scale());
diff --git a/compiler/luci-interpreter/src/kernels/Div.h b/compiler/luci-interpreter/src/kernels/Div.h
index 6040cdd02..c1bf3e10b 100644
--- a/compiler/luci-interpreter/src/kernels/Div.h
+++ b/compiler/luci-interpreter/src/kernels/Div.h
@@ -39,6 +39,7 @@ public:
private:
void evalFloat() const;
+ template <typename T> void evalInteger() const;
void evalQuantized() const;
};
diff --git a/compiler/luci-interpreter/src/kernels/Div.test.cpp b/compiler/luci-interpreter/src/kernels/Div.test.cpp
index 021d68d06..85cd8b90a 100644
--- a/compiler/luci-interpreter/src/kernels/Div.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/Div.test.cpp
@@ -134,6 +134,56 @@ TEST_F(DivTest, Uint8)
EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(output_shape));
}
+template <loco::DataType DType> void checkInteger(luci_interpreter::IMemoryManager *memory_manager)
+{
+ using dtype = typename loco::DataTypeImpl<DType>::Type;
+ Shape base_shape = {2, 3, 1, 2};
+ std::vector<Shape> test_shapes{{1, 1, 3, 2}, {1, 3, 1, 2}, {2, 1, 3, 1}, {2, 3, 1, 1}};
+
+ std::vector<std::vector<dtype>> test_outputs = {{5, 6, 2, 0, 10, 3, //
+ 10, 0, 4, 5, 20, 0, //
+ 0, 0, 0, 2, 0, 0, //
+ 2, 0, 1, 10, 5, 0, //
+ 2, 3, 1, 0, 5, 1, //
+ 18, 20, 7, 0, 37, 10},
+ {5, 6, 4, 5, 0, 0, 2, 0, 1, 0, 37, 10},
+ {5, 7, 4, 6, 2, 3, 10, 0, 8, 0, 4, 0,
+ 0, 0, 0, 0, 0, 0, 0, 10, 5, 0, 1, 0,
+ 0, 0, 5, 9, 1, 1, 0, 0, 37, 50, 7, 10},
+ {5, 7, 8, 0, 0, 0, 0, 10, 5, 9, 7, 10}};
+ std::vector<dtype> input1_data{20, 30, 40, -17, -4, -7, 11, -31, 10, 19, 75, 100};
+ std::vector<dtype> input2_data{4, 5, 10, -3, 2, 10};
+ for (size_t i = 0; i < test_shapes.size(); ++i)
+ {
+ Tensor input1_tensor = makeInputTensor<DType>(base_shape, input1_data, memory_manager);
+ Tensor input2_tensor = makeInputTensor<DType>(test_shapes[i], input2_data, memory_manager);
+ Tensor output_tensor = makeOutputTensor(DType);
+
+ DivParams params{};
+ params.activation = Activation::RELU;
+
+ Div kernel(&input1_tensor, &input2_tensor, &output_tensor, params);
+ kernel.configure();
+ memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<dtype>(output_tensor), test_outputs[i])
+ << "With shape number " << i;
+ }
+}
+
+TEST_F(DivTest, SInt64)
+{
+ checkInteger<loco::DataType::S64>(_memory_manager.get());
+ SUCCEED();
+}
+
+TEST_F(DivTest, SInt32)
+{
+ checkInteger<loco::DataType::S32>(_memory_manager.get());
+ SUCCEED();
+}
+
TEST_F(DivTest, Input_Output_Type_NEG)
{
Tensor input1_tensor = makeInputTensor<DataType::FLOAT32>({1}, {1.f}, _memory_manager.get());
@@ -149,9 +199,9 @@ TEST_F(DivTest, Input_Output_Type_NEG)
TEST_F(DivTest, Invalid_Input_Type_NEG)
{
- Tensor input1_tensor = makeInputTensor<DataType::S64>({1}, {1}, _memory_manager.get());
- Tensor input2_tensor = makeInputTensor<DataType::S64>({1}, {2}, _memory_manager.get());
- Tensor output_tensor = makeOutputTensor(DataType::S64);
+ Tensor input1_tensor = makeInputTensor<DataType::U64>({1}, {1}, _memory_manager.get());
+ Tensor input2_tensor = makeInputTensor<DataType::U64>({1}, {2}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::U64);
DivParams params{};
params.activation = Activation::RELU;
@@ -162,6 +212,19 @@ TEST_F(DivTest, Invalid_Input_Type_NEG)
EXPECT_ANY_THROW(kernel.execute());
}
+TEST_F(DivTest, Invalid_Output_Type_NEG)
+{
+ Tensor input1_tensor = makeInputTensor<DataType::S32>({1}, {1}, _memory_manager.get());
+ Tensor input2_tensor = makeInputTensor<DataType::S32>({1}, {2}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::S64);
+
+ DivParams params{};
+ params.activation = Activation::RELU;
+
+ Div kernel(&input1_tensor, &input2_tensor, &output_tensor, params);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
} // namespace
} // namespace kernels
} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/Equal.cpp b/compiler/luci-interpreter/src/kernels/Equal.cpp
index f58de1250..a57e127b7 100644
--- a/compiler/luci-interpreter/src/kernels/Equal.cpp
+++ b/compiler/luci-interpreter/src/kernels/Equal.cpp
@@ -49,6 +49,12 @@ void Equal::execute() const
case DataType::FLOAT32:
evalFloat();
break;
+ case DataType::S64:
+ evalInteger<int64_t>();
+ break;
+ case DataType::S32:
+ evalInteger<int32_t>();
+ break;
case DataType::U8:
evalQuantized();
break;
@@ -79,6 +85,29 @@ void Equal::evalFloat() const
}
}
+template <typename T> void Equal::evalInteger() const
+{
+ const auto x_data = getTensorData<T>(x());
+ const auto y_data = getTensorData<T>(y());
+ auto output_data = getTensorData<bool>(output());
+
+ tflite::ComparisonParams op_params;
+ op_params.is_broadcast = x()->shape() != y()->shape();
+
+ if (op_params.is_broadcast)
+ {
+ tflite::reference_ops::Broadcast4DSlowEqualNoScaling(op_params, getTensorShape(x()), x_data,
+ getTensorShape(y()), y_data,
+ getTensorShape(output()), output_data);
+ }
+ else
+ {
+ tflite::reference_ops::EqualNoScaling(op_params, getTensorShape(x()), x_data,
+ getTensorShape(y()), y_data, getTensorShape(output()),
+ output_data);
+ }
+}
+
void Equal::evalQuantized() const
{
const auto x_data = getTensorData<uint8_t>(x());
diff --git a/compiler/luci-interpreter/src/kernels/Equal.h b/compiler/luci-interpreter/src/kernels/Equal.h
index 11f025eac..c9be32cc0 100644
--- a/compiler/luci-interpreter/src/kernels/Equal.h
+++ b/compiler/luci-interpreter/src/kernels/Equal.h
@@ -38,6 +38,7 @@ public:
private:
void evalFloat() const;
+ template <typename T> void evalInteger() const;
void evalQuantized() const;
private:
diff --git a/compiler/luci-interpreter/src/kernels/Equal.test.cpp b/compiler/luci-interpreter/src/kernels/Equal.test.cpp
index 46a0f97d8..5870e5460 100644
--- a/compiler/luci-interpreter/src/kernels/Equal.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/Equal.test.cpp
@@ -99,6 +99,82 @@ TEST_F(EqualTest, FloatBroardcast)
EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({4, 3}));
}
+template <loco::DataType DType>
+void checkIntegerSimple(luci_interpreter::IMemoryManager *memory_manager)
+{
+ using dtype = typename loco::DataTypeImpl<DType>::Type;
+ dtype min_value = std::numeric_limits<dtype>::min();
+ dtype max_value = std::numeric_limits<dtype>::max();
+ std::vector<dtype> x_data{min_value, 2, max_value};
+
+ std::vector<dtype> y_data{min_value, -2, max_value};
+
+ std::vector<bool> ref_output_data{true, false, true};
+
+ Tensor x_tensor = makeInputTensor<DType>({3}, x_data, memory_manager);
+ Tensor y_tensor = makeInputTensor<DType>({3}, y_data, memory_manager);
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ Equal kernel(&x_tensor, &y_tensor, &output_tensor);
+ kernel.configure();
+ memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<bool>(output_tensor), ::testing::ElementsAreArray(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({3}));
+}
+
+template <loco::DataType DType>
+void checkIntegerBroadcast(luci_interpreter::IMemoryManager *memory_manager)
+{
+ using dtype = typename loco::DataTypeImpl<DType>::Type;
+ dtype min_value = std::numeric_limits<dtype>::min();
+ dtype max_value = std::numeric_limits<dtype>::max();
+ std::vector<dtype> x_data{
+ min_value, 2, 3, // Row 1
+ 4, 5, max_value, // Row 2
+ -1, -2, -3, // Row 3
+ min_value, -2, max_value, // Row 4
+ };
+
+ std::vector<dtype> y_data{
+ min_value, -2, max_value, // Row 1
+ };
+
+ std::vector<bool> ref_output_data{
+ true, false, false, // Row 1
+ false, false, true, // Row 2
+ false, true, false, // Row 3
+ true, true, true, // Row 4
+ };
+
+ Tensor x_tensor = makeInputTensor<DType>({4, 3}, x_data, memory_manager);
+ Tensor y_tensor = makeInputTensor<DType>({3}, y_data, memory_manager);
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ Equal kernel(&x_tensor, &y_tensor, &output_tensor);
+ kernel.configure();
+ memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<bool>(output_tensor), ::testing::ElementsAreArray(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({4, 3}));
+}
+
+TEST_F(EqualTest, Int32)
+{
+ checkIntegerSimple<loco::DataType::S32>(_memory_manager.get());
+ checkIntegerBroadcast<loco::DataType::S32>(_memory_manager.get());
+ SUCCEED();
+}
+
+TEST_F(EqualTest, Int64)
+{
+ checkIntegerSimple<loco::DataType::S64>(_memory_manager.get());
+ checkIntegerBroadcast<loco::DataType::S64>(_memory_manager.get());
+ SUCCEED();
+}
+
// Choose min / max in such a way that there are exactly 256 units to avoid rounding errors.
const float F_MIN = -128.0 / 128.0;
const float F_MAX = 127.0 / 128.0;
@@ -195,6 +271,36 @@ TEST_F(EqualTest, Input_Output_Type_NEG)
EXPECT_ANY_THROW(kernel.configure());
}
+TEST_F(EqualTest, Float_Broadcast_NEG)
+{
+ Tensor x_tensor = makeInputTensor<DataType::FLOAT32>({2}, {1.f, 2.f}, _memory_manager.get());
+ Tensor y_tensor = makeInputTensor<DataType::FLOAT32>({3}, {1.f, 2.f, 3.f}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ Equal kernel(&x_tensor, &y_tensor, &output_tensor);
+ ASSERT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(EqualTest, Int32_Broadcast_NEG)
+{
+ Tensor x_tensor = makeInputTensor<DataType::S32>({2}, {1, 2}, _memory_manager.get());
+ Tensor y_tensor = makeInputTensor<DataType::S32>({3}, {1, 2, 3}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ Equal kernel(&x_tensor, &y_tensor, &output_tensor);
+ ASSERT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(EqualTest, Int64_Broadcast_NEG)
+{
+ Tensor x_tensor = makeInputTensor<DataType::S64>({2}, {1, 2}, _memory_manager.get());
+ Tensor y_tensor = makeInputTensor<DataType::S64>({3}, {1, 2, 3}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ Equal kernel(&x_tensor, &y_tensor, &output_tensor);
+ ASSERT_ANY_THROW(kernel.configure());
+}
+
} // namespace
} // namespace kernels
} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/ExpandDims.cpp b/compiler/luci-interpreter/src/kernels/ExpandDims.cpp
new file mode 100644
index 000000000..ba35c99fa
--- /dev/null
+++ b/compiler/luci-interpreter/src/kernels/ExpandDims.cpp
@@ -0,0 +1,88 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "kernels/ExpandDims.h"
+#include "kernels/Utils.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+
+ExpandDims::ExpandDims(const Tensor *input, const Tensor *axis, Tensor *output)
+ : Kernel({input, axis}, {output})
+{
+}
+
+void ExpandDims::configure()
+{
+ int32_t axis_value;
+
+ switch (axis()->element_type())
+ {
+ case loco::DataType::S32:
+ axis_value = *getTensorData<int32_t>(axis());
+ break;
+ case loco::DataType::S64:
+ axis_value = static_cast<int32_t>(*getTensorData<int64_t>(axis()));
+ break;
+ default:
+ throw std::runtime_error("Unsupported type.");
+ }
+
+ const auto input_shape = input()->shape();
+
+ if (axis_value < 0)
+ {
+ axis_value += input_shape.num_dims() + 1;
+ }
+
+ LUCI_INTERPRETER_CHECK(axis_value <= input_shape.num_dims() and axis_value >= 0);
+
+ Shape output_shape(input_shape.num_dims() + 1);
+ for (int32_t i = 0; i < output_shape.num_dims(); ++i)
+ {
+ if (i < axis_value)
+ {
+ output_shape.dim(i) = input_shape.dim(i);
+ }
+ else if (i == axis_value)
+ {
+ output_shape.dim(i) = 1;
+ }
+ else
+ {
+ LUCI_INTERPRETER_CHECK(i >= 1);
+ output_shape.dim(i) = input_shape.dim(i - 1);
+ }
+ }
+
+ output()->resize(output_shape);
+}
+
+void ExpandDims::execute() const
+{
+ // Just copy input to output
+ const auto *input_data = input()->data<void>();
+ auto *output_data = output()->data<void>();
+
+ const size_t element_size = getDataTypeSize(input()->element_type());
+ const int32_t num_elements = input()->shape().num_elements();
+ std::memcpy(output_data, input_data, num_elements * element_size);
+}
+
+} // namespace kernels
+} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/ExpandDims.h b/compiler/luci-interpreter/src/kernels/ExpandDims.h
new file mode 100644
index 000000000..e510b1160
--- /dev/null
+++ b/compiler/luci-interpreter/src/kernels/ExpandDims.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_KERNELS_EXPAND_DIMS_H
+#define LUCI_INTERPRETER_KERNELS_EXPAND_DIMS_H
+
+#include "core/Kernel.h"
+#include "core/KernelParams.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+
+class ExpandDims : public Kernel
+{
+public:
+ ExpandDims(const Tensor *input, const Tensor *axis, Tensor *output);
+
+ const Tensor *input() const { return _inputs[0]; }
+ const Tensor *axis() const { return _inputs[1]; }
+ Tensor *output() const { return _outputs[0]; }
+
+ void configure() override;
+ void execute() const override;
+};
+
+} // namespace kernels
+} // namespace luci_interpreter
+
+#endif // LUCI_INTERPRETER_KERNELS_EXPAND_DIMS_H
diff --git a/compiler/luci-interpreter/src/kernels/ExpandDims.test.cpp b/compiler/luci-interpreter/src/kernels/ExpandDims.test.cpp
new file mode 100644
index 000000000..df9eaccc0
--- /dev/null
+++ b/compiler/luci-interpreter/src/kernels/ExpandDims.test.cpp
@@ -0,0 +1,115 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "kernels/ExpandDims.h"
+#include "kernels/TestUtils.h"
+#include "luci_interpreter/TestMemoryManager.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+namespace
+{
+
+using namespace testing;
+
+class ExpandDimsTest : public ::testing::Test
+{
+protected:
+ void SetUp() override { _memory_manager = std::make_unique<TestMemoryManager>(); }
+
+ std::unique_ptr<IMemoryManager> _memory_manager;
+};
+
+TEST_F(ExpandDimsTest, PositiveAxis)
+{
+ std::vector<int32_t> input_data{-1, 1, -2, 2};
+ std::initializer_list<int32_t> input_shape = {2, 2};
+
+ std::initializer_list<int32_t> axis_value = {0};
+
+ Tensor input_tensor =
+ makeInputTensor<DataType::S32>(input_shape, input_data, _memory_manager.get());
+ Tensor axis_tensor = makeInputTensor<DataType::S32>({1}, axis_value, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::S32);
+
+ ExpandDims kernel(&input_tensor, &axis_tensor, &output_tensor);
+ kernel.configure();
+ _memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<int32_t>(output_tensor), ::testing::ElementsAreArray(input_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({1, 2, 2}));
+}
+
+TEST_F(ExpandDimsTest, NegAxis)
+{
+ std::vector<int32_t> input_data{-1, 1, -2, 2};
+ std::initializer_list<int32_t> input_shape = {2, 2};
+
+ std::initializer_list<int32_t> axis_value = {-1};
+
+ Tensor input_tensor =
+ makeInputTensor<DataType::S32>(input_shape, input_data, _memory_manager.get());
+ Tensor axis_tensor = makeInputTensor<DataType::S32>({1}, axis_value, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::S32);
+
+ ExpandDims kernel(&input_tensor, &axis_tensor, &output_tensor);
+ kernel.configure();
+ _memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<int32_t>(output_tensor), ::testing::ElementsAreArray(input_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({2, 2, 1}));
+}
+
+TEST_F(ExpandDimsTest, InvalidAxisType_NEG)
+{
+ std::vector<int32_t> input_data{-1, 1, -2, 2};
+ std::initializer_list<int32_t> input_shape = {2, 2};
+
+ std::initializer_list<float> axis_value = {1.0};
+
+ Tensor input_tensor =
+ makeInputTensor<DataType::S32>(input_shape, input_data, _memory_manager.get());
+ Tensor axis_tensor = makeInputTensor<DataType::FLOAT32>({1}, axis_value, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::S32);
+
+ ExpandDims kernel(&input_tensor, &axis_tensor, &output_tensor);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(ExpandDimsTest, InvalidAxisValue_NEG)
+{
+ std::vector<int32_t> input_data{-1, 1, -2, 2};
+ std::initializer_list<int32_t> input_shape = {2, 2};
+
+ std::initializer_list<int32_t> axis_value = {3};
+
+ Tensor input_tensor =
+ makeInputTensor<DataType::S32>(input_shape, input_data, _memory_manager.get());
+ Tensor axis_tensor = makeInputTensor<DataType::S32>({1}, axis_value, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::S32);
+
+ ExpandDims kernel(&input_tensor, &axis_tensor, &output_tensor);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+} // namespace
+} // namespace kernels
+} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/FullyConnected.cpp b/compiler/luci-interpreter/src/kernels/FullyConnected.cpp
index cfe8f8bf2..bd2bb2f35 100644
--- a/compiler/luci-interpreter/src/kernels/FullyConnected.cpp
+++ b/compiler/luci-interpreter/src/kernels/FullyConnected.cpp
@@ -18,8 +18,7 @@
#include "kernels/Utils.h"
-#include <tensorflow/lite/kernels/internal/reference/fully_connected.h>
-#include <tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h>
+#include "PALFullyConnected.h"
#include <stdexcept>
@@ -74,7 +73,18 @@ void FullyConnected::configure()
if (bias())
LUCI_INTERPRETER_CHECK(bias()->shape().num_elements() == weights()->shape().dim(0));
- output()->resize({batch_size, num_units});
+ if (params().keep_num_dims == false)
+ {
+ output()->resize({batch_size, num_units});
+ }
+ else
+ {
+ luci_interpreter::Shape output_shape(input_shape.num_dims());
+ for (int i = 0; i < input_shape.num_dims(); ++i)
+ output_shape.dim(i) = input_shape.dim(i);
+ output_shape.dim(input_shape.num_dims() - 1) = num_units;
+ output()->resize(output_shape);
+ }
}
void FullyConnected::execute() const
@@ -172,7 +182,7 @@ void FullyConnected::evalQuantizedS8() const
op_params.quantized_activation_max = output_activation_max;
op_params.lhs_cacheable = false;
op_params.rhs_cacheable = false;
- tflite::reference_integer_ops::FullyConnected(
+ luci_interpreter_pal::FullyConnected<int8_t>(
op_params, getTensorShape(input()), getTensorData<int8_t>(input()), getTensorShape(weights()),
getTensorData<int8_t>(weights()), getTensorShape(bias()), getTensorData<int32_t>(bias()),
getTensorShape(output()), getTensorData<int8_t>(output()));
diff --git a/compiler/luci-interpreter/src/kernels/FullyConnected.test.cpp b/compiler/luci-interpreter/src/kernels/FullyConnected.test.cpp
index b0eda0145..4474cc4fb 100644
--- a/compiler/luci-interpreter/src/kernels/FullyConnected.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/FullyConnected.test.cpp
@@ -133,7 +133,7 @@ template <typename T> class FullyConnectedTest : public ::testing::Test
};
using DataTypes = ::testing::Types<float, uint8_t, int8_t>;
-TYPED_TEST_CASE(FullyConnectedTest, DataTypes);
+TYPED_TEST_SUITE(FullyConnectedTest, DataTypes);
TYPED_TEST(FullyConnectedTest, Simple)
{
diff --git a/compiler/luci-interpreter/src/kernels/Gather.cpp b/compiler/luci-interpreter/src/kernels/Gather.cpp
new file mode 100644
index 000000000..f1256660f
--- /dev/null
+++ b/compiler/luci-interpreter/src/kernels/Gather.cpp
@@ -0,0 +1,139 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "kernels/Gather.h"
+#include "kernels/Utils.h"
+#include "PALGather.h"
+
+#include <stdexcept>
+#include <cassert>
+
+namespace luci_interpreter
+{
+
+namespace kernels
+{
+
+Gather::Gather(const Tensor *params, const Tensor *indices, Tensor *output,
+ const GatherParams &gparams)
+ : KernelWithParams<GatherParams>({params, indices}, {output}, gparams)
+{
+}
+
+void Gather::configure()
+{
+ if (params()->element_type() == DataType::FLOAT32)
+ {
+ LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::FLOAT32);
+ }
+ else
+ {
+ throw std::runtime_error("Unsupported type.");
+ }
+
+ LUCI_INTERPRETER_CHECK(indices()->element_type() == DataType::S32 ||
+ indices()->element_type() == DataType::S64);
+
+ // refer tensorflow/lite/kernels/gather.cc
+
+ const Shape &params_shape = params()->shape();
+ const Shape &indices_shape = indices()->shape();
+
+ int axis = _params.axis;
+ if (axis < 0)
+ {
+ axis += params_shape.num_dims();
+ }
+ LUCI_INTERPRETER_CHECK(0 <= axis && axis < params_shape.num_dims());
+
+ int batch_dims = _params.batch_dims;
+ // batch_dims should be in range: [-rank(indices), rank(indices)].
+ // Negative batch_dims is added with rank of positions.
+ if (batch_dims < 0)
+ {
+ batch_dims += indices_shape.num_dims();
+ }
+ LUCI_INTERPRETER_CHECK(batch_dims <= axis);
+ LUCI_INTERPRETER_CHECK(0 <= batch_dims && batch_dims < params_shape.num_dims());
+ LUCI_INTERPRETER_CHECK(batch_dims <= indices_shape.num_dims());
+ for (int i = 0; i < batch_dims; ++i)
+ {
+ LUCI_INTERPRETER_CHECK(params_shape.dim(i) == indices_shape.dim(i));
+ }
+
+ const int num_dimensions = params_shape.num_dims() + indices_shape.num_dims() - 1 - batch_dims;
+
+ Shape output_shape(num_dimensions);
+ int output_index = 0;
+ for (int i = 0; i < axis; ++i)
+ {
+ output_shape.dim(output_index++) = params_shape.dim(i);
+ }
+ for (int i = batch_dims; i < indices_shape.num_dims(); ++i)
+ {
+ output_shape.dim(output_index++) = indices_shape.dim(i);
+ }
+ for (int i = axis + 1; i < params_shape.num_dims(); ++i)
+ {
+ output_shape.dim(output_index++) = params_shape.dim(i);
+ }
+ output()->resize(output_shape);
+}
+
+void Gather::execute() const
+{
+ switch (params()->element_type())
+ {
+ case DataType::FLOAT32:
+ evalFloat();
+ break;
+ default:
+ throw std::runtime_error("Unsupported type.");
+ }
+}
+
+void Gather::evalFloat() const
+{
+ assert(indices()->element_type() == DataType::S32 || indices()->element_type() == DataType::S64);
+
+ const auto params_data = getTensorData<float>(params());
+ auto output_data = getTensorData<float>(output());
+
+ tflite::GatherParams tparams;
+ tparams.axis = _params.axis;
+ tparams.batch_dims = _params.batch_dims;
+
+ if (indices()->element_type() == DataType::S32)
+ {
+ const auto indices_data = getTensorData<int32_t>(indices());
+
+ luci_interpreter_pal::Gather<float, int32_t>(tparams, getTensorShape(params()), params_data,
+ getTensorShape(indices()), indices_data,
+ getTensorShape(output()), output_data);
+ }
+ else
+ {
+ const auto indices_data = getTensorData<int64_t>(indices());
+
+ luci_interpreter_pal::Gather<float, int64_t>(tparams, getTensorShape(params()), params_data,
+ getTensorShape(indices()), indices_data,
+ getTensorShape(output()), output_data);
+ }
+}
+
+} // namespace kernels
+} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/Gather.h b/compiler/luci-interpreter/src/kernels/Gather.h
new file mode 100644
index 000000000..cc02d64fb
--- /dev/null
+++ b/compiler/luci-interpreter/src/kernels/Gather.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_KERNELS_GATHER_H
+#define LUCI_INTERPRETER_KERNELS_GATHER_H
+
+#include "core/Kernel.h"
+#include "core/KernelParams.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+
+class Gather : public KernelWithParams<GatherParams>
+{
+public:
+ Gather(const Tensor *params, const Tensor *indices, Tensor *output, const GatherParams &gparams);
+
+ const Tensor *params() const { return _inputs[0]; }
+ const Tensor *indices() const { return _inputs[1]; }
+ Tensor *output() const { return _outputs[0]; }
+
+ void configure() override;
+ void execute() const override;
+
+private:
+ void evalFloat() const;
+};
+
+} // namespace kernels
+} // namespace luci_interpreter
+
+#endif // LUCI_INTERPRETER_KERNELS_GATHER_H
diff --git a/compiler/luci-interpreter/src/kernels/Gather.test.cpp b/compiler/luci-interpreter/src/kernels/Gather.test.cpp
new file mode 100644
index 000000000..4b3dda708
--- /dev/null
+++ b/compiler/luci-interpreter/src/kernels/Gather.test.cpp
@@ -0,0 +1,137 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "kernels/Gather.h"
+#include "kernels/TestUtils.h"
+#include "luci_interpreter/TestMemoryManager.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+namespace
+{
+
+using namespace testing;
+
+class GatherTest : public ::testing::Test
+{
+protected:
+ void SetUp() override { _memory_manager = std::make_unique<TestMemoryManager>(); }
+
+ std::unique_ptr<IMemoryManager> _memory_manager;
+};
+
+TEST_F(GatherTest, Simple)
+{
+ std::vector<float> params_data{1.f, 2.f, 3.f, 4.f, 5.f, 6.f};
+ std::vector<int32_t> indices_data{1, 0, 1, 5};
+ std::vector<float> ref_output_data{2.f, 1.f, 2.f, 6.f};
+
+ Tensor params_tensor =
+ makeInputTensor<DataType::FLOAT32>({1, 6}, params_data, _memory_manager.get());
+ Tensor indices_tensor = makeInputTensor<DataType::S32>({4}, indices_data, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+ GatherParams gparams;
+
+ gparams.axis = 1;
+ gparams.batch_dims = 0;
+
+ Gather kernel(&params_tensor, &indices_tensor, &output_tensor, gparams);
+ kernel.configure();
+ _memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<float>(output_tensor),
+ ::testing::ElementsAreArray(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({1, 4}));
+}
+
+TEST_F(GatherTest, Simple_Batch)
+{
+ Shape params_shape = {3, 5};
+ Shape indices_shape = {3, 2};
+ std::vector<float> params_data{0., 0., 1., 0., 2., 3., 0., 0., 0., 4., 0., 5., 0., 6., 0.};
+ std::vector<int32_t> indices_data{2, 4, 0, 4, 1, 3};
+ std::vector<float> ref_output_data{1., 2., 3., 4., 5., 6.};
+
+ Tensor params_tensor =
+ makeInputTensor<DataType::FLOAT32>(params_shape, params_data, _memory_manager.get());
+ Tensor indices_tensor =
+ makeInputTensor<DataType::S32>(indices_shape, indices_data, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+ GatherParams gparams;
+
+ gparams.axis = 1;
+ gparams.batch_dims = 1;
+
+ Gather kernel(&params_tensor, &indices_tensor, &output_tensor, gparams);
+ kernel.configure();
+ _memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<float>(output_tensor),
+ ::testing::ElementsAreArray(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({3, 2}));
+}
+
+TEST_F(GatherTest, Simple_NEG)
+{
+ Tensor params_tensor = makeInputTensor<DataType::S32>({1}, {1}, _memory_manager.get());
+ Tensor indices_tensor = makeInputTensor<DataType::S32>({1}, {0}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+ GatherParams gparams;
+
+ Gather kernel(&params_tensor, &indices_tensor, &output_tensor, gparams);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(GatherTest, Axis_NEG)
+{
+ Tensor params_tensor = makeInputTensor<DataType::FLOAT32>({1}, {1.f}, _memory_manager.get());
+ Tensor indices_tensor = makeInputTensor<DataType::S32>({1}, {0}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+ GatherParams gparams;
+
+ gparams.axis = 100;
+ gparams.batch_dims = 0;
+
+ Gather kernel(&params_tensor, &indices_tensor, &output_tensor, gparams);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(GatherTest, Batch_NEG)
+{
+ std::vector<float> params_data{1.f, 2.f, 3.f, 4.f, 5.f, 6.f};
+ std::vector<int32_t> indices_data{1, 0, 1, 5};
+ std::vector<float> ref_output_data{2.f, 1.f, 2.f, 6.f};
+
+ Tensor params_tensor =
+ makeInputTensor<DataType::FLOAT32>({1, 6}, params_data, _memory_manager.get());
+ Tensor indices_tensor = makeInputTensor<DataType::S32>({4}, indices_data, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+ GatherParams gparams;
+
+ gparams.axis = 0;
+ gparams.batch_dims = 1;
+
+ Gather kernel(&params_tensor, &indices_tensor, &output_tensor, gparams);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+} // namespace
+} // namespace kernels
+} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/Greater.cpp b/compiler/luci-interpreter/src/kernels/Greater.cpp
index f0dd2db36..5ccae3c38 100644
--- a/compiler/luci-interpreter/src/kernels/Greater.cpp
+++ b/compiler/luci-interpreter/src/kernels/Greater.cpp
@@ -49,6 +49,12 @@ void Greater::execute() const
case DataType::FLOAT32:
evalFloat();
break;
+ case DataType::S64:
+ evalInteger<int64_t>();
+ break;
+ case DataType::S32:
+ evalInteger<int32_t>();
+ break;
case DataType::U8:
evalQuantized();
break;
@@ -79,6 +85,29 @@ void Greater::evalFloat() const
}
}
+template <typename T> void Greater::evalInteger() const
+{
+ const auto x_data = getTensorData<T>(x());
+ const auto y_data = getTensorData<T>(y());
+ auto output_data = getTensorData<bool>(output());
+
+ tflite::ComparisonParams op_params;
+ op_params.is_broadcast = x()->shape() != y()->shape();
+
+ if (op_params.is_broadcast)
+ {
+ tflite::reference_ops::Broadcast4DSlowGreaterNoScaling(op_params, getTensorShape(x()), x_data,
+ getTensorShape(y()), y_data,
+ getTensorShape(output()), output_data);
+ }
+ else
+ {
+ tflite::reference_ops::GreaterNoScaling(op_params, getTensorShape(x()), x_data,
+ getTensorShape(y()), y_data, getTensorShape(output()),
+ output_data);
+ }
+}
+
void Greater::evalQuantized() const
{
const auto x_data = getTensorData<uint8_t>(x());
diff --git a/compiler/luci-interpreter/src/kernels/Greater.h b/compiler/luci-interpreter/src/kernels/Greater.h
index 877c139c9..065f76d7b 100644
--- a/compiler/luci-interpreter/src/kernels/Greater.h
+++ b/compiler/luci-interpreter/src/kernels/Greater.h
@@ -38,6 +38,7 @@ public:
private:
void evalFloat() const;
+ template <typename T> void evalInteger() const;
void evalQuantized() const;
private:
diff --git a/compiler/luci-interpreter/src/kernels/Greater.test.cpp b/compiler/luci-interpreter/src/kernels/Greater.test.cpp
index ba3925f17..a48080124 100644
--- a/compiler/luci-interpreter/src/kernels/Greater.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/Greater.test.cpp
@@ -97,6 +97,82 @@ TEST_F(GreaterTest, FloatBroardcast)
EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({3, 3}));
}
+template <loco::DataType DType>
+void checkIntegerSimple(luci_interpreter::IMemoryManager *memory_manager)
+{
+ using dtype = typename loco::DataTypeImpl<DType>::Type;
+ dtype min_value = std::numeric_limits<dtype>::min();
+ dtype max_value = std::numeric_limits<dtype>::max();
+ std::vector<dtype> x_data{min_value, 2, max_value};
+
+ std::vector<dtype> y_data{min_value + 1, -2, max_value};
+
+ std::vector<bool> ref_output_data{false, true, false};
+
+ Tensor x_tensor = makeInputTensor<DType>({3}, x_data, memory_manager);
+ Tensor y_tensor = makeInputTensor<DType>({3}, y_data, memory_manager);
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ Greater kernel(&x_tensor, &y_tensor, &output_tensor);
+ kernel.configure();
+ memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<bool>(output_tensor), ::testing::ElementsAreArray(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({3}));
+}
+
+template <loco::DataType DType>
+void checkIntegerBroadcast(luci_interpreter::IMemoryManager *memory_manager)
+{
+ using dtype = typename loco::DataTypeImpl<DType>::Type;
+ dtype min_value = std::numeric_limits<dtype>::min();
+ dtype max_value = std::numeric_limits<dtype>::max();
+ std::vector<dtype> x_data{
+ min_value, 2, 3, // Row 1
+ 4, 5, max_value, // Row 2
+ -1, -4, -3, // Row 3
+ min_value, -2, max_value, // Row 4
+ };
+
+ std::vector<dtype> y_data{
+ min_value + 1, -2, max_value - 1, // Row 1
+ };
+
+ std::vector<bool> ref_output_data{
+ false, true, false, // Row 1
+ true, true, true, // Row 2
+ true, false, false, // Row 3
+ false, false, true, // Row 4
+ };
+
+ Tensor x_tensor = makeInputTensor<DType>({4, 3}, x_data, memory_manager);
+ Tensor y_tensor = makeInputTensor<DType>({3}, y_data, memory_manager);
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ Greater kernel(&x_tensor, &y_tensor, &output_tensor);
+ kernel.configure();
+ memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<bool>(output_tensor), ::testing::ElementsAreArray(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({4, 3}));
+}
+
+TEST_F(GreaterTest, Int32)
+{
+ checkIntegerSimple<loco::DataType::S32>(_memory_manager.get());
+ checkIntegerBroadcast<loco::DataType::S32>(_memory_manager.get());
+ SUCCEED();
+}
+
+TEST_F(GreaterTest, Int64)
+{
+ checkIntegerSimple<loco::DataType::S64>(_memory_manager.get());
+ checkIntegerBroadcast<loco::DataType::S64>(_memory_manager.get());
+ SUCCEED();
+}
+
// Choose min / max in such a way that there are exactly 256 units to avoid rounding errors.
const float F_MIN = -128.0 / 128.0;
const float F_MAX = 127.0 / 128.0;
@@ -223,6 +299,36 @@ TEST_F(GreaterTest, Input_Output_Type_NEG)
EXPECT_ANY_THROW(kernel.configure());
}
+TEST_F(GreaterTest, Float_Broadcast_NEG)
+{
+ Tensor x_tensor = makeInputTensor<DataType::FLOAT32>({2}, {1.f, 2.f}, _memory_manager.get());
+ Tensor y_tensor = makeInputTensor<DataType::FLOAT32>({3}, {1.f, 2.f, 3.f}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ Greater kernel(&x_tensor, &y_tensor, &output_tensor);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(GreaterTest, Int32_Broadcast_NEG)
+{
+ Tensor x_tensor = makeInputTensor<DataType::S32>({2}, {1, 2}, _memory_manager.get());
+ Tensor y_tensor = makeInputTensor<DataType::S32>({3}, {1, 2, 3}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ Greater kernel(&x_tensor, &y_tensor, &output_tensor);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(GreaterTest, Int64_Broadcast_NEG)
+{
+ Tensor x_tensor = makeInputTensor<DataType::S64>({2}, {1, 2}, _memory_manager.get());
+ Tensor y_tensor = makeInputTensor<DataType::S64>({3}, {1, 2, 3}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ Greater kernel(&x_tensor, &y_tensor, &output_tensor);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
} // namespace
} // namespace kernels
} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/GreaterEqual.cpp b/compiler/luci-interpreter/src/kernels/GreaterEqual.cpp
index e7c1b4afe..27e42c971 100644
--- a/compiler/luci-interpreter/src/kernels/GreaterEqual.cpp
+++ b/compiler/luci-interpreter/src/kernels/GreaterEqual.cpp
@@ -52,6 +52,12 @@ void GreaterEqual::execute() const
case DataType::FLOAT32:
evalFloat();
break;
+ case DataType::S64:
+ evalInteger<int64_t>();
+ break;
+ case DataType::S32:
+ evalInteger<int32_t>();
+ break;
case DataType::U8:
evalQuantized();
break;
@@ -82,6 +88,29 @@ void GreaterEqual::evalFloat() const
}
}
+template <typename T> void GreaterEqual::evalInteger() const
+{
+ const auto x_data = getTensorData<T>(x());
+ const auto y_data = getTensorData<T>(y());
+ auto output_data = getTensorData<bool>(output());
+
+ tflite::ComparisonParams op_params;
+ op_params.is_broadcast = x()->shape() != y()->shape();
+
+ if (op_params.is_broadcast)
+ {
+ tflite::reference_ops::Broadcast4DSlowGreaterEqualNoScaling(
+ op_params, getTensorShape(x()), x_data, getTensorShape(y()), y_data, getTensorShape(output()),
+ output_data);
+ }
+ else
+ {
+ tflite::reference_ops::GreaterEqualNoScaling(op_params, getTensorShape(x()), x_data,
+ getTensorShape(y()), y_data,
+ getTensorShape(output()), output_data);
+ }
+}
+
void GreaterEqual::evalQuantized() const
{
const auto x_data = getTensorData<uint8_t>(x());
diff --git a/compiler/luci-interpreter/src/kernels/GreaterEqual.h b/compiler/luci-interpreter/src/kernels/GreaterEqual.h
index 4a0f48748..e333c30a6 100644
--- a/compiler/luci-interpreter/src/kernels/GreaterEqual.h
+++ b/compiler/luci-interpreter/src/kernels/GreaterEqual.h
@@ -38,6 +38,7 @@ public:
private:
void evalFloat() const;
+ template <typename T> void evalInteger() const;
void evalQuantized() const;
private:
diff --git a/compiler/luci-interpreter/src/kernels/GreaterEqual.test.cpp b/compiler/luci-interpreter/src/kernels/GreaterEqual.test.cpp
index a9d172301..35bf88eab 100644
--- a/compiler/luci-interpreter/src/kernels/GreaterEqual.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/GreaterEqual.test.cpp
@@ -96,6 +96,81 @@ TEST_F(GreaterEqualTest, FloatBroardcast)
EXPECT_THAT(extractTensorData<bool>(output_tensor), ::testing::ElementsAreArray(ref_output_data));
EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({3, 3}));
}
+template <loco::DataType DType>
+void checkIntegerSimple(luci_interpreter::IMemoryManager *memory_manager)
+{
+ using dtype = typename loco::DataTypeImpl<DType>::Type;
+ dtype min_value = std::numeric_limits<dtype>::min();
+ dtype max_value = std::numeric_limits<dtype>::max();
+ std::vector<dtype> x_data{min_value, 2, max_value};
+
+ std::vector<dtype> y_data{min_value + 1, -2, max_value};
+
+ std::vector<bool> ref_output_data{false, true, true};
+
+ Tensor x_tensor = makeInputTensor<DType>({3}, x_data, memory_manager);
+ Tensor y_tensor = makeInputTensor<DType>({3}, y_data, memory_manager);
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ GreaterEqual kernel(&x_tensor, &y_tensor, &output_tensor);
+ kernel.configure();
+ memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<bool>(output_tensor), ::testing::ElementsAreArray(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({3}));
+}
+
+template <loco::DataType DType>
+void checkIntegerBroadcast(luci_interpreter::IMemoryManager *memory_manager)
+{
+ using dtype = typename loco::DataTypeImpl<DType>::Type;
+ dtype min_value = std::numeric_limits<dtype>::min();
+ dtype max_value = std::numeric_limits<dtype>::max();
+ std::vector<dtype> x_data{
+ min_value, 2, 3, // Row 1
+ 4, 5, max_value, // Row 2
+ -1, -4, -3, // Row 3
+ min_value, -2, max_value - 1, // Row 4
+ };
+
+ std::vector<dtype> y_data{
+ min_value + 1, -2, max_value - 1, // Row 1
+ };
+
+ std::vector<bool> ref_output_data{
+ false, true, false, // Row 1
+ true, true, true, // Row 2
+ true, false, false, // Row 3
+ false, true, true, // Row 4
+ };
+
+ Tensor x_tensor = makeInputTensor<DType>({4, 3}, x_data, memory_manager);
+ Tensor y_tensor = makeInputTensor<DType>({3}, y_data, memory_manager);
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ GreaterEqual kernel(&x_tensor, &y_tensor, &output_tensor);
+ kernel.configure();
+ memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<bool>(output_tensor), ::testing::ElementsAreArray(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({4, 3}));
+}
+
+TEST_F(GreaterEqualTest, Int32)
+{
+ checkIntegerSimple<loco::DataType::S32>(_memory_manager.get());
+ checkIntegerBroadcast<loco::DataType::S32>(_memory_manager.get());
+ SUCCEED();
+}
+
+TEST_F(GreaterEqualTest, Int64)
+{
+ checkIntegerSimple<loco::DataType::S64>(_memory_manager.get());
+ checkIntegerBroadcast<loco::DataType::S64>(_memory_manager.get());
+ SUCCEED();
+}
// Choose min / max in such a way that there are exactly 256 units to avoid rounding errors.
const float F_MIN = -128.0 / 128.0;
@@ -223,6 +298,36 @@ TEST_F(GreaterEqualTest, Input_Output_Type_NEG)
EXPECT_ANY_THROW(kernel.configure());
}
+TEST_F(GreaterEqualTest, Float_Broadcast_NEG)
+{
+ Tensor x_tensor = makeInputTensor<DataType::FLOAT32>({2}, {1.f, 2.f}, _memory_manager.get());
+ Tensor y_tensor = makeInputTensor<DataType::FLOAT32>({3}, {1.f, 2.f, 3.f}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ GreaterEqual kernel(&x_tensor, &y_tensor, &output_tensor);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(GreaterEqualTest, Int32_Broadcast_NEG)
+{
+ Tensor x_tensor = makeInputTensor<DataType::S32>({2}, {1, 2}, _memory_manager.get());
+ Tensor y_tensor = makeInputTensor<DataType::S32>({3}, {1, 2, 3}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ GreaterEqual kernel(&x_tensor, &y_tensor, &output_tensor);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(GreaterEqualTest, Int64_Broadcast_NEG)
+{
+ Tensor x_tensor = makeInputTensor<DataType::S64>({2}, {1, 2}, _memory_manager.get());
+ Tensor y_tensor = makeInputTensor<DataType::S64>({3}, {1, 2, 3}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ GreaterEqual kernel(&x_tensor, &y_tensor, &output_tensor);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
} // namespace
} // namespace kernels
} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/L2Normalize.test.cpp b/compiler/luci-interpreter/src/kernels/L2Normalize.test.cpp
index 1e565e358..6f960e8b4 100644
--- a/compiler/luci-interpreter/src/kernels/L2Normalize.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/L2Normalize.test.cpp
@@ -81,7 +81,7 @@ template <typename T> class L2NormalizeTest : public ::testing::Test
};
using DataTypes = ::testing::Types<float, uint8_t>;
-TYPED_TEST_CASE(L2NormalizeTest, DataTypes);
+TYPED_TEST_SUITE(L2NormalizeTest, DataTypes);
TYPED_TEST(L2NormalizeTest, Simple)
{
diff --git a/compiler/luci-interpreter/src/kernels/L2Pool2D.test.cpp b/compiler/luci-interpreter/src/kernels/L2Pool2D.test.cpp
index 289742a50..7245456cb 100644
--- a/compiler/luci-interpreter/src/kernels/L2Pool2D.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/L2Pool2D.test.cpp
@@ -206,7 +206,8 @@ TEST_F(L2Pool2DTest, FloatPaddingSameStride)
kernel.execute();
std::vector<float> ref_output_data{3.5, 6.0, 6.5, 5.70088, 2.54951, 7.2111, 8.63134, 7.0};
- EXPECT_THAT(extractTensorData<float>(output_tensor), FloatArrayNear(ref_output_data));
+ // NOTE with NEON+ruy, error is #1=-1.14441e-05, #6=-1.81198e-05
+ EXPECT_THAT(extractTensorData<float>(output_tensor), FloatArrayNear(ref_output_data, 1.0e-4f));
// TODO make a Shape checking of output_tensor.
}
diff --git a/compiler/luci-interpreter/src/kernels/LeakyRelu.test.cpp b/compiler/luci-interpreter/src/kernels/LeakyRelu.test.cpp
index 6ec8a348a..0f6263b57 100644
--- a/compiler/luci-interpreter/src/kernels/LeakyRelu.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/LeakyRelu.test.cpp
@@ -83,7 +83,7 @@ template <typename T> class LeakReluTest : public ::testing::Test
};
using DataTypes = ::testing::Types<float, uint8_t>;
-TYPED_TEST_CASE(LeakReluTest, DataTypes);
+TYPED_TEST_SUITE(LeakReluTest, DataTypes);
TYPED_TEST(LeakReluTest, Simple)
{
diff --git a/compiler/luci-interpreter/src/kernels/Less.cpp b/compiler/luci-interpreter/src/kernels/Less.cpp
index 041444926..8d26ff297 100644
--- a/compiler/luci-interpreter/src/kernels/Less.cpp
+++ b/compiler/luci-interpreter/src/kernels/Less.cpp
@@ -49,6 +49,12 @@ void Less::execute() const
case DataType::FLOAT32:
evalFloat();
break;
+ case DataType::S64:
+ evalInteger<int64_t>();
+ break;
+ case DataType::S32:
+ evalInteger<int32_t>();
+ break;
case DataType::U8:
evalQuantized();
break;
@@ -79,6 +85,29 @@ void Less::evalFloat() const
}
}
+template <typename T> void Less::evalInteger() const
+{
+ const auto x_data = getTensorData<T>(x());
+ const auto y_data = getTensorData<T>(y());
+ auto output_data = getTensorData<bool>(output());
+
+ tflite::ComparisonParams op_params;
+ op_params.is_broadcast = x()->shape() != y()->shape();
+
+ if (op_params.is_broadcast)
+ {
+ tflite::reference_ops::Broadcast4DSlowLessNoScaling(op_params, getTensorShape(x()), x_data,
+ getTensorShape(y()), y_data,
+ getTensorShape(output()), output_data);
+ }
+ else
+ {
+ tflite::reference_ops::LessNoScaling(op_params, getTensorShape(x()), x_data,
+ getTensorShape(y()), y_data, getTensorShape(output()),
+ output_data);
+ }
+}
+
void Less::evalQuantized() const
{
const auto x_data = getTensorData<uint8_t>(x());
diff --git a/compiler/luci-interpreter/src/kernels/Less.h b/compiler/luci-interpreter/src/kernels/Less.h
index 293740e72..e27bb689c 100644
--- a/compiler/luci-interpreter/src/kernels/Less.h
+++ b/compiler/luci-interpreter/src/kernels/Less.h
@@ -38,6 +38,7 @@ public:
private:
void evalFloat() const;
+ template <typename T> void evalInteger() const;
void evalQuantized() const;
private:
diff --git a/compiler/luci-interpreter/src/kernels/Less.test.cpp b/compiler/luci-interpreter/src/kernels/Less.test.cpp
index e9d09b288..8c5963363 100644
--- a/compiler/luci-interpreter/src/kernels/Less.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/Less.test.cpp
@@ -97,6 +97,82 @@ TEST_F(LessTest, FloatBroardcast)
EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({3, 3}));
}
+template <loco::DataType DType>
+void checkIntegerSimple(luci_interpreter::IMemoryManager *memory_manager)
+{
+ using dtype = typename loco::DataTypeImpl<DType>::Type;
+ dtype min_value = std::numeric_limits<dtype>::min();
+ dtype max_value = std::numeric_limits<dtype>::max();
+ std::vector<dtype> x_data{min_value, 2, max_value};
+
+ std::vector<dtype> y_data{min_value + 1, -2, max_value};
+
+ std::vector<bool> ref_output_data{true, false, false};
+
+ Tensor x_tensor = makeInputTensor<DType>({3}, x_data, memory_manager);
+ Tensor y_tensor = makeInputTensor<DType>({3}, y_data, memory_manager);
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ Less kernel(&x_tensor, &y_tensor, &output_tensor);
+ kernel.configure();
+ memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<bool>(output_tensor), ::testing::ElementsAreArray(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({3}));
+}
+
+template <loco::DataType DType>
+void checkIntegerBroadcast(luci_interpreter::IMemoryManager *memory_manager)
+{
+ using dtype = typename loco::DataTypeImpl<DType>::Type;
+ dtype min_value = std::numeric_limits<dtype>::min();
+ dtype max_value = std::numeric_limits<dtype>::max();
+ std::vector<dtype> x_data{
+ min_value, 2, 3, // Row 1
+ 4, 5, max_value, // Row 2
+ -1, -4, -3, // Row 3
+ min_value, -2, max_value, // Row 4
+ };
+
+ std::vector<dtype> y_data{
+ min_value + 1, -2, max_value - 1, // Row 1
+ };
+
+ std::vector<bool> ref_output_data{
+ true, false, true, // Row 1
+ false, false, false, // Row 2
+ false, true, true, // Row 3
+ true, false, false, // Row 4
+ };
+
+ Tensor x_tensor = makeInputTensor<DType>({4, 3}, x_data, memory_manager);
+ Tensor y_tensor = makeInputTensor<DType>({3}, y_data, memory_manager);
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ Less kernel(&x_tensor, &y_tensor, &output_tensor);
+ kernel.configure();
+ memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<bool>(output_tensor), ::testing::ElementsAreArray(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({4, 3}));
+}
+
+TEST_F(LessTest, Int32)
+{
+ checkIntegerSimple<loco::DataType::S32>(_memory_manager.get());
+ checkIntegerBroadcast<loco::DataType::S32>(_memory_manager.get());
+ SUCCEED();
+}
+
+TEST_F(LessTest, Int64)
+{
+ checkIntegerSimple<loco::DataType::S64>(_memory_manager.get());
+ checkIntegerBroadcast<loco::DataType::S64>(_memory_manager.get());
+ SUCCEED();
+}
+
// Choose min / max in such a way that there are exactly 256 units to avoid rounding errors.
const float F_MIN = -128.0 / 128.0;
const float F_MAX = 127.0 / 128.0;
@@ -223,6 +299,36 @@ TEST_F(LessTest, Input_Output_Type_NEG)
EXPECT_ANY_THROW(kernel.configure());
}
+TEST_F(LessTest, Float_Broadcast_NEG)
+{
+ Tensor x_tensor = makeInputTensor<DataType::FLOAT32>({2}, {1.f, 2.f}, _memory_manager.get());
+ Tensor y_tensor = makeInputTensor<DataType::FLOAT32>({3}, {1.f, 2.f, 3.f}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ Less kernel(&x_tensor, &y_tensor, &output_tensor);
+ ASSERT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(LessTest, Int32_Broadcast_NEG)
+{
+ Tensor x_tensor = makeInputTensor<DataType::S32>({2}, {1, 2}, _memory_manager.get());
+ Tensor y_tensor = makeInputTensor<DataType::S32>({3}, {1, 2, 3}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ Less kernel(&x_tensor, &y_tensor, &output_tensor);
+ ASSERT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(LessTest, Int64_Broadcast_NEG)
+{
+ Tensor x_tensor = makeInputTensor<DataType::S64>({2}, {1, 2}, _memory_manager.get());
+ Tensor y_tensor = makeInputTensor<DataType::S64>({3}, {1, 2, 3}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ Less kernel(&x_tensor, &y_tensor, &output_tensor);
+ ASSERT_ANY_THROW(kernel.configure());
+}
+
} // namespace
} // namespace kernels
} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/LessEqual.cpp b/compiler/luci-interpreter/src/kernels/LessEqual.cpp
index 5f4c7f7aa..b474bc47a 100644
--- a/compiler/luci-interpreter/src/kernels/LessEqual.cpp
+++ b/compiler/luci-interpreter/src/kernels/LessEqual.cpp
@@ -49,6 +49,12 @@ void LessEqual::execute() const
case DataType::FLOAT32:
evalFloat();
break;
+ case DataType::S64:
+ evalInteger<int64_t>();
+ break;
+ case DataType::S32:
+ evalInteger<int32_t>();
+ break;
case DataType::U8:
evalQuantized();
break;
@@ -79,6 +85,29 @@ void LessEqual::evalFloat() const
}
}
+template <typename T> void LessEqual::evalInteger() const
+{
+ const auto x_data = getTensorData<T>(x());
+ const auto y_data = getTensorData<T>(y());
+ auto output_data = getTensorData<bool>(output());
+
+ tflite::ComparisonParams op_params;
+ op_params.is_broadcast = x()->shape() != y()->shape();
+
+ if (op_params.is_broadcast)
+ {
+ tflite::reference_ops::Broadcast4DSlowLessEqualNoScaling(op_params, getTensorShape(x()), x_data,
+ getTensorShape(y()), y_data,
+ getTensorShape(output()), output_data);
+ }
+ else
+ {
+ tflite::reference_ops::LessEqualNoScaling(op_params, getTensorShape(x()), x_data,
+ getTensorShape(y()), y_data, getTensorShape(output()),
+ output_data);
+ }
+}
+
void LessEqual::evalQuantized() const
{
const auto x_data = getTensorData<uint8_t>(x());
diff --git a/compiler/luci-interpreter/src/kernels/LessEqual.h b/compiler/luci-interpreter/src/kernels/LessEqual.h
index b6da1a2a8..f82ea90d4 100644
--- a/compiler/luci-interpreter/src/kernels/LessEqual.h
+++ b/compiler/luci-interpreter/src/kernels/LessEqual.h
@@ -38,6 +38,7 @@ public:
private:
void evalFloat() const;
+ template <typename T> void evalInteger() const;
void evalQuantized() const;
private:
diff --git a/compiler/luci-interpreter/src/kernels/LessEqual.test.cpp b/compiler/luci-interpreter/src/kernels/LessEqual.test.cpp
index 0558003dd..b2e2fa7a1 100644
--- a/compiler/luci-interpreter/src/kernels/LessEqual.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/LessEqual.test.cpp
@@ -97,6 +97,82 @@ TEST_F(LessEqualTest, FloatBroardcast)
EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({3, 3}));
}
+template <loco::DataType DType>
+void checkIntegerSimple(luci_interpreter::IMemoryManager *memory_manager)
+{
+ using dtype = typename loco::DataTypeImpl<DType>::Type;
+ dtype min_value = std::numeric_limits<dtype>::min();
+ dtype max_value = std::numeric_limits<dtype>::max();
+ std::vector<dtype> x_data{min_value, 2, max_value};
+
+ std::vector<dtype> y_data{min_value + 1, -2, max_value};
+
+ std::vector<bool> ref_output_data{true, false, true};
+
+ Tensor x_tensor = makeInputTensor<DType>({3}, x_data, memory_manager);
+ Tensor y_tensor = makeInputTensor<DType>({3}, y_data, memory_manager);
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ LessEqual kernel(&x_tensor, &y_tensor, &output_tensor);
+ kernel.configure();
+ memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<bool>(output_tensor), ::testing::ElementsAreArray(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({3}));
+}
+
+template <loco::DataType DType>
+void checkIntegerBroadcast(luci_interpreter::IMemoryManager *memory_manager)
+{
+ using dtype = typename loco::DataTypeImpl<DType>::Type;
+ dtype min_value = std::numeric_limits<dtype>::min();
+ dtype max_value = std::numeric_limits<dtype>::max();
+ std::vector<dtype> x_data{
+ min_value, 2, 3, // Row 1
+ 4, 5, max_value, // Row 2
+ -1, -4, -3, // Row 3
+ min_value, -2, max_value, // Row 4
+ };
+
+ std::vector<dtype> y_data{
+ min_value + 1, -2, max_value - 1, // Row 1
+ };
+
+ std::vector<bool> ref_output_data{
+ true, false, true, // Row 1
+ false, false, false, // Row 2
+ false, true, true, // Row 3
+ true, true, false, // Row 4
+ };
+
+ Tensor x_tensor = makeInputTensor<DType>({4, 3}, x_data, memory_manager);
+ Tensor y_tensor = makeInputTensor<DType>({3}, y_data, memory_manager);
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ LessEqual kernel(&x_tensor, &y_tensor, &output_tensor);
+ kernel.configure();
+ memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<bool>(output_tensor), ::testing::ElementsAreArray(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({4, 3}));
+}
+
+TEST_F(LessEqualTest, Int32)
+{
+ checkIntegerSimple<loco::DataType::S32>(_memory_manager.get());
+ checkIntegerBroadcast<loco::DataType::S32>(_memory_manager.get());
+ SUCCEED();
+}
+
+TEST_F(LessEqualTest, Int64)
+{
+ checkIntegerSimple<loco::DataType::S64>(_memory_manager.get());
+ checkIntegerBroadcast<loco::DataType::S64>(_memory_manager.get());
+ SUCCEED();
+}
+
// Choose min / max in such a way that there are exactly 256 units to avoid rounding errors.
const float F_MIN = -128.0 / 128.0;
const float F_MAX = 127.0 / 128.0;
@@ -223,6 +299,36 @@ TEST_F(LessEqualTest, Input_Output_Type_NEG)
EXPECT_ANY_THROW(kernel.configure());
}
+TEST_F(LessEqualTest, Float_Broadcast_NEG)
+{
+ Tensor x_tensor = makeInputTensor<DataType::FLOAT32>({2}, {1.f, 2.f}, _memory_manager.get());
+ Tensor y_tensor = makeInputTensor<DataType::FLOAT32>({3}, {1.f, 2.f, 3.f}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ LessEqual kernel(&x_tensor, &y_tensor, &output_tensor);
+ ASSERT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(LessEqualTest, Int32_Broadcast_NEG)
+{
+ Tensor x_tensor = makeInputTensor<DataType::S32>({2}, {1, 2}, _memory_manager.get());
+ Tensor y_tensor = makeInputTensor<DataType::S32>({3}, {1, 2, 3}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ LessEqual kernel(&x_tensor, &y_tensor, &output_tensor);
+ ASSERT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(LessEqualTest, Int64_Broadcast_NEG)
+{
+ Tensor x_tensor = makeInputTensor<DataType::S64>({2}, {1, 2}, _memory_manager.get());
+ Tensor y_tensor = makeInputTensor<DataType::S64>({3}, {1, 2, 3}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ LessEqual kernel(&x_tensor, &y_tensor, &output_tensor);
+ ASSERT_ANY_THROW(kernel.configure());
+}
+
} // namespace
} // namespace kernels
} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/Logistic.test.cpp b/compiler/luci-interpreter/src/kernels/Logistic.test.cpp
index 70227563f..5a1ea669c 100644
--- a/compiler/luci-interpreter/src/kernels/Logistic.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/Logistic.test.cpp
@@ -76,7 +76,7 @@ template <typename T> class LogisticTest : public ::testing::Test
};
using DataTypes = ::testing::Types<float, uint8_t>;
-TYPED_TEST_CASE(LogisticTest, DataTypes);
+TYPED_TEST_SUITE(LogisticTest, DataTypes);
TYPED_TEST(LogisticTest, Simple)
{
diff --git a/compiler/luci-interpreter/src/kernels/MirrorPad.cpp b/compiler/luci-interpreter/src/kernels/MirrorPad.cpp
index 89049c96c..2fbeefce4 100644
--- a/compiler/luci-interpreter/src/kernels/MirrorPad.cpp
+++ b/compiler/luci-interpreter/src/kernels/MirrorPad.cpp
@@ -1,5 +1,6 @@
/*
* Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -18,8 +19,6 @@
#include "kernels/Utils.h"
-#include <tensorflow/lite/kernels/internal/reference/pad.h>
-
namespace luci_interpreter
{
namespace kernels
@@ -59,44 +58,25 @@ void MirrorPad::configure()
output()->resize(output_shape);
}
+template <typename T>
+inline void MirrorPadImpl(const Tensor &input, const Tensor &paddings, MirrorPadMode mode,
+ Tensor &output);
+
void MirrorPad::execute() const
{
- const int num_dims = input()->shape().num_dims();
-
- tflite::PadParams params{};
- params.left_padding_count = num_dims;
- params.right_padding_count = num_dims;
-
- const auto *paddings_data = getTensorData<int32_t>(paddings());
- for (int i = num_dims - 1; i >= 0; --i)
- {
- params.left_padding[i] = paddings_data[i * 2];
- params.right_padding[i] = paddings_data[i * 2 + 1];
- }
-
switch (input()->element_type())
{
case DataType::FLOAT32:
{
- const float pad_value = 0;
-
- // NOTE: this implementation only obtains min-max values for quantization
- // TODO: calculate proper inference values
- tflite::reference_ops::Pad(params, getTensorShape(input()), getTensorData<float>(input()),
- &pad_value, getTensorShape(output()),
- getTensorData<float>(output()));
+ MirrorPadImpl<float>(*input(), *paddings(), params().mode, *output());
break;
}
case DataType::U8:
{
- // NOTE: this implementation only obtains min-max values for quantization
- // TODO: calculate proper inference values
assert(output()->zero_point() >= std::numeric_limits<uint8_t>::min());
assert(output()->zero_point() <= std::numeric_limits<uint8_t>::max());
- const auto pad_value = static_cast<uint8_t>(output()->zero_point());
- tflite::reference_ops::Pad(params, getTensorShape(input()), getTensorData<uint8_t>(input()),
- &pad_value, getTensorShape(output()),
- getTensorData<uint8_t>(output()));
+
+ MirrorPadImpl<uint8_t>(*input(), *paddings(), params().mode, *output());
break;
}
default:
@@ -104,5 +84,87 @@ void MirrorPad::execute() const
}
}
+template <typename T>
+inline void MirrorPadImpl(const Tensor &input, const Tensor &paddings, MirrorPadMode mode,
+ Tensor &output)
+{
+ auto const input_dims = input.shape().num_dims();
+ auto const input_data = input.data<T>();
+ auto const paddings_data = paddings.data<int32_t>();
+ auto const output_data = output.data<T>();
+
+ auto const input_b = input_dims > 3 ? input.shape().dim(input_dims - 4) : 1;
+ auto const input_h = input_dims > 2 ? input.shape().dim(input_dims - 3) : 1;
+ auto const input_w = input_dims > 1 ? input.shape().dim(input_dims - 2) : 1;
+ auto const input_d = input.shape().dim(input_dims - 1);
+
+ auto const input_h_offset = input_d * input_w;
+ auto const input_b_offset = input_h_offset * input_h;
+
+ auto const output_b = input_dims > 3 ? output.shape().dim(input_dims - 4) : 1;
+ auto const output_h = input_dims > 2 ? output.shape().dim(input_dims - 3) : 1;
+ auto const output_w = input_dims > 1 ? output.shape().dim(input_dims - 2) : 1;
+ auto const output_d = output.shape().dim(input_dims - 1);
+
+ auto const left_b_pad = paddings_data[2 * (input_dims - 4)];
+ auto const left_h_pad = paddings_data[2 * (input_dims - 3)];
+ auto const left_w_pad = paddings_data[2 * (input_dims - 2)];
+ auto const left_d_pad = paddings_data[2 * (input_dims - 1)];
+
+ auto const right_b_pad = paddings_data[2 * (input_dims - 4) + 1];
+ auto const right_h_pad = paddings_data[2 * (input_dims - 3) + 1];
+ auto const right_w_pad = paddings_data[2 * (input_dims - 2) + 1];
+ auto const right_d_pad = paddings_data[2 * (input_dims - 1) + 1];
+
+ const auto positive_mod = [](auto a, auto b) { return (a % b + b) % b; };
+ const auto offset_index = [input_d, input_h_offset, input_b_offset](auto d, auto w, auto h,
+ auto b) {
+ return d + w * input_d + h * input_h_offset + b * input_b_offset;
+ };
+
+ const auto symmetric_dim = [&positive_mod](auto i, auto left_pad, auto input) {
+ bool reflected = (((i < left_pad ? i + 1 - input : i) - left_pad) / input & 1) == 1;
+ return positive_mod(reflected ? input + left_pad - i - 1 : i - left_pad, input);
+ };
+
+ const T *in_ptr = input_data;
+ T *out_ptr = output_data;
+
+ for (int32_t b = 0; b < output_b; ++b)
+ {
+ for (int32_t h = 0; h < output_h; ++h)
+ {
+ for (int32_t w = 0; w < output_w; ++w)
+ {
+ for (int32_t d = 0; d < output_d; ++d)
+ {
+ if (b < left_b_pad || b >= output_b - right_b_pad || //
+ h < left_h_pad || h >= output_h - right_h_pad || //
+ w < left_w_pad || w >= output_w - right_w_pad || //
+ d < left_d_pad || d >= output_d - right_d_pad)
+ {
+ if (mode == MirrorPadMode::REFLECT)
+ {
+ *out_ptr++ = input_data[offset_index(
+ positive_mod(d - left_d_pad, input_d), positive_mod(w - left_w_pad, input_w),
+ positive_mod(h - left_h_pad, input_h), positive_mod(b - left_b_pad, input_b))];
+ }
+ else
+ {
+ *out_ptr++ = input_data[offset_index(
+ symmetric_dim(d, left_d_pad, input_d), symmetric_dim(w, left_w_pad, input_w),
+ symmetric_dim(h, left_h_pad, input_h), symmetric_dim(b, left_b_pad, input_b))];
+ }
+ }
+ else
+ {
+ *out_ptr++ = *in_ptr++;
+ }
+ }
+ }
+ }
+ }
+}
+
} // namespace kernels
} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/MirrorPad.test.cpp b/compiler/luci-interpreter/src/kernels/MirrorPad.test.cpp
index de9da5051..740d8cb22 100644
--- a/compiler/luci-interpreter/src/kernels/MirrorPad.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/MirrorPad.test.cpp
@@ -14,4 +14,212 @@
* limitations under the License.
*/
-// TODO: Add tests for MirrorPad
+#include "kernels/MirrorPad.h"
+#include "kernels/TestUtils.h"
+#include "luci_interpreter/TestMemoryManager.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+namespace
+{
+
+using namespace testing;
+
+class MirrorPadTest : public ::testing::Test
+{
+protected:
+ void SetUp() override { _memory_manager = std::make_unique<TestMemoryManager>(); }
+
+ void Execute(const Tensor &input, const Tensor &padding, Tensor &output, MirrorPadMode mode)
+ {
+ MirrorPadParams params{};
+ params.mode = mode;
+
+ MirrorPad kernel(&input, &padding, &output, params);
+ kernel.configure();
+ _memory_manager->allocate_memory(output);
+ kernel.execute();
+ }
+
+ std::unique_ptr<IMemoryManager> _memory_manager;
+};
+
+TEST_F(MirrorPadTest, FloatReflect)
+{
+ Shape input_shape = {1, 2, 2, 1};
+ Shape padding_shape = {4, 2};
+
+ std::vector<float> input_data{1.0f, 2.0f, //
+ 3.0f, 4.0f}; //
+ std::vector<int> padding_data{0, 0, 2, 1, 1, 2, 0, 0};
+
+ Tensor input_tensor =
+ makeInputTensor<DataType::FLOAT32>(input_shape, input_data, _memory_manager.get());
+ Tensor padding_tensor =
+ makeInputTensor<DataType::S32>(padding_shape, padding_data, _memory_manager.get());
+
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+
+ Execute(input_tensor, padding_tensor, output_tensor, MirrorPadMode::REFLECT);
+
+ std::vector<float> ref_output_data{2.0f, 1.0f, 2.0f, 1.0f, 2.0f, //
+ 4.0f, 3.0f, 4.0f, 3.0f, 4.0f, //
+ 2.0f, 1.0f, 2.0f, 1.0f, 2.0f, //
+ 4.0f, 3.0f, 4.0f, 3.0f, 4.0f, //
+ 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}; //
+ std::initializer_list<int32_t> ref_output_shape{1, 5, 5, 1};
+
+ EXPECT_THAT(extractTensorData<float>(output_tensor), FloatArrayNear(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(ref_output_shape));
+}
+
+TEST_F(MirrorPadTest, FloatSymmetric)
+{
+ Shape input_shape = {1, 2, 2, 1};
+ Shape padding_shape = {4, 2};
+
+ std::vector<float> input_data{1.0f, 2.0f, //
+ 3.0f, 4.0f}; //
+ std::vector<int> padding_data{0, 0, 2, 1, 1, 2, 0, 0};
+
+ Tensor input_tensor =
+ makeInputTensor<DataType::FLOAT32>(input_shape, input_data, _memory_manager.get());
+ Tensor padding_tensor =
+ makeInputTensor<DataType::S32>(padding_shape, padding_data, _memory_manager.get());
+
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+
+ Execute(input_tensor, padding_tensor, output_tensor, MirrorPadMode::SYMMETRIC);
+
+ std::vector<float> ref_output_data{3.0, 3.0, 4.0, 4.0, 3.0, //
+ 1.0, 1.0, 2.0, 2.0, 1.0, //
+ 1.0, 1.0, 2.0, 2.0, 1.0, //
+ 3.0, 3.0, 4.0, 4.0, 3.0, //
+ 3.0, 3.0, 4.0, 4.0, 3.0}; //
+ std::initializer_list<int32_t> ref_output_shape{1, 5, 5, 1};
+
+ EXPECT_THAT(extractTensorData<float>(output_tensor), FloatArrayNear(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(ref_output_shape));
+}
+
+TEST_F(MirrorPadTest, FloatSymmetric2Dim)
+{
+ Shape input_shape = {3, 1};
+ Shape padding_shape = {2, 2};
+
+ std::vector<float> input_data{1.0f, 2.0f, 3.0f};
+ std::vector<int> padding_data{1, 2, 0, 0};
+
+ Tensor input_tensor =
+ makeInputTensor<DataType::FLOAT32>(input_shape, input_data, _memory_manager.get());
+ Tensor padding_tensor =
+ makeInputTensor<DataType::S32>(padding_shape, padding_data, _memory_manager.get());
+
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+
+ Execute(input_tensor, padding_tensor, output_tensor, MirrorPadMode::SYMMETRIC);
+
+ std::vector<float> ref_output_data{1.0, 1.0, 2.0, 3.0, 3.0, 2.0};
+ std::initializer_list<int32_t> ref_output_shape{6, 1};
+
+ EXPECT_THAT(extractTensorData<float>(output_tensor), FloatArrayNear(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(ref_output_shape));
+}
+
+TEST_F(MirrorPadTest, Uint8Reflect)
+{
+ Shape input_shape = {1, 2, 3, 1};
+ Shape padding_shape = {4, 2};
+
+ float quant_tolerance = getTolerance(0.0f, 6.0f, 255);
+ std::pair<float, int32_t> quant_param = quantizationParams<uint8_t>(0.0f, 6.0f);
+
+ std::vector<float> input_data{1.0f, 2.0f, 3.0f, //
+ 4.0f, 5.0f, 6.0f}; //
+ std::vector<int> padding_data{0, 0, 2, 1, 1, 3, 0, 0};
+
+ Tensor input_tensor = makeInputTensor<DataType::U8>(
+ input_shape, quant_param.first, quant_param.second, input_data, _memory_manager.get());
+
+ Tensor padding_tensor =
+ makeInputTensor<DataType::S32>(padding_shape, padding_data, _memory_manager.get());
+
+ Tensor output_tensor = makeOutputTensor(DataType::U8, quant_param.first, quant_param.second);
+
+ Execute(input_tensor, padding_tensor, output_tensor, MirrorPadMode::REFLECT);
+
+ std::vector<float> ref_output_data{
+ 3.0f, 1.0f, 2.0f, 3.0f, 1.0f, 2.0f, 3.0f, //
+ 6.0f, 4.0f, 5.0f, 6.0f, 4.0f, 5.0f, 6.0f, //
+ 3.0f, 1.0f, 2.0f, 3.0f, 1.0f, 2.0f, 3.0f, //
+ 6.0f, 4.0f, 5.0f, 6.0f, 4.0f, 5.0f, 6.0f, //
+ 3.0f, 1.0f, 2.0f, 3.0f, 1.0f, 2.0f, 3.0f, //
+ };
+ std::initializer_list<int32_t> ref_output_shape{1, 5, 7, 1};
+
+ EXPECT_THAT(dequantizeTensorData(output_tensor),
+ FloatArrayNear(ref_output_data, quant_tolerance));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(ref_output_shape));
+}
+
+TEST_F(MirrorPadTest, Uint8Symmetric)
+{
+ Shape input_shape = {1, 2, 3, 1};
+ Shape padding_shape = {4, 2};
+
+ float quant_tolerance = getTolerance(0.0f, 6.0f, 255);
+ std::pair<float, int32_t> quant_param = quantizationParams<uint8_t>(0.0f, 6.0f);
+
+ std::vector<float> input_data{1.0f, 2.0f, 3.0f, //
+ 4.0f, 5.0f, 6.0f}; //
+ std::vector<int> padding_data{0, 0, 2, 1, 1, 3, 0, 0};
+
+ Tensor input_tensor = makeInputTensor<DataType::U8>(
+ input_shape, quant_param.first, quant_param.second, input_data, _memory_manager.get());
+
+ Tensor padding_tensor =
+ makeInputTensor<DataType::S32>(padding_shape, padding_data, _memory_manager.get());
+
+ Tensor output_tensor = makeOutputTensor(DataType::U8, quant_param.first, quant_param.second);
+
+ Execute(input_tensor, padding_tensor, output_tensor, MirrorPadMode::SYMMETRIC);
+
+ std::vector<float> ref_output_data{
+ 4.0f, 4.0f, 5.0f, 6.0f, 6.0f, 5.0f, 4.0f, //
+ 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 2.0f, 1.0f, //
+ 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 2.0f, 1.0f, //
+ 4.0f, 4.0f, 5.0f, 6.0f, 6.0f, 5.0f, 4.0f, //
+ 4.0f, 4.0f, 5.0f, 6.0f, 6.0f, 5.0f, 4.0f, //
+ };
+ std::initializer_list<int32_t> ref_output_shape{1, 5, 7, 1};
+
+ EXPECT_THAT(dequantizeTensorData(output_tensor),
+ FloatArrayNear(ref_output_data, quant_tolerance));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(ref_output_shape));
+}
+
+TEST_F(MirrorPadTest, UnsupportedDim_NEG)
+{
+ Tensor input_tensor =
+ makeInputTensor<DataType::FLOAT32>({1, 1, 1, 1, 1}, {1.0f}, _memory_manager.get());
+ Tensor padding_tensor =
+ makeInputTensor<DataType::S32>({5, 2}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+
+ EXPECT_ANY_THROW(Execute(input_tensor, padding_tensor, output_tensor, MirrorPadMode::REFLECT));
+}
+
+TEST_F(MirrorPadTest, InvalidInputType_NEG)
+{
+ Tensor input_tensor = makeInputTensor<DataType::S64>({1}, {1}, _memory_manager.get());
+ Tensor padding_tensor = makeInputTensor<DataType::S32>({1, 2}, {0, 0}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::S64);
+
+ EXPECT_ANY_THROW(Execute(input_tensor, padding_tensor, output_tensor, MirrorPadMode::REFLECT));
+}
+
+} // namespace
+} // namespace kernels
+} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/Mul.cpp b/compiler/luci-interpreter/src/kernels/Mul.cpp
index bc855de0f..531fb4fa1 100644
--- a/compiler/luci-interpreter/src/kernels/Mul.cpp
+++ b/compiler/luci-interpreter/src/kernels/Mul.cpp
@@ -42,6 +42,8 @@ void Mul::configure()
LUCI_INTERPRETER_CHECK(output()->element_type() == input1()->element_type());
if (input1()->element_type() == DataType::S16)
{
+ LUCI_INTERPRETER_CHECK(input1()->zero_points().size() == 1 &&
+ input2()->zero_points().size() == 1)
LUCI_INTERPRETER_CHECK(input1()->zero_point() == 0 && input2()->zero_point() == 0 &&
output()->zero_point() == 0);
}
@@ -56,6 +58,12 @@ void Mul::execute() const
case DataType::FLOAT32:
evalFloat();
break;
+ case DataType::S64:
+ evalInteger<int64_t>();
+ break;
+ case DataType::S32:
+ evalInteger<int32_t>();
+ break;
case DataType::S16:
evalQuantizedS16();
break;
@@ -66,13 +74,8 @@ void Mul::execute() const
void Mul::evalFloat() const
{
- float activation_min{};
- float activation_max{};
- calculateActivationRange(_params.activation, &activation_min, &activation_max);
-
tflite::ArithmeticParams params{};
- params.float_activation_min = activation_min;
- params.float_activation_max = activation_max;
+ fillArithmeticActivationRange<float>(params, _params.activation);
const bool need_broadcast = tflite::reference_ops::ProcessBroadcastShapes(
getTensorShape(input1()), getTensorShape(input2()), &params);
@@ -91,6 +94,28 @@ void Mul::evalFloat() const
}
}
+template <typename T> void Mul::evalInteger() const
+{
+ tflite::ArithmeticParams params{};
+ fillArithmeticActivationRange<T>(params, _params.activation);
+
+ const bool need_broadcast = tflite::reference_ops::ProcessBroadcastShapes(
+ getTensorShape(input1()), getTensorShape(input2()), &params);
+
+ if (need_broadcast)
+ {
+ luci_interpreter_pal::BroadcastMul4DSlow(
+ params, getTensorShape(input1()), getTensorData<T>(input1()), getTensorShape(input2()),
+ getTensorData<T>(input2()), getTensorShape(output()), getTensorData<T>(output()));
+ }
+ else
+ {
+ luci_interpreter_pal::Mul(params, getTensorShape(input1()), getTensorData<T>(input1()),
+ getTensorShape(input2()), getTensorData<T>(input2()),
+ getTensorShape(output()), getTensorData<T>(output()));
+ }
+}
+
void Mul::evalQuantizedS16() const
{
const auto input1_scale = static_cast<double>(input1()->scale());
diff --git a/compiler/luci-interpreter/src/kernels/Mul.h b/compiler/luci-interpreter/src/kernels/Mul.h
index 2ccf60f3a..c0cf817df 100644
--- a/compiler/luci-interpreter/src/kernels/Mul.h
+++ b/compiler/luci-interpreter/src/kernels/Mul.h
@@ -42,6 +42,7 @@ public:
private:
void evalFloat() const;
+ template <typename T> void evalInteger() const;
void evalQuantizedS16() const;
};
diff --git a/compiler/luci-interpreter/src/kernels/Mul.test.cpp b/compiler/luci-interpreter/src/kernels/Mul.test.cpp
index 471f6ac86..fc0e60614 100644
--- a/compiler/luci-interpreter/src/kernels/Mul.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/Mul.test.cpp
@@ -93,6 +93,78 @@ TEST_F(MulTest, Float)
}
}
+template <loco::DataType DType> void checkInteger(luci_interpreter::IMemoryManager *memory_manager)
+{
+ using dtype = typename loco::DataTypeImpl<DType>::Type;
+ Shape base_shape = {2, 3, 1, 2};
+ std::vector<Shape> test_shapes{{1, 1, 3, 2}, {1, 3, 1, 2}, {2, 1, 3, 1}, {2, 3, 1, 1}};
+
+ dtype max_value = std::numeric_limits<dtype>::max();
+ dtype res_max = max_value - max_value % 10;
+
+ std::vector<std::vector<dtype>> test_outputs = {
+ {8, 0, 20, 0, 4, 30, //
+ 16, 0, 40, 3, 8, 0, //
+ 0, 0, 0, 6, 0, 0, //
+ 4, 0, 10, 9, 2, 0, //
+ 40, 0, 100, 0, 20, 150, //
+ 28, 0, 70, 0, 14, res_max},
+ {8, 0, 40, 3, 0, 0, 4, 0, 100, 0, 14, res_max},
+ {8, 12, 0, 0, 20, 30, 16, 0, 0, 0, 40, 0, 0, 0, 0, 0, 0,
+ 0, 0, 9, 2, 0, 10, 0, 0, 0, 20, 30, 100, 150, 0, 0, 14, max_value / 10 * 2,
+ 70, res_max},
+ {8, 12, 0, 0, 0, 0, 0, 9, 20, 30, 70, res_max}};
+ std::vector<dtype> input1_data{2, 3, 4, -1, -3, -2, 1, -3, 10, 15, 7, max_value / 10};
+ std::vector<dtype> input2_data{4, 0, 10, -3, 2, 10};
+ for (size_t i = 0; i < test_shapes.size(); ++i)
+ {
+ Tensor input1_tensor = makeInputTensor<DType>(base_shape, input1_data, memory_manager);
+ Tensor input2_tensor = makeInputTensor<DType>(test_shapes[i], input2_data, memory_manager);
+ Tensor output_tensor = makeOutputTensor(DType);
+
+ MulParams params{};
+ params.activation = Activation::RELU;
+
+ Mul kernel(&input1_tensor, &input2_tensor, &output_tensor, params);
+ kernel.configure();
+ memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<dtype>(output_tensor), test_outputs[i])
+ << "With shape number " << i;
+ }
+ // Re-run with exchanged inputs.
+ for (size_t i = 0; i < test_shapes.size(); ++i)
+ {
+ Tensor input1_tensor = makeInputTensor<DType>(test_shapes[i], input2_data, memory_manager);
+ Tensor input2_tensor = makeInputTensor<DType>(base_shape, input1_data, memory_manager);
+ Tensor output_tensor = makeOutputTensor(DType);
+
+ MulParams params{};
+ params.activation = Activation::RELU;
+
+ Mul kernel(&input1_tensor, &input2_tensor, &output_tensor, params);
+ kernel.configure();
+ memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<dtype>(output_tensor), test_outputs[i])
+ << "With shape number " << i;
+ }
+}
+
+TEST_F(MulTest, SInt64)
+{
+ checkInteger<loco::DataType::S64>(_memory_manager.get());
+ SUCCEED();
+}
+
+TEST_F(MulTest, SInt32)
+{
+ checkInteger<loco::DataType::S32>(_memory_manager.get());
+ SUCCEED();
+}
+
TEST_F(MulTest, SInt16)
{
Shape base_shape = {2, 3, 1, 2};
@@ -161,6 +233,60 @@ TEST_F(MulTest, SInt16)
}
}
+TEST_F(MulTest, Input_Output_Type_NEG)
+{
+ Tensor input1_tensor = makeInputTensor<DataType::FLOAT32>({1}, {1.f}, _memory_manager.get());
+ Tensor input2_tensor = makeInputTensor<DataType::S32>({1}, {2}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+
+ MulParams params{};
+ params.activation = Activation::RELU;
+
+ Mul kernel(&input1_tensor, &input2_tensor, &output_tensor, params);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(MulTest, Invalid_Output_Type_NEG)
+{
+ Tensor input1_tensor = makeInputTensor<DataType::S64>({1}, {1}, _memory_manager.get());
+ Tensor input2_tensor = makeInputTensor<DataType::S64>({1}, {2}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::S32);
+
+ MulParams params{};
+ params.activation = Activation::RELU;
+
+ Mul kernel(&input1_tensor, &input2_tensor, &output_tensor, params);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(MulTest, Invalid_Input_Type_NEG)
+{
+ Tensor input1_tensor = makeInputTensor<DataType::U64>({1}, {1}, _memory_manager.get());
+ Tensor input2_tensor = makeInputTensor<DataType::U64>({1}, {2}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::U64);
+
+ MulParams params{};
+ params.activation = Activation::RELU;
+
+ Mul kernel(&input1_tensor, &input2_tensor, &output_tensor, params);
+ kernel.configure();
+ _memory_manager->allocate_memory(output_tensor);
+ EXPECT_ANY_THROW(kernel.execute());
+}
+
+TEST_F(MulTest, Invalid_Quantization_NEG)
+{
+ Tensor input1_tensor = makeInputTensor<DataType::S16>({1}, {1}, _memory_manager.get());
+ Tensor input2_tensor = makeInputTensor<DataType::S16>({1}, {2}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::S16);
+
+ MulParams params{};
+ params.activation = Activation::NONE;
+
+ Mul kernel(&input1_tensor, &input2_tensor, &output_tensor, params);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
} // namespace
} // namespace kernels
} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/NotEqual.cpp b/compiler/luci-interpreter/src/kernels/NotEqual.cpp
index 99d5e0fa0..54e5eee34 100644
--- a/compiler/luci-interpreter/src/kernels/NotEqual.cpp
+++ b/compiler/luci-interpreter/src/kernels/NotEqual.cpp
@@ -49,6 +49,12 @@ void NotEqual::execute() const
case DataType::FLOAT32:
evalFloat();
break;
+ case DataType::S64:
+ evalInteger<int64_t>();
+ break;
+ case DataType::S32:
+ evalInteger<int32_t>();
+ break;
case DataType::U8:
evalQuantized();
break;
@@ -79,6 +85,29 @@ void NotEqual::evalFloat() const
}
}
+template <typename T> void NotEqual::evalInteger() const
+{
+ const auto x_data = getTensorData<T>(x());
+ const auto y_data = getTensorData<T>(y());
+ auto output_data = getTensorData<bool>(output());
+
+ tflite::ComparisonParams op_params;
+ op_params.is_broadcast = x()->shape() != y()->shape();
+
+ if (op_params.is_broadcast)
+ {
+ tflite::reference_ops::Broadcast4DSlowNotEqualNoScaling(op_params, getTensorShape(x()), x_data,
+ getTensorShape(y()), y_data,
+ getTensorShape(output()), output_data);
+ }
+ else
+ {
+ tflite::reference_ops::NotEqualNoScaling(op_params, getTensorShape(x()), x_data,
+ getTensorShape(y()), y_data, getTensorShape(output()),
+ output_data);
+ }
+}
+
void NotEqual::evalQuantized() const
{
const auto x_data = getTensorData<uint8_t>(x());
diff --git a/compiler/luci-interpreter/src/kernels/NotEqual.h b/compiler/luci-interpreter/src/kernels/NotEqual.h
index 247874df7..d2aafe893 100644
--- a/compiler/luci-interpreter/src/kernels/NotEqual.h
+++ b/compiler/luci-interpreter/src/kernels/NotEqual.h
@@ -38,6 +38,7 @@ public:
private:
void evalFloat() const;
+ template <typename T> void evalInteger() const;
void evalQuantized() const;
private:
diff --git a/compiler/luci-interpreter/src/kernels/NotEqual.test.cpp b/compiler/luci-interpreter/src/kernels/NotEqual.test.cpp
index 763f86893..45bf4022a 100644
--- a/compiler/luci-interpreter/src/kernels/NotEqual.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/NotEqual.test.cpp
@@ -99,6 +99,82 @@ TEST_F(NotEqualTest, FloatBroardcast)
EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({4, 3}));
}
+template <loco::DataType DType>
+void checkIntegerSimple(luci_interpreter::IMemoryManager *memory_manager)
+{
+ using dtype = typename loco::DataTypeImpl<DType>::Type;
+ dtype min_value = std::numeric_limits<dtype>::min();
+ dtype max_value = std::numeric_limits<dtype>::max();
+ std::vector<dtype> x_data{min_value, 2, max_value};
+
+ std::vector<dtype> y_data{min_value, -2, max_value};
+
+ std::vector<bool> ref_output_data{false, true, false};
+
+ Tensor x_tensor = makeInputTensor<DType>({3}, x_data, memory_manager);
+ Tensor y_tensor = makeInputTensor<DType>({3}, y_data, memory_manager);
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ NotEqual kernel(&x_tensor, &y_tensor, &output_tensor);
+ kernel.configure();
+ memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<bool>(output_tensor), ::testing::ElementsAreArray(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({3}));
+}
+
+template <loco::DataType DType>
+void checkIntegerBroadcast(luci_interpreter::IMemoryManager *memory_manager)
+{
+ using dtype = typename loco::DataTypeImpl<DType>::Type;
+ dtype min_value = std::numeric_limits<dtype>::min();
+ dtype max_value = std::numeric_limits<dtype>::max();
+ std::vector<dtype> x_data{
+ min_value, 2, 3, // Row 1
+ 4, 5, max_value, // Row 2
+ -1, -2, -3, // Row 3
+ min_value, -2, max_value, // Row 4
+ };
+
+ std::vector<dtype> y_data{
+ min_value, -2, max_value, // Row 1
+ };
+
+ std::vector<bool> ref_output_data{
+ false, true, true, // Row 1
+ true, true, false, // Row 2
+ true, false, true, // Row 3
+ false, false, false, // Row 4
+ };
+
+ Tensor x_tensor = makeInputTensor<DType>({4, 3}, x_data, memory_manager);
+ Tensor y_tensor = makeInputTensor<DType>({3}, y_data, memory_manager);
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ NotEqual kernel(&x_tensor, &y_tensor, &output_tensor);
+ kernel.configure();
+ memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<bool>(output_tensor), ::testing::ElementsAreArray(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({4, 3}));
+}
+
+TEST_F(NotEqualTest, Int32)
+{
+ checkIntegerSimple<loco::DataType::S32>(_memory_manager.get());
+ checkIntegerBroadcast<loco::DataType::S32>(_memory_manager.get());
+ SUCCEED();
+}
+
+TEST_F(NotEqualTest, Int64)
+{
+ checkIntegerSimple<loco::DataType::S64>(_memory_manager.get());
+ checkIntegerBroadcast<loco::DataType::S64>(_memory_manager.get());
+ SUCCEED();
+}
+
// Choose min / max in such a way that there are exactly 256 units to avoid rounding errors.
const float F_MIN = -128.0 / 128.0;
const float F_MAX = 127.0 / 128.0;
@@ -195,6 +271,36 @@ TEST_F(NotEqualTest, Input_Output_Type_NEG)
EXPECT_ANY_THROW(kernel.configure());
}
+TEST_F(NotEqualTest, Float_Broadcast_NEG)
+{
+ Tensor x_tensor = makeInputTensor<DataType::FLOAT32>({2}, {1.f, 2.f}, _memory_manager.get());
+ Tensor y_tensor = makeInputTensor<DataType::FLOAT32>({3}, {1.f, 2.f, 3.f}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ NotEqual kernel(&x_tensor, &y_tensor, &output_tensor);
+ ASSERT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(NotEqualTest, Int32_Broadcast_NEG)
+{
+ Tensor x_tensor = makeInputTensor<DataType::S32>({2}, {1, 2}, _memory_manager.get());
+ Tensor y_tensor = makeInputTensor<DataType::S32>({3}, {1, 2, 3}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ NotEqual kernel(&x_tensor, &y_tensor, &output_tensor);
+ ASSERT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(NotEqualTest, Int64_Broadcast_NEG)
+{
+ Tensor x_tensor = makeInputTensor<DataType::S64>({2}, {1, 2}, _memory_manager.get());
+ Tensor y_tensor = makeInputTensor<DataType::S64>({3}, {1, 2, 3}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::BOOL);
+
+ NotEqual kernel(&x_tensor, &y_tensor, &output_tensor);
+ ASSERT_ANY_THROW(kernel.configure());
+}
+
} // namespace
} // namespace kernels
} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/OneHot.cpp b/compiler/luci-interpreter/src/kernels/OneHot.cpp
new file mode 100644
index 000000000..4d3e5f2ef
--- /dev/null
+++ b/compiler/luci-interpreter/src/kernels/OneHot.cpp
@@ -0,0 +1,136 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "kernels/OneHot.h"
+#include "kernels/Utils.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+
+namespace
+{
+
+template <typename T>
+void OneHotComputeImpl(const Tensor *indices_tensor, const Tensor *on_value_tensor,
+ const Tensor *off_value_tensor, int32_t depth, int32_t axis,
+ Tensor *output_tensor)
+{
+ // define input shape and correct axis
+ auto const &input_shape = indices_tensor->shape();
+ axis = axis == -1 ? input_shape.num_dims() : axis;
+
+ // TODO support other integer input types
+ auto const *indices = getTensorData<int32_t>(indices_tensor);
+ auto const on_value = getTensorData<T>(on_value_tensor)[0];
+ auto const off_value = getTensorData<T>(off_value_tensor)[0];
+ auto *output = getTensorData<T>(output_tensor);
+
+ // prefix_dim_size == # of elements before the axis
+ // depth == # of elements per axis
+ // suffix_dim_size == # of elements after the axis
+ auto prefix_dim_size = 1;
+ for (int32_t i = 0; i < axis; ++i)
+ {
+ prefix_dim_size *= input_shape.dim(i);
+ }
+ assert(prefix_dim_size > 0);
+ auto const suffix_dim_size = input_shape.num_elements() / prefix_dim_size;
+
+ // View the indices as a matrix of size:
+ // prefix_dim_size x suffix_dim_size
+ // View the output as a matrix of size:
+ // prefix_dim_size x depth x suffix_dim_size
+ // Then the output is:
+ // output(i, j, k) == (indices(i, k) == j) ? on : off
+ for (int32_t i = 0; i < prefix_dim_size; ++i)
+ for (int32_t j = 0; j < depth; ++j)
+ for (int32_t k = 0; k < suffix_dim_size; ++k, ++output)
+ *output = indices[i * suffix_dim_size + k] == j ? on_value : off_value;
+}
+
+} // namespace
+
+OneHot::OneHot(const Tensor *indices, const Tensor *depth, const Tensor *on_value,
+ const Tensor *off_value, Tensor *output, const OneHotParams &params)
+ : KernelWithParams<OneHotParams>({indices, depth, on_value, off_value}, {output}, params)
+{
+ // Do nothing
+}
+
+void OneHot::configure()
+{
+ // check types
+ LUCI_INTERPRETER_CHECK(indices()->element_type() == DataType::S32);
+ LUCI_INTERPRETER_CHECK(depth()->element_type() == DataType::S32);
+ LUCI_INTERPRETER_CHECK(on_value()->element_type() == off_value()->element_type());
+ LUCI_INTERPRETER_CHECK(output()->element_type() == on_value()->element_type());
+
+ // check shape dependent parameters
+ LUCI_INTERPRETER_CHECK(on_value()->shape().num_elements() == 1);
+ LUCI_INTERPRETER_CHECK(off_value()->shape().num_elements() == 1);
+ LUCI_INTERPRETER_CHECK(depth()->shape().num_elements() == 1);
+ LUCI_INTERPRETER_CHECK(params().axis >= -1 && params().axis <= indices()->shape().num_dims());
+
+ // define parameters that affect the output shape
+ auto const depth_value = getTensorData<int32_t>(depth())[0];
+ auto const &input_shape = indices()->shape();
+ auto const input_dims = input_shape.num_dims();
+ auto const axis = params().axis == -1 ? input_dims : params().axis;
+
+ // define output shape
+ Shape output_shape(input_shape.num_dims() + 1);
+ {
+ for (int32_t d = 0; d < axis; ++d)
+ output_shape.dim(d) = input_shape.dim(d);
+
+ output_shape.dim(axis) = depth_value;
+
+ for (int32_t d = axis + 1; d < output_shape.num_dims(); ++d)
+ output_shape.dim(d) = input_shape.dim(d - 1);
+ }
+
+ // reshape output
+ output()->resize(output_shape);
+}
+
+void OneHot::execute() const
+{
+ auto const depth_value = getTensorData<int32_t>(depth())[0];
+ auto const axis = params().axis;
+
+ switch (output()->element_type())
+ {
+ case loco::DataType::FLOAT32:
+ OneHotComputeImpl<float>(indices(), on_value(), off_value(), depth_value, axis, output());
+ break;
+ case loco::DataType::U8:
+ OneHotComputeImpl<uint8_t>(indices(), on_value(), off_value(), depth_value, axis, output());
+ break;
+ case loco::DataType::S16:
+ OneHotComputeImpl<int16_t>(indices(), on_value(), off_value(), depth_value, axis, output());
+ break;
+ default:
+ // TODO Support other data types
+ throw std::runtime_error("Not supported, yet!");
+ break;
+ }
+}
+
+} // namespace kernels
+} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/OneHot.h b/compiler/luci-interpreter/src/kernels/OneHot.h
new file mode 100644
index 000000000..572f857ae
--- /dev/null
+++ b/compiler/luci-interpreter/src/kernels/OneHot.h
@@ -0,0 +1,48 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_KERNELS_ONEHOT_H
+#define LUCI_INTERPRETER_KERNELS_ONEHOT_H
+
+#include "core/Kernel.h"
+#include "core/KernelParams.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+
+class OneHot : public KernelWithParams<OneHotParams>
+{
+public:
+ OneHot(const Tensor *indices, const Tensor *depth, const Tensor *on_value,
+ const Tensor *off_value, Tensor *output, const OneHotParams &params);
+
+ const Tensor *indices() const { return _inputs[0]; }
+ const Tensor *depth() const { return _inputs[1]; }
+ const Tensor *on_value() const { return _inputs[2]; }
+ const Tensor *off_value() const { return _inputs[3]; }
+
+ Tensor *output() const { return _outputs[0]; }
+
+ void configure() override;
+ void execute() const override;
+};
+
+} // namespace kernels
+} // namespace luci_interpreter
+
+#endif // LUCI_INTERPRETER_KERNELS_ONEHOT_H
diff --git a/compiler/luci-interpreter/src/kernels/OneHot.test.cpp b/compiler/luci-interpreter/src/kernels/OneHot.test.cpp
new file mode 100644
index 000000000..45b6968fa
--- /dev/null
+++ b/compiler/luci-interpreter/src/kernels/OneHot.test.cpp
@@ -0,0 +1,192 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "kernels/OneHot.h"
+#include "kernels/TestUtils.h"
+#include "luci_interpreter/TestMemoryManager.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+namespace
+{
+
+using namespace testing;
+
+template <typename T1, typename T2>
+void Check(std::initializer_list<int32_t> input_shape, std::initializer_list<int32_t> output_shape,
+ std::initializer_list<T1> input_data, std::initializer_list<int32_t> depth_data,
+ std::initializer_list<T2> on_value_data, std::initializer_list<T2> off_value_data,
+ int32_t axis, std::initializer_list<T2> output_data)
+{
+ std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
+
+ constexpr auto input_type = getElementType<T1>();
+ constexpr auto output_type = getElementType<T2>();
+
+ Tensor input_tensor = makeInputTensor<input_type>(input_shape, input_data, memory_manager.get());
+ Tensor depth_tensor = makeInputTensor<DataType::S32>({}, depth_data, memory_manager.get());
+ Tensor on_value_tensor = makeInputTensor<output_type>({}, on_value_data, memory_manager.get());
+ Tensor off_value_tensor = makeInputTensor<output_type>({}, off_value_data, memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(output_type);
+
+ OneHotParams params{};
+ params.axis = axis;
+
+ OneHot kernel(&input_tensor, &depth_tensor, &on_value_tensor, &off_value_tensor, &output_tensor,
+ params);
+ kernel.configure();
+ memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorShape(output_tensor), output_shape);
+ EXPECT_THAT(extractTensorData<T2>(output_tensor), ::testing::ElementsAreArray(output_data));
+}
+
+template <typename T> class OneHotTest : public ::testing::Test
+{
+};
+
+using DataTypes = ::testing::Types<float, uint8_t, int16_t>;
+TYPED_TEST_SUITE(OneHotTest, DataTypes);
+
+TYPED_TEST(OneHotTest, BasicPattern)
+{
+ // axis 0
+ Check<int32_t, TypeParam>(/*input_shape=*/{2, 3}, /*output_shape=*/{4, 2, 3},
+ /*input_data=*/
+ {
+ 0, 3, 5, //
+ 7, 3, 0, //
+ },
+ /*depth_data=*/{4}, /*on_value_data=*/{1}, /*off_value_data=*/{0},
+ /*axis=*/0,
+ /*output_data=*/
+ {
+ 1, 0, 0, //
+ 0, 0, 1, //
+
+ 0, 0, 0, //
+ 0, 0, 0, //
+
+ 0, 0, 0, //
+ 0, 0, 0, //
+
+ 0, 1, 0, //
+ 0, 1, 0, //
+ });
+ // axis 1
+ Check<int32_t, TypeParam>(/*input_shape=*/{2, 3}, /*output_shape=*/{2, 4, 3},
+ /*input_data=*/
+ {
+ 0, 3, 5, //
+ 7, 3, 0, //
+ },
+ /*depth_data=*/{4}, /*on_value_data=*/{1}, /*off_value_data=*/{0},
+ /*axis=*/1,
+ /*output_data=*/
+ {
+ 1, 0, 0, //
+ 0, 0, 0, //
+ 0, 0, 0, //
+ 0, 1, 0, //
+
+ 0, 0, 1, //
+ 0, 0, 0, //
+ 0, 0, 0, //
+ 0, 1, 0, //
+ });
+ // axis -1
+ Check<int32_t, TypeParam>(/*input_shape=*/{2, 3}, /*output_shape=*/{2, 3, 4},
+ /*input_data=*/
+ {
+ 0, 3, 5, //
+ 7, 3, 0, //
+ },
+ /*depth_data=*/{4}, /*on_value_data=*/{1}, /*off_value_data=*/{0},
+ /*axis=*/-1,
+ /*output_data=*/
+ {
+ 1, 0, 0, 0, //
+ 0, 0, 0, 1, //
+ 0, 0, 0, 0, //
+
+ 0, 0, 0, 0, //
+ 0, 0, 0, 1, //
+ 1, 0, 0, 0, //
+ });
+}
+
+TEST(OneHotTest, UnsupportedInputType_NEG)
+{
+ std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
+
+ // input type should be integer
+ Tensor input_tensor = makeInputTensor<DataType::FLOAT32>({1}, {0}, memory_manager.get());
+
+ Tensor depth_tensor = makeInputTensor<DataType::S32>({}, {1}, memory_manager.get());
+ Tensor on_value_tensor = makeInputTensor<DataType::FLOAT32>({}, {1.0}, memory_manager.get());
+ Tensor off_value_tensor = makeInputTensor<DataType::FLOAT32>({}, {0.0}, memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+
+ OneHotParams params = {-1};
+
+ OneHot kernel(&input_tensor, &depth_tensor, &on_value_tensor, &off_value_tensor, &output_tensor,
+ params);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+TEST(OneHotTest, OutputTypeMismatch_NEG)
+{
+ std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
+
+ Tensor input_tensor = makeInputTensor<DataType::S32>({1}, {0}, memory_manager.get());
+ Tensor depth_tensor = makeInputTensor<DataType::S32>({}, {1}, memory_manager.get());
+
+ // type of on_value, off_value and output_tensor should be same
+ Tensor on_value_tensor = makeInputTensor<DataType::FLOAT32>({}, {1.0}, memory_manager.get());
+ Tensor off_value_tensor = makeInputTensor<DataType::FLOAT32>({}, {0.0}, memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::S16);
+
+ OneHotParams params = {-1};
+
+ OneHot kernel(&input_tensor, &depth_tensor, &on_value_tensor, &off_value_tensor, &output_tensor,
+ params);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+TEST(OneHotTest, InvalidAxis_NEG)
+{
+ std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
+
+ Tensor input_tensor = makeInputTensor<DataType::S32>({1}, {0}, memory_manager.get());
+ Tensor depth_tensor = makeInputTensor<DataType::S32>({}, {1}, memory_manager.get());
+ Tensor on_value_tensor = makeInputTensor<DataType::FLOAT32>({}, {1.0}, memory_manager.get());
+ Tensor off_value_tensor = makeInputTensor<DataType::FLOAT32>({}, {0.0}, memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+
+ // axis should be in [-1, input_shape.rank]
+ OneHotParams params = {-2};
+
+ OneHot kernel(&input_tensor, &depth_tensor, &on_value_tensor, &off_value_tensor, &output_tensor,
+ params);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+} // namespace
+} // namespace kernels
+} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/Pack.test.cpp b/compiler/luci-interpreter/src/kernels/Pack.test.cpp
index 90a0f894e..2404e4303 100644
--- a/compiler/luci-interpreter/src/kernels/Pack.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/Pack.test.cpp
@@ -80,7 +80,7 @@ template <typename T> class PackTest : public ::testing::Test
};
using DataTypes = ::testing::Types<uint8_t, float>;
-TYPED_TEST_CASE(PackTest, DataTypes);
+TYPED_TEST_SUITE(PackTest, DataTypes);
TYPED_TEST(PackTest, ThreeInputs)
{
diff --git a/compiler/luci-interpreter/src/kernels/Pad.cpp b/compiler/luci-interpreter/src/kernels/Pad.cpp
index 700448e7a..fe172884b 100644
--- a/compiler/luci-interpreter/src/kernels/Pad.cpp
+++ b/compiler/luci-interpreter/src/kernels/Pad.cpp
@@ -93,6 +93,16 @@ void Pad::execute() const
getTensorData<uint8_t>(output()));
break;
}
+ case DataType::S8:
+ {
+ assert(output()->zero_point() >= std::numeric_limits<int8_t>::min());
+ assert(output()->zero_point() <= std::numeric_limits<int8_t>::max());
+ const auto pad_value = static_cast<int8_t>(output()->zero_point());
+ tflite::reference_ops::Pad(params, getTensorShape(input()), getTensorData<int8_t>(input()),
+ &pad_value, getTensorShape(output()),
+ getTensorData<int8_t>(output()));
+ break;
+ }
default:
throw std::runtime_error("Unsupported type.");
}
diff --git a/compiler/luci-interpreter/src/kernels/Pad.test.cpp b/compiler/luci-interpreter/src/kernels/Pad.test.cpp
index 7994263e2..dd3ce947c 100644
--- a/compiler/luci-interpreter/src/kernels/Pad.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/Pad.test.cpp
@@ -54,6 +54,32 @@ TEST(Pad, Uint8)
EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({1, 4, 7, 1}));
}
+TEST(Pad, Int8)
+{
+ std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
+ float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
+ std::pair<float, int32_t> quant_param = quantizationParams<int8_t>(-1.0f, 1.0f);
+ std::vector<float> input_data{-0.2, 0.4, 0.5, -0.7, -0.1, -0.9, 0.7, 0.1, 0.2};
+ std::vector<int32_t> paddings_data{0, 0, 1, 2, 2, 1, 0, 0};
+ Tensor input_tensor = makeInputTensor<DataType::S8>(
+ {1, 3, 3, 1}, quant_param.first, quant_param.second, input_data, memory_manager.get());
+ Tensor paddings_tensor =
+ makeInputTensor<DataType::S32>({4, 2}, paddings_data, memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::S8, quant_param.first, quant_param.second);
+
+ Pad kernel(&input_tensor, &paddings_tensor, &output_tensor);
+ kernel.configure();
+ memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ std::vector<float> ref_output_data{0, 0, 0, 0, 0, 0, 0, 0, -0.2, 0.4, 0.5, 0,
+ 0, 0, -0.7, -0.1, -0.9, 0, 0, 0, 0.7, 0.1, 0.2, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
+ EXPECT_THAT(dequantizeTensorData(output_tensor),
+ FloatArrayNear(ref_output_data, kQuantizedTolerance));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({1, 6, 6, 1}));
+}
+
TEST(Pad, Float)
{
std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
diff --git a/compiler/luci-interpreter/src/kernels/Quantize.cpp b/compiler/luci-interpreter/src/kernels/Quantize.cpp
new file mode 100644
index 000000000..0c8544a65
--- /dev/null
+++ b/compiler/luci-interpreter/src/kernels/Quantize.cpp
@@ -0,0 +1,160 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "kernels/Quantize.h"
+#include "kernels/Utils.h"
+#include "PALQuantize.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+
+namespace
+{
+
+template <typename input_dtype> void call_requantize(const Tensor *input, Tensor *output)
+{
+ int32_t multiplier;
+ int shift;
+
+ const double effective_output_scale = input->scale() / output->scale();
+ quantizeMultiplier(effective_output_scale, &multiplier, &shift);
+
+ const auto input_shape = getTensorShape(input);
+ const auto output_shape = getTensorShape(output);
+ const auto size = tflite::MatchingFlatSize(input_shape, output_shape);
+
+ const auto input_data = getTensorData<input_dtype>(input);
+
+ switch (output->element_type())
+ {
+ case loco::DataType::S8:
+ luci_interpreter_pal::Requantize(input_data, size, multiplier, shift, input->zero_point(),
+ output->zero_point(), getTensorData<int8_t>(output));
+ break;
+ case loco::DataType::U8:
+ luci_interpreter_pal::Requantize(input_data, size, multiplier, shift, input->zero_point(),
+ output->zero_point(), getTensorData<uint8_t>(output));
+ break;
+ case loco::DataType::S16:
+ luci_interpreter_pal::Requantize(input_data, size, multiplier, shift, input->zero_point(),
+ output->zero_point(), getTensorData<int16_t>(output));
+ break;
+ default:
+ throw std::runtime_error("Unsupported quantized type, yet!");
+ }
+}
+
+} // namespace
+
+Quantize::Quantize(const Tensor *input, Tensor *output) : Kernel({input}, {output}) {}
+
+void Quantize::configure()
+{
+
+ if (input()->element_type() == loco::DataType::S16)
+ LUCI_INTERPRETER_CHECK(input()->zero_point() == 0);
+
+ switch (input()->element_type())
+ {
+ case loco::DataType::FLOAT32:
+ {
+ LUCI_INTERPRETER_CHECK(output()->element_type() == loco::DataType::U8 ||
+ output()->element_type() == loco::DataType::S8 ||
+ output()->element_type() == loco::DataType::S16);
+ break;
+ }
+ case loco::DataType::S16:
+ case loco::DataType::S8:
+ case loco::DataType::U8:
+ {
+ LUCI_INTERPRETER_CHECK(output()->element_type() == loco::DataType::S8 ||
+ output()->element_type() == loco::DataType::U8 ||
+ output()->element_type() == loco::DataType::S16);
+ if (output()->element_type() == loco::DataType::S16)
+ {
+ LUCI_INTERPRETER_CHECK(output()->zero_point() == 0);
+ }
+ break;
+ }
+ default:
+ throw std::runtime_error("Unsupported type");
+ }
+
+ output()->resize(input()->shape());
+}
+
+void Quantize::execute() const
+{
+ switch (input()->element_type())
+ {
+ case loco::DataType::FLOAT32:
+ {
+ tflite::QuantizationParams op_params;
+ op_params.zero_point = output()->zero_point();
+ op_params.scale = output()->scale();
+ const auto input_data = getTensorData<float>(input());
+
+ switch (output()->element_type())
+ {
+ case loco::DataType::S8:
+ {
+ luci_interpreter_pal::Quantize(op_params, getTensorShape(input()), input_data,
+ getTensorShape(output()), getTensorData<int8_t>(output()));
+ break;
+ }
+ case loco::DataType::U8:
+ {
+ luci_interpreter_pal::Quantize(op_params, getTensorShape(input()), input_data,
+ getTensorShape(output()),
+ getTensorData<uint8_t>(output()));
+ break;
+ }
+ case loco::DataType::S16:
+ {
+ luci_interpreter_pal::Quantize(op_params, getTensorShape(input()), input_data,
+ getTensorShape(output()),
+ getTensorData<int16_t>(output()));
+ break;
+ }
+ default:
+ throw std::runtime_error("Unsupported type.");
+ }
+ break;
+ }
+ case loco::DataType::S16:
+ {
+ call_requantize<int16_t>(input(), output());
+ break;
+ }
+ case loco::DataType::S8:
+ {
+ call_requantize<int8_t>(input(), output());
+ break;
+ }
+ case loco::DataType::U8:
+ {
+ call_requantize<uint8_t>(input(), output());
+ break;
+ }
+ default:
+ throw std::runtime_error("Unsupported type.");
+ }
+}
+
+} // namespace kernels
+} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/Quantize.h b/compiler/luci-interpreter/src/kernels/Quantize.h
new file mode 100644
index 000000000..006c5366f
--- /dev/null
+++ b/compiler/luci-interpreter/src/kernels/Quantize.h
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_KERNELS_QUANTIZE_H
+#define LUCI_INTERPRETER_KERNELS_QUANTIZE_H
+
+#include "core/Kernel.h"
+#include "core/KernelParams.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+
+class Quantize : public Kernel
+{
+public:
+ Quantize(const Tensor *input, Tensor *output);
+
+ const Tensor *input() const { return _inputs[0]; }
+ Tensor *output() const { return _outputs[0]; }
+
+ void configure() override;
+ void execute() const override;
+};
+
+} // namespace kernels
+} // namespace luci_interpreter
+
+#endif // LUCI_INTERPRETER_KERNELS_QUANTIZE_H
diff --git a/compiler/luci-interpreter/src/kernels/Quantize.test.cpp b/compiler/luci-interpreter/src/kernels/Quantize.test.cpp
new file mode 100644
index 000000000..22e67fe3f
--- /dev/null
+++ b/compiler/luci-interpreter/src/kernels/Quantize.test.cpp
@@ -0,0 +1,254 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "kernels/Quantize.h"
+#include "kernels/TestUtils.h"
+#include "luci_interpreter/TestMemoryManager.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+namespace
+{
+
+using namespace testing;
+
+class QuantizeTest : public ::testing::Test
+{
+protected:
+ void SetUp() override { _memory_manager = std::make_unique<TestMemoryManager>(); }
+
+ std::unique_ptr<IMemoryManager> _memory_manager;
+};
+
+TEST_F(QuantizeTest, FloatUint8)
+{
+ std::vector<float> input_data{-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64};
+
+ std::vector<uint8_t> ref_output_data{0, 1, 2, 3, 4, 251, 252, 253, 254, 255};
+
+ Tensor input_tensor =
+ makeInputTensor<DataType::FLOAT32>({2, 5}, input_data, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::U8, /*scale*/ 0.5, /*zero_point*/ 127);
+
+ Quantize kernel(&input_tensor, &output_tensor);
+ kernel.configure();
+ _memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<uint8_t>(output_tensor),
+ ::testing::ElementsAreArray(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({2, 5}));
+}
+
+TEST_F(QuantizeTest, FloatInt8)
+{
+ std::vector<float> input_data{-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64};
+
+ std::vector<int8_t> ref_output_data{-128, -127, -126, -125, -124, 123, 124, 125, 126, 127};
+
+ Tensor input_tensor =
+ makeInputTensor<DataType::FLOAT32>({2, 5}, input_data, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::S8, /*scale*/ 0.5, /*zero_point*/ -1);
+
+ Quantize kernel(&input_tensor, &output_tensor);
+ kernel.configure();
+ _memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<int8_t>(output_tensor),
+ ::testing::ElementsAreArray(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({2, 5}));
+}
+
+TEST_F(QuantizeTest, FloatInt16)
+{
+ std::vector<float> input_data{-63.5, -63, -3, -2, -1, 1, 2, 3, 63.5, 64};
+
+ std::vector<int16_t> ref_output_data{-12700, -12600, -600, -400, -200,
+ 200, 400, 600, 12700, 12800};
+
+ Tensor input_tensor =
+ makeInputTensor<DataType::FLOAT32>({2, 5}, input_data, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::S16, /*scale*/ 0.005, /*zero_point*/ 0);
+
+ Quantize kernel(&input_tensor, &output_tensor);
+ kernel.configure();
+ _memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<int16_t>(output_tensor),
+ ::testing::ElementsAreArray(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({2, 5}));
+}
+
+TEST_F(QuantizeTest, Int16Int16)
+{
+ std::vector<float> input_data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
+
+ std::vector<int16_t> ref_output_data{2, 4, 6, 8, 10, 12, 14, 16, 18, 20};
+
+ Tensor input_tensor = makeInputTensor<DataType::S16>(
+ {1, 1, 2, 5}, /*scale*/ 1.0, /*zero_point*/ 0, input_data, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::S16, /*scale*/ 0.5, /*zero_point*/ 0);
+
+ Quantize kernel(&input_tensor, &output_tensor);
+ kernel.configure();
+ _memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<int16_t>(output_tensor),
+ ::testing::ElementsAreArray(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({1, 1, 2, 5}));
+}
+
+TEST_F(QuantizeTest, Int8Int8)
+{
+ std::vector<float> input_data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
+
+ std::vector<int8_t> ref_output_data{1, 3, 5, 7, 9, 11, 13, 15, 17, 19};
+
+ Tensor input_tensor = makeInputTensor<DataType::S8>(
+ {1, 1, 2, 5}, /*scale*/ 0.5, /*zero_point*/ -1, input_data, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::S8, /*scale*/ 0.5, /*zero_point*/ -1);
+
+ Quantize kernel(&input_tensor, &output_tensor);
+ kernel.configure();
+ _memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<int8_t>(output_tensor),
+ ::testing::ElementsAreArray(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({1, 1, 2, 5}));
+}
+
+TEST_F(QuantizeTest, Uint8Uint8)
+{
+ std::vector<float> input_data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
+
+ std::vector<uint8_t> ref_output_data{129, 131, 133, 135, 137, 139, 141, 143, 145, 147};
+
+ Tensor input_tensor = makeInputTensor<DataType::U8>(
+ {1, 1, 2, 5}, /*scale*/ 0.5, /*zero_point*/ 127, input_data, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::U8, /*scale*/ 0.5, /*zero_point*/ 127);
+
+ Quantize kernel(&input_tensor, &output_tensor);
+ kernel.configure();
+ _memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<uint8_t>(output_tensor),
+ ::testing::ElementsAreArray(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({1, 1, 2, 5}));
+}
+
+TEST_F(QuantizeTest, Int16Int8)
+{
+ std::vector<float> input_data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
+
+ std::vector<int8_t> ref_output_data{1, 3, 5, 7, 9, 11, 13, 15, 17, 19};
+
+ Tensor input_tensor = makeInputTensor<DataType::S16>(
+ {1, 1, 2, 5}, /*scale*/ 1.0, /*zero_point*/ 0, input_data, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::S8, /*scale*/ 0.5, /*zero_point*/ -1);
+
+ Quantize kernel(&input_tensor, &output_tensor);
+ kernel.configure();
+ _memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<int8_t>(output_tensor),
+ ::testing::ElementsAreArray(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({1, 1, 2, 5}));
+}
+
+TEST_F(QuantizeTest, InvalidInputType_NEG)
+{
+ std::vector<float> input_data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
+
+ Tensor input_tensor =
+ makeInputTensor<DataType::S32>({1, 1, 2, 5}, 0.5, 0, input_data, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::S8, /*scale*/ 0.5, /*zero_point*/ -1);
+
+ Quantize kernel(&input_tensor, &output_tensor);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(QuantizeTest, InvalidOutputTypeForFloatInput_NEG)
+{
+ std::vector<float> input_data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
+
+ Tensor input_tensor =
+ makeInputTensor<DataType::FLOAT32>({1, 1, 2, 5}, input_data, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+
+ Quantize kernel(&input_tensor, &output_tensor);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(QuantizeTest, InvalidOutputTypeForInt16Input_NEG)
+{
+ std::vector<float> input_data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
+
+ Tensor input_tensor =
+ makeInputTensor<DataType::S16>({1, 1, 2, 5}, 0.5, 0, input_data, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+
+ Quantize kernel(&input_tensor, &output_tensor);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(QuantizeTest, InvalidOutputTypeForInt8Input_NEG)
+{
+ std::vector<float> input_data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
+
+ Tensor input_tensor =
+ makeInputTensor<DataType::S8>({1, 1, 2, 5}, 0.5, 0, input_data, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+
+ Quantize kernel(&input_tensor, &output_tensor);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(QuantizeTest, InvalidOutputTypeForUint8Input_NEG)
+{
+ std::vector<float> input_data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
+
+ Tensor input_tensor =
+ makeInputTensor<DataType::U8>({1, 1, 2, 5}, 0.5, 0, input_data, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::S32);
+
+ Quantize kernel(&input_tensor, &output_tensor);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(QuantizeTest, InvalidInputZeroPoint_NEG)
+{
+ std::vector<float> input_data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
+
+ Tensor input_tensor =
+ makeInputTensor<DataType::S16>({1, 1, 2, 5}, 0.5, -1, input_data, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::S16, 0.5, 0);
+
+ Quantize kernel(&input_tensor, &output_tensor);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+} // namespace
+} // namespace kernels
+} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/ResizeBilinear.test.cpp b/compiler/luci-interpreter/src/kernels/ResizeBilinear.test.cpp
index 7af20f8c4..933a1128c 100644
--- a/compiler/luci-interpreter/src/kernels/ResizeBilinear.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/ResizeBilinear.test.cpp
@@ -90,7 +90,7 @@ template <typename T> class ResizeBilinearTest : public ::testing::Test
};
using DataTypes = ::testing::Types<float, uint8_t>;
-TYPED_TEST_CASE(ResizeBilinearTest, DataTypes);
+TYPED_TEST_SUITE(ResizeBilinearTest, DataTypes);
TYPED_TEST(ResizeBilinearTest, SimpleTest)
{
diff --git a/compiler/luci-interpreter/src/kernels/ResizeNearestNeighbor.test.cpp b/compiler/luci-interpreter/src/kernels/ResizeNearestNeighbor.test.cpp
index 0e9017c78..7ade02a6f 100644
--- a/compiler/luci-interpreter/src/kernels/ResizeNearestNeighbor.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/ResizeNearestNeighbor.test.cpp
@@ -92,7 +92,7 @@ template <typename T> class ResizeNearestNeighborTest : public ::testing::Test
};
using DataTypes = ::testing::Types<float, uint8_t>;
-TYPED_TEST_CASE(ResizeNearestNeighborTest, DataTypes);
+TYPED_TEST_SUITE(ResizeNearestNeighborTest, DataTypes);
TYPED_TEST(ResizeNearestNeighborTest, SimpleTest)
{
diff --git a/compiler/luci-interpreter/src/kernels/ReverseV2.test.cpp b/compiler/luci-interpreter/src/kernels/ReverseV2.test.cpp
index 2bd94875b..c0025faca 100644
--- a/compiler/luci-interpreter/src/kernels/ReverseV2.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/ReverseV2.test.cpp
@@ -33,7 +33,7 @@ template <typename T> class ReverseV2Test : public ::testing::Test
};
using DataTypes = ::testing::Types<float, uint8_t>;
-TYPED_TEST_CASE(ReverseV2Test, DataTypes);
+TYPED_TEST_SUITE(ReverseV2Test, DataTypes);
TYPED_TEST(ReverseV2Test, MultiDimensions)
{
diff --git a/compiler/luci-interpreter/src/kernels/SVDF.cpp b/compiler/luci-interpreter/src/kernels/SVDF.cpp
new file mode 100644
index 000000000..40d79aaa3
--- /dev/null
+++ b/compiler/luci-interpreter/src/kernels/SVDF.cpp
@@ -0,0 +1,241 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "kernels/SVDF.h"
+#include "kernels/Utils.h"
+#include "PALSVDF.h"
+
+#include <tensorflow/lite/kernels/internal/quantization_util.h>
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+
+namespace
+{
+TfLiteFusedActivation get_tflite_activation(Activation activation)
+{
+ switch (activation)
+ {
+ case luci::FusedActFunc::RELU:
+ return kTfLiteActRelu;
+ case luci::FusedActFunc::RELU6:
+ return kTfLiteActRelu6;
+ case luci::FusedActFunc::RELU_N1_TO_1:
+ return kTfLiteActReluN1To1;
+ case luci::FusedActFunc::TANH:
+ return kTfLiteActTanh;
+ case luci::FusedActFunc::SIGN_BIT:
+ return kTfLiteActSignBit;
+ case luci::FusedActFunc::NONE:
+ return kTfLiteActNone;
+ default:
+ throw std::runtime_error("Unsupported activation type");
+ }
+}
+} // namespace
+
+SVDF::SVDF(const Tensor *input, const Tensor *weight_feature, const Tensor *weight_time,
+ const Tensor *bias, const Tensor *input_activation_state, Tensor *output,
+ Tensor *scratchpad_activation_state, Tensor *scratchpad_1, Tensor *scratchpad_2,
+ Tensor *scratchpad_3, Tensor *scratchpad_4, Tensor *scratchpad_5, Tensor *scratchpad_6,
+ const SVDFParams &params)
+ : KernelWithParams<SVDFParams>({input, weight_feature, weight_time, bias, input_activation_state},
+ {output, scratchpad_activation_state, scratchpad_1, scratchpad_2,
+ scratchpad_3, scratchpad_4, scratchpad_5, scratchpad_6},
+ params)
+{
+ // Do nothing
+}
+
+void SVDF::configure()
+{
+ const Shape &input_shape = input()->shape();
+ const Shape &weight_features_shape = weight_feature()->shape();
+ const Shape &weight_time_shape = weight_time()->shape();
+
+ // Validate Input Tensor:
+ LUCI_INTERPRETER_CHECK(input()->element_type() == loco::DataType::FLOAT32 ||
+ input()->element_type() == loco::DataType::S8);
+ LUCI_INTERPRETER_CHECK(input_shape.num_dims() == 2);
+
+ // Validate inputs and output types
+ if (input()->element_type() == loco::DataType::S8)
+ {
+ LUCI_INTERPRETER_CHECK(weight_feature()->element_type() == loco::DataType::S8);
+ LUCI_INTERPRETER_CHECK(weight_time()->element_type() == loco::DataType::S16 ||
+ weight_time()->element_type() == loco::DataType::S8);
+ if (bias())
+ LUCI_INTERPRETER_CHECK(bias()->element_type() == loco::DataType::S32);
+
+ LUCI_INTERPRETER_CHECK(input_activation_state()->element_type() == loco::DataType::S16 ||
+ input_activation_state()->element_type() == loco::DataType::S8);
+ LUCI_INTERPRETER_CHECK(output()->element_type() == loco::DataType::S8);
+
+ // Note: now tflite support only ReLU activation for integer SVDF
+ LUCI_INTERPRETER_CHECK(params().activation == luci::FusedActFunc::RELU);
+ }
+ else if (weight_feature()->element_type() == loco::DataType::FLOAT32)
+ {
+ LUCI_INTERPRETER_CHECK(weight_feature()->element_type() == loco::DataType::FLOAT32);
+ LUCI_INTERPRETER_CHECK(weight_time()->element_type() == loco::DataType::FLOAT32);
+ LUCI_INTERPRETER_CHECK(input_activation_state()->element_type() == loco::DataType::FLOAT32);
+ if (bias())
+ LUCI_INTERPRETER_CHECK(bias()->element_type() == loco::DataType::FLOAT32);
+ LUCI_INTERPRETER_CHECK(output()->element_type() == loco::DataType::FLOAT32);
+ }
+ else if ((weight_feature()->element_type() == loco::DataType::U8 ||
+ weight_feature()->element_type() == loco::DataType::S8) &&
+ input()->element_type() == loco::DataType::FLOAT32)
+ {
+ // TODO:: support hybrid SVDF op
+ throw std::runtime_error("Hybrid type is not currently supported");
+ }
+ else
+ {
+ throw std::runtime_error("Unsupported type.");
+ }
+
+ // Check all the parameters of tensor match within themselves and match the
+ // input configuration.
+ const int rank = params().svdf_rank;
+ const int batch_size = input_shape.dim(0);
+ const int num_filters = weight_features_shape.dim(0);
+ LUCI_INTERPRETER_CHECK(rank != 0);
+ LUCI_INTERPRETER_CHECK(num_filters % rank == 0);
+
+ const int num_units = num_filters / rank;
+ const int memory_size = weight_time_shape.dim(1);
+
+ // Validate Weight_Feature Input Tensor:
+ LUCI_INTERPRETER_CHECK(weight_features_shape.num_dims() == 2);
+ LUCI_INTERPRETER_CHECK(weight_features_shape.dim(1) == input_shape.dim(1));
+
+ // Validate Weight_Time Input Tensor:
+ LUCI_INTERPRETER_CHECK(weight_time_shape.num_dims() == 2);
+ LUCI_INTERPRETER_CHECK(weight_time_shape.dim(0) == num_filters);
+
+ // Validate Bias
+ if (bias())
+ LUCI_INTERPRETER_CHECK(bias()->shape().dim(0) == num_units);
+
+ // Validate Input Activation State
+ LUCI_INTERPRETER_CHECK(input_activation_state()->shape().num_dims() == 2);
+ LUCI_INTERPRETER_CHECK(input_activation_state()->shape().dim(0) == batch_size);
+ LUCI_INTERPRETER_CHECK(input_activation_state()->shape().dim(1) == memory_size * num_filters);
+
+ // Resize scratchpad_state to input_activation_state
+ auto scratchpad_activation_state = getOutputTensors()[1];
+ scratchpad_activation_state->resize({batch_size, memory_size * num_filters});
+
+ // Resize output tensor
+ output()->resize({batch_size, num_units});
+
+ luci_interpreter_pal::SetupScratchpadTensor(
+ input()->element_type(), weight_feature()->element_type(), getOutputTensors()[2],
+ getOutputTensors()[3], getOutputTensors()[4], getOutputTensors()[5], getOutputTensors()[6],
+ getOutputTensors()[7], input_shape, weight_time_shape, batch_size, num_filters, num_units);
+}
+
+void SVDF::execute() const
+{
+ switch (weight_feature()->element_type())
+ {
+ case loco::DataType::FLOAT32:
+ evalFloat();
+ break;
+ case loco::DataType::S8:
+ {
+ if (input()->element_type() == loco::DataType::S8)
+ evalInteger();
+ else
+ // TODO:: support hybrid SVDF op
+ throw std::runtime_error("Hybrid type is not currently supported");
+ break;
+ }
+ default:
+ throw std::runtime_error("Unsupported type");
+ }
+}
+
+void SVDF::evalInteger() const
+{
+ const auto effective_scale_1 = static_cast<double>(input()->scale() * weight_feature()->scale() /
+ input_activation_state()->scale());
+ const auto effective_scale_2 = static_cast<double>(input_activation_state()->scale() *
+ weight_time()->scale() / output()->scale());
+
+ int32_t effective_scale_1_a;
+ int effective_scale_1_b;
+ int32_t effective_scale_2_a;
+ int effective_scale_2_b;
+
+ tflite::QuantizeMultiplier(effective_scale_1, &effective_scale_1_a, &effective_scale_1_b);
+ tflite::QuantizeMultiplier(effective_scale_2, &effective_scale_2_a, &effective_scale_2_b);
+
+ TfLiteSVDFParams params_svdf{};
+ params_svdf.asymmetric_quantize_inputs = params().asymmetric_quantize_inputs;
+ params_svdf.rank = params().svdf_rank;
+ params_svdf.activation = get_tflite_activation(params().activation);
+
+ auto scratchpad_activation_state = getOutputTensors()[1];
+ // Note: it is expected that activation_state input variable tensor reset to zero,
+ // also expected that this variable tensor doesn't have buffer
+ auto scratchpad_data = getTensorData<int16_t>(scratchpad_activation_state);
+ std::fill_n(scratchpad_data, scratchpad_activation_state->shape().num_elements(), 0);
+
+ auto scratchpad = getOutputTensors()[2];
+ auto output_temp = getOutputTensors()[3];
+
+ int32_t input_zp = input()->zero_point();
+ int32_t output_zp = output()->zero_point();
+ luci_interpreter_pal::IntegerSVDF(
+ params_svdf, getTensorShape(input()), getTensorData<int8_t>(input()),
+ getTensorShape(weight_feature()), getTensorData<int8_t>(weight_feature()),
+ getTensorShape(weight_time()), getTensorData<int16_t>(weight_time()), getTensorShape(bias()),
+ getTensorData<int32_t>(bias()), scratchpad_data, getTensorShape(output()),
+ getTensorData<int8_t>(output()), getTensorData<int32_t>(scratchpad),
+ getTensorData<int32_t>(output_temp), effective_scale_1_a, effective_scale_1_b,
+ effective_scale_2_a, effective_scale_2_b, input_zp, output_zp);
+}
+
+void SVDF::evalFloat() const
+{
+ TfLiteSVDFParams params_svdf{};
+ params_svdf.asymmetric_quantize_inputs = params().asymmetric_quantize_inputs;
+ params_svdf.rank = params().svdf_rank;
+ params_svdf.activation = get_tflite_activation(params().activation);
+
+ auto scratchpad_activation_state = getOutputTensors()[1];
+ // Note: it is expected that activation_state input variable tensor reset to zero,
+ // also expected that this variable tensor doesn't have buffer
+ auto scratchpad_data = getTensorData<float>(scratchpad_activation_state);
+ std::fill_n(scratchpad_data, scratchpad_activation_state->shape().num_elements(), 0);
+
+ auto scratchpad_1 = getOutputTensors()[2];
+
+ luci_interpreter_pal::FloatSVDF(
+ params_svdf, getTensorShape(input()), getTensorData<float>(input()),
+ getTensorShape(weight_feature()), getTensorData<float>(weight_feature()),
+ getTensorShape(weight_time()), getTensorData<float>(weight_time()), getTensorShape(bias()),
+ getTensorData<float>(bias()), getTensorData<float>(scratchpad_1), scratchpad_data,
+ getTensorShape(output()), getTensorData<float>(output()));
+}
+
+} // namespace kernels
+} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/SVDF.h b/compiler/luci-interpreter/src/kernels/SVDF.h
new file mode 100644
index 000000000..335a6cd8f
--- /dev/null
+++ b/compiler/luci-interpreter/src/kernels/SVDF.h
@@ -0,0 +1,56 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_KERNELS_SVDF_H
+#define LUCI_INTERPRETER_KERNELS_SVDF_H
+
+#include "core/Kernel.h"
+#include "core/KernelParams.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+
+class SVDF : public KernelWithParams<SVDFParams>
+{
+public:
+ SVDF(const Tensor *input, const Tensor *weight_feature, const Tensor *weight_time,
+ const Tensor *bias, const Tensor *input_activation_state, Tensor *output,
+ Tensor *scratchpad_activation_state, Tensor *scratchpad_1, Tensor *scratchpad_2,
+ Tensor *scratchpad_3, Tensor *scratchpad_4, Tensor *scratchpad_5, Tensor *scratchpad_6,
+ const SVDFParams &params);
+
+ const Tensor *input() const { return _inputs[0]; }
+ const Tensor *weight_feature() const { return _inputs[1]; }
+ const Tensor *weight_time() const { return _inputs[2]; }
+ const Tensor *bias() const { return _inputs[3]; }
+ const Tensor *input_activation_state() const { return _inputs[4]; }
+
+ Tensor *output() const { return _outputs[0]; }
+
+ void configure() override;
+ void execute() const override;
+
+private:
+ void evalFloat() const;
+ void evalInteger() const;
+};
+
+} // namespace kernels
+} // namespace luci_interpreter
+
+#endif // LUCI_INTERPRETER_KERNELS_SVDF_H
diff --git a/compiler/luci-interpreter/src/kernels/SVDF.test.cpp b/compiler/luci-interpreter/src/kernels/SVDF.test.cpp
new file mode 100644
index 000000000..82bd9b009
--- /dev/null
+++ b/compiler/luci-interpreter/src/kernels/SVDF.test.cpp
@@ -0,0 +1,341 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "kernels/SVDF.h"
+#include "kernels/TestUtils.h"
+#include "luci_interpreter/TestMemoryManager.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+namespace
+{
+
+using namespace testing;
+
+class SVDFTest : public ::testing::Test
+{
+protected:
+ void SetUp() override { _memory_manager = std::make_unique<TestMemoryManager>(); }
+
+ std::unique_ptr<IMemoryManager> _memory_manager;
+};
+
+TEST_F(SVDFTest, FullIntegerTest)
+{
+ const int32_t batches = 2;
+ const int32_t input_size = 3;
+ const int32_t units = 4;
+ const int32_t memory_size = 10;
+ const int32_t rank = 1;
+ const int32_t num_filters = units * rank;
+
+ Shape input_shape{batches, input_size};
+ Shape weight_feature_shape{num_filters, input_size};
+ Shape weight_time_shape{num_filters, memory_size};
+ Shape bias_shape{units};
+ Shape activation_state_shape{batches, memory_size * num_filters};
+
+ std::vector<float> input_data{0.49837467, 0.19278903, 0.26584083,
+ 0.17660543, 0.52949083, -0.77931279};
+
+ std::vector<float> weight_feature_data{-0.31930989, -0.36118156, 0.0079667, 0.37613347,
+ 0.22197971, 0.12416199, 0.27901134, 0.27557442,
+ 0.3905206, -0.36137494, -0.06634006, -0.10640851};
+
+ std::vector<float> weight_time_data{
+ -0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
+ 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
+
+ 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
+ -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
+
+ -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
+ 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
+
+ -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
+ -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657};
+
+ std::vector<float> bias_data{-0.0976817, 0.15294972, 0.39635518, -0.02702999};
+
+ std::pair<float, int32_t> input_quant_param = quantizationParams<int8_t>(-1, 1);
+ std::pair<float, int32_t> weight_feature_quant_param = quantizationParams<int8_t>(-0.5, 0.5);
+ std::pair<float, int32_t> weight_time_quant_param = quantizationParams<int16_t>(-1, 1);
+ std::pair<float, int32_t> bias_quant_param = quantizationParams<int32_t>(-512, 512);
+ std::pair<float, int32_t> activation_state_quant_param = quantizationParams<int16_t>(-16, 16);
+
+ std::pair<float, int32_t> output_quant_param = quantizationParams<int8_t>(-0.5, 0.5);
+
+ Tensor input_tensor =
+ makeInputTensor<DataType::S8>(input_shape, input_quant_param.first, input_quant_param.second,
+ input_data, _memory_manager.get());
+ Tensor weight_feature_tensor = makeInputTensor<DataType::S8>(
+ weight_feature_shape, weight_feature_quant_param.first, weight_feature_quant_param.second,
+ weight_feature_data, _memory_manager.get());
+ Tensor weight_time_tensor = makeInputTensor<DataType::S16>(
+ weight_time_shape, weight_time_quant_param.first, weight_time_quant_param.second,
+ weight_time_data, _memory_manager.get());
+ Tensor bias_tensor = makeInputTensor<DataType::S32>(
+ bias_shape, bias_quant_param.first, bias_quant_param.second, bias_data, _memory_manager.get());
+ Tensor activation_state_tensor = makeOutputTensor(
+ DataType::S16, activation_state_quant_param.first, activation_state_quant_param.second);
+ activation_state_tensor.resize(activation_state_shape);
+ Tensor output_tensor =
+ makeOutputTensor(DataType::S8, output_quant_param.first, output_quant_param.second);
+
+ Tensor scratchpad_activation_state(DataType::S16, Shape({}), {}, "");
+ Tensor scratchpad_1(DataType::S32, Shape({}), {}, "");
+ Tensor scratchpad_2(DataType::S32, Shape({}), {}, "");
+ Tensor scratchpad_3(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor scratchpad_4(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor scratchpad_5(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor scratchpad_6(DataType::FLOAT32, Shape({}), {}, "");
+
+ SVDFParams params{};
+ params.activation = Activation::RELU;
+ params.asymmetric_quantize_inputs = false;
+ params.svdf_rank = rank;
+
+ SVDF kernel(&input_tensor, &weight_feature_tensor, &weight_time_tensor, &bias_tensor,
+ &activation_state_tensor, &output_tensor, &scratchpad_activation_state, &scratchpad_1,
+ &scratchpad_2, &scratchpad_3, &scratchpad_4, &scratchpad_5, &scratchpad_6, params);
+ kernel.configure();
+ _memory_manager->allocate_memory(output_tensor);
+ _memory_manager->allocate_memory(scratchpad_activation_state);
+ _memory_manager->allocate_memory(scratchpad_1);
+ _memory_manager->allocate_memory(scratchpad_2);
+ _memory_manager->allocate_memory(scratchpad_3);
+ _memory_manager->allocate_memory(scratchpad_4);
+ _memory_manager->allocate_memory(scratchpad_5);
+ _memory_manager->allocate_memory(scratchpad_6);
+ kernel.execute();
+
+ std::vector<int8_t> ref_output_data{-9, 24, 31, 1, -10, 10, -3, 0};
+
+ std::vector<int32_t> ref_output_shape{batches, units};
+ EXPECT_THAT(extractTensorData<int8_t>(output_tensor), ref_output_data);
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(ref_output_shape));
+}
+
+TEST_F(SVDFTest, FloatTest)
+{
+ const int32_t batches = 2;
+ const int32_t input_size = 3;
+ const int32_t units = 4;
+ const int32_t memory_size = 10;
+ const int32_t rank = 1;
+ const int32_t num_filters = units * rank;
+
+ Shape input_shape{batches, input_size};
+ Shape weight_feature_shape{num_filters, input_size};
+ Shape weight_time_shape{num_filters, memory_size};
+ Shape activation_state_shape{batches, memory_size * num_filters};
+
+ std::vector<float> input_data{0.12609188, -0.46347019, -0.89598465,
+ 0.35867718, 0.36897406, 0.73463392};
+
+ std::vector<float> weight_feature_data{-0.31930989, -0.36118156, 0.0079667, 0.37613347,
+ 0.22197971, 0.12416199, 0.27901134, 0.27557442,
+ 0.3905206, -0.36137494, -0.06634006, -0.10640851};
+
+ std::vector<float> weight_time_data{
+ -0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
+ 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
+
+ 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
+ -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
+
+ -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
+ 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
+
+ -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
+ -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657};
+
+ Tensor input_tensor =
+ makeInputTensor<DataType::FLOAT32>(input_shape, input_data, _memory_manager.get());
+ Tensor weight_feature_tensor = makeInputTensor<DataType::FLOAT32>(
+ weight_feature_shape, weight_feature_data, _memory_manager.get());
+ Tensor weight_time_tensor =
+ makeInputTensor<DataType::FLOAT32>(weight_time_shape, weight_time_data, _memory_manager.get());
+ Tensor activation_state_tensor = makeOutputTensor(DataType::FLOAT32);
+ activation_state_tensor.resize(activation_state_shape);
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+
+ Tensor scratchpad_activation_state(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor scratchpad_1(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor scratchpad_2(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor scratchpad_3(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor scratchpad_4(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor scratchpad_5(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor scratchpad_6(DataType::FLOAT32, Shape({}), {}, "");
+
+ SVDFParams params{};
+ params.activation = Activation::NONE;
+ params.asymmetric_quantize_inputs = false;
+ params.svdf_rank = rank;
+
+ SVDF kernel(&input_tensor, &weight_feature_tensor, &weight_time_tensor, nullptr,
+ &activation_state_tensor, &output_tensor, &scratchpad_activation_state, &scratchpad_1,
+ &scratchpad_2, &scratchpad_3, &scratchpad_4, &scratchpad_5, &scratchpad_6, params);
+ kernel.configure();
+ _memory_manager->allocate_memory(output_tensor);
+ _memory_manager->allocate_memory(scratchpad_activation_state);
+ _memory_manager->allocate_memory(scratchpad_1);
+ _memory_manager->allocate_memory(scratchpad_2);
+ _memory_manager->allocate_memory(scratchpad_3);
+ _memory_manager->allocate_memory(scratchpad_4);
+ _memory_manager->allocate_memory(scratchpad_5);
+ _memory_manager->allocate_memory(scratchpad_6);
+ kernel.execute();
+
+ std::vector<float> ref_output_data{0.014899, -0.0517661, -0.143725, -0.00271883,
+ -0.03004015, 0.09565311, 0.1587342, 0.00784263};
+
+ std::vector<float> ref_output_shape{batches, units};
+ const float tolerance = 1e-5;
+ EXPECT_THAT(extractTensorData<float>(output_tensor), FloatArrayNear(ref_output_data, tolerance));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(ref_output_shape));
+}
+
+TEST_F(SVDFTest, Unsupported_Type_Configure_NEG)
+{
+ const int32_t batches = 2;
+ const int32_t input_size = 3;
+ const int32_t units = 4;
+ const int32_t memory_size = 10;
+ const int32_t rank = 1;
+ const int32_t num_filters = units * rank;
+
+ Shape input_shape{batches, input_size};
+ Shape weight_feature_shape{num_filters, input_size};
+ Shape weight_time_shape{num_filters, memory_size};
+ Shape activation_state_shape{batches, memory_size * num_filters};
+
+ std::vector<int32_t> input_data{0, 1, 3, 4, 4, -2};
+
+ std::vector<float> weight_feature_data{-0.31930989, -0.36118156, 0.0079667, 0.37613347,
+ 0.22197971, 0.12416199, 0.27901134, 0.27557442,
+ 0.3905206, -0.36137494, -0.06634006, -0.10640851};
+
+ std::vector<float> weight_time_data{
+ -0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
+ 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
+
+ 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
+ -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
+
+ -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
+ 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
+
+ -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
+ -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657};
+
+ Tensor input_tensor =
+ makeInputTensor<DataType::S32>(input_shape, input_data, _memory_manager.get());
+ Tensor weight_feature_tensor = makeInputTensor<DataType::FLOAT32>(
+ weight_feature_shape, weight_feature_data, _memory_manager.get());
+ Tensor weight_time_tensor =
+ makeInputTensor<DataType::FLOAT32>(weight_time_shape, weight_time_data, _memory_manager.get());
+ Tensor activation_state_tensor = makeOutputTensor(DataType::FLOAT32);
+ activation_state_tensor.resize(activation_state_shape);
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+
+ Tensor scratchpad_activation_state(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor scratchpad_1(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor scratchpad_2(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor scratchpad_3(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor scratchpad_4(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor scratchpad_5(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor scratchpad_6(DataType::FLOAT32, Shape({}), {}, "");
+
+ SVDFParams params{};
+ params.activation = Activation::NONE;
+ params.asymmetric_quantize_inputs = false;
+ params.svdf_rank = rank;
+
+ SVDF kernel(&input_tensor, &weight_feature_tensor, &weight_time_tensor, nullptr,
+ &activation_state_tensor, &output_tensor, &scratchpad_activation_state, &scratchpad_1,
+ &scratchpad_2, &scratchpad_3, &scratchpad_4, &scratchpad_5, &scratchpad_6, params);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(SVDFTest, Invalid_Input_Shape_NEG)
+{
+ const int32_t batches = 2;
+ const int32_t right_input_size = 3;
+ const int32_t wrong_input_size = 4;
+ const int32_t units = 4;
+ const int32_t memory_size = 10;
+ const int32_t rank = 1;
+ const int32_t num_filters = units * rank;
+
+ Shape input_shape{batches, wrong_input_size};
+ Shape weight_feature_shape{num_filters, right_input_size};
+ Shape weight_time_shape{num_filters, memory_size};
+ Shape activation_state_shape{batches, memory_size * num_filters};
+
+ std::vector<float> input_data{0, 1, 3, 2, 4, 4, -2, 1};
+
+ std::vector<float> weight_feature_data{-0.31930989, -0.36118156, 0.0079667, 0.37613347,
+ 0.22197971, 0.12416199, 0.27901134, 0.27557442,
+ 0.3905206, -0.36137494, -0.06634006, -0.10640851};
+
+ std::vector<float> weight_time_data{
+ -0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
+ 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
+
+ 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
+ -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
+
+ -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
+ 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
+
+ -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
+ -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657};
+
+ Tensor input_tensor =
+ makeInputTensor<DataType::FLOAT32>(input_shape, input_data, _memory_manager.get());
+ Tensor weight_feature_tensor = makeInputTensor<DataType::FLOAT32>(
+ weight_feature_shape, weight_feature_data, _memory_manager.get());
+ Tensor weight_time_tensor =
+ makeInputTensor<DataType::FLOAT32>(weight_time_shape, weight_time_data, _memory_manager.get());
+ Tensor activation_state_tensor = makeOutputTensor(DataType::FLOAT32);
+ activation_state_tensor.resize(activation_state_shape);
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+
+ Tensor scratchpad_activation_state(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor scratchpad_1(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor scratchpad_2(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor scratchpad_3(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor scratchpad_4(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor scratchpad_5(DataType::FLOAT32, Shape({}), {}, "");
+ Tensor scratchpad_6(DataType::FLOAT32, Shape({}), {}, "");
+
+ SVDFParams params{};
+ params.activation = Activation::NONE;
+ params.asymmetric_quantize_inputs = false;
+ params.svdf_rank = rank;
+
+ SVDF kernel(&input_tensor, &weight_feature_tensor, &weight_time_tensor, nullptr,
+ &activation_state_tensor, &output_tensor, &scratchpad_activation_state, &scratchpad_1,
+ &scratchpad_2, &scratchpad_3, &scratchpad_4, &scratchpad_5, &scratchpad_6, params);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+} // namespace
+} // namespace kernels
+} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/Slice.cpp b/compiler/luci-interpreter/src/kernels/Slice.cpp
index 37a834a18..2fe2c5471 100644
--- a/compiler/luci-interpreter/src/kernels/Slice.cpp
+++ b/compiler/luci-interpreter/src/kernels/Slice.cpp
@@ -139,6 +139,11 @@ void Slice::execute() const
getTensorData<uint8_t>(input()), getTensorShape(output()),
getTensorData<uint8_t>(output()));
break;
+ case DataType::S8:
+ luci_interpreter_pal::Slice(op_params, getTensorShape(input()),
+ getTensorData<int8_t>(input()), getTensorShape(output()),
+ getTensorData<int8_t>(output()));
+ break;
default:
throw std::runtime_error("Unsupported input type.");
}
diff --git a/compiler/luci-interpreter/src/kernels/Slice.test.cpp b/compiler/luci-interpreter/src/kernels/Slice.test.cpp
index 3e0d0b0d7..517982990 100644
--- a/compiler/luci-interpreter/src/kernels/Slice.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/Slice.test.cpp
@@ -31,8 +31,8 @@ template <typename T> class SliceTest : public ::testing::Test
{
};
-using DataTypes = ::testing::Types<float, uint8_t>;
-TYPED_TEST_CASE(SliceTest, DataTypes);
+using DataTypes = ::testing::Types<float, uint8_t, int8_t>;
+TYPED_TEST_SUITE(SliceTest, DataTypes);
TYPED_TEST(SliceTest, SimpleTest)
{
diff --git a/compiler/luci-interpreter/src/kernels/Softmax.test.cpp b/compiler/luci-interpreter/src/kernels/Softmax.test.cpp
index 9de40b6ec..08e70672d 100644
--- a/compiler/luci-interpreter/src/kernels/Softmax.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/Softmax.test.cpp
@@ -93,7 +93,7 @@ template <typename T> class SoftmaxTest : public ::testing::Test
};
using DataTypes = ::testing::Types<float, uint8_t, int8_t>;
-TYPED_TEST_CASE(SoftmaxTest, DataTypes);
+TYPED_TEST_SUITE(SoftmaxTest, DataTypes);
TYPED_TEST(SoftmaxTest, Simple)
{
diff --git a/compiler/luci-interpreter/src/kernels/SpaceToBatchND.test.cpp b/compiler/luci-interpreter/src/kernels/SpaceToBatchND.test.cpp
index e06501c8c..3a8b0a812 100644
--- a/compiler/luci-interpreter/src/kernels/SpaceToBatchND.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/SpaceToBatchND.test.cpp
@@ -90,7 +90,7 @@ template <typename T> class SpaceToBatchNDTest : public ::testing::Test
};
using DataTypes = ::testing::Types<float, uint8_t>;
-TYPED_TEST_CASE(SpaceToBatchNDTest, DataTypes);
+TYPED_TEST_SUITE(SpaceToBatchNDTest, DataTypes);
TYPED_TEST(SpaceToBatchNDTest, Simple)
{
diff --git a/compiler/luci-interpreter/src/kernels/SpaceToDepth.test.cpp b/compiler/luci-interpreter/src/kernels/SpaceToDepth.test.cpp
index 735c010b9..4af488618 100644
--- a/compiler/luci-interpreter/src/kernels/SpaceToDepth.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/SpaceToDepth.test.cpp
@@ -32,7 +32,7 @@ template <typename T> class SpaceToDepthTest : public ::testing::Test
};
using DataTypes = ::testing::Types<float, uint8_t>;
-TYPED_TEST_CASE(SpaceToDepthTest, DataTypes);
+TYPED_TEST_SUITE(SpaceToDepthTest, DataTypes);
TYPED_TEST(SpaceToDepthTest, SimpleCase)
{
diff --git a/compiler/luci-interpreter/src/kernels/Split.test.cpp b/compiler/luci-interpreter/src/kernels/Split.test.cpp
index 74d57aed3..283cd9aa9 100644
--- a/compiler/luci-interpreter/src/kernels/Split.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/Split.test.cpp
@@ -73,7 +73,7 @@ template <typename T> class SplitTest : public ::testing::Test
};
using DataTypes = ::testing::Types<float, uint8_t>;
-TYPED_TEST_CASE(SplitTest, DataTypes);
+TYPED_TEST_SUITE(SplitTest, DataTypes);
TYPED_TEST(SplitTest, FourDimensional)
{
diff --git a/compiler/luci-interpreter/src/kernels/SplitV.test.cpp b/compiler/luci-interpreter/src/kernels/SplitV.test.cpp
index aac0567d7..035bc2122 100644
--- a/compiler/luci-interpreter/src/kernels/SplitV.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/SplitV.test.cpp
@@ -77,7 +77,7 @@ template <typename T> class SplitVTest : public ::testing::Test
};
using DataTypes = ::testing::Types<float, uint8_t, int16_t>;
-TYPED_TEST_CASE(SplitVTest, DataTypes);
+TYPED_TEST_SUITE(SplitVTest, DataTypes);
TYPED_TEST(SplitVTest, ThreeDimensional)
{
diff --git a/compiler/luci-interpreter/src/kernels/Squeeze.test.cpp b/compiler/luci-interpreter/src/kernels/Squeeze.test.cpp
index d3326fe98..1bc0b6459 100644
--- a/compiler/luci-interpreter/src/kernels/Squeeze.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/Squeeze.test.cpp
@@ -56,7 +56,7 @@ template <typename T> class SqueezeTest : public ::testing::Test
};
using DataTypes = ::testing::Types<float, uint8_t>;
-TYPED_TEST_CASE(SqueezeTest, DataTypes);
+TYPED_TEST_SUITE(SqueezeTest, DataTypes);
TYPED_TEST(SqueezeTest, TotalTest)
{
diff --git a/compiler/luci-interpreter/src/kernels/Sub.cpp b/compiler/luci-interpreter/src/kernels/Sub.cpp
index 603c62d0f..24b6a72e5 100644
--- a/compiler/luci-interpreter/src/kernels/Sub.cpp
+++ b/compiler/luci-interpreter/src/kernels/Sub.cpp
@@ -37,6 +37,7 @@ Sub::Sub(const Tensor *input1, const Tensor *input2, Tensor *output, const SubPa
void Sub::configure()
{
LUCI_INTERPRETER_CHECK(!(input1()->element_type() != input2()->element_type()))
+ LUCI_INTERPRETER_CHECK(!(input1()->element_type() != output()->element_type()))
output()->resize(calculateShapeForBroadcast(input1()->shape(), input2()->shape()));
}
@@ -47,6 +48,12 @@ void Sub::execute() const
case DataType::FLOAT32:
evalFloat();
break;
+ case DataType::S64:
+ evalInteger<int64_t>();
+ break;
+ case DataType::S32:
+ evalInteger<int32_t>();
+ break;
case DataType::U8:
evalQuantized();
break;
@@ -57,13 +64,8 @@ void Sub::execute() const
void Sub::evalFloat() const
{
- float activation_min{};
- float activation_max{};
- calculateActivationRange(_params.activation, &activation_min, &activation_max);
-
tflite::ArithmeticParams params{};
- params.float_activation_min = activation_min;
- params.float_activation_max = activation_max;
+ fillArithmeticActivationRange<float>(params, _params.activation);
const bool need_broadcast = tflite::reference_ops::ProcessBroadcastShapes(
getTensorShape(input1()), getTensorShape(input2()), &params);
@@ -82,6 +84,28 @@ void Sub::evalFloat() const
}
}
+template <typename T> void Sub::evalInteger() const
+{
+ tflite::ArithmeticParams params{};
+ fillArithmeticActivationRange<T>(params, _params.activation);
+
+ const bool need_broadcast = tflite::reference_ops::ProcessBroadcastShapes(
+ getTensorShape(input1()), getTensorShape(input2()), &params);
+
+ if (need_broadcast)
+ {
+ tflite::reference_ops::BroadcastSubSlow(
+ params, getTensorShape(input1()), getTensorData<T>(input1()), getTensorShape(input2()),
+ getTensorData<T>(input2()), getTensorShape(output()), getTensorData<T>(output()));
+ }
+ else
+ {
+ tflite::reference_ops::Sub(params, getTensorShape(input1()), getTensorData<T>(input1()),
+ getTensorShape(input2()), getTensorData<T>(input2()),
+ getTensorShape(output()), getTensorData<T>(output()));
+ }
+}
+
void Sub::evalQuantized() const
{
const auto input1_scale = static_cast<double>(input1()->scale());
diff --git a/compiler/luci-interpreter/src/kernels/Sub.h b/compiler/luci-interpreter/src/kernels/Sub.h
index d7940b5c6..23952b3bd 100644
--- a/compiler/luci-interpreter/src/kernels/Sub.h
+++ b/compiler/luci-interpreter/src/kernels/Sub.h
@@ -39,6 +39,7 @@ public:
private:
void evalFloat() const;
+ template <typename T> void evalInteger() const;
void evalQuantized() const;
};
diff --git a/compiler/luci-interpreter/src/kernels/Sub.test.cpp b/compiler/luci-interpreter/src/kernels/Sub.test.cpp
index c189f4481..9abafd49a 100644
--- a/compiler/luci-interpreter/src/kernels/Sub.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/Sub.test.cpp
@@ -162,6 +162,51 @@ TEST_F(SubTest, Float)
}
}
+template <loco::DataType DType> void CheckInteger(luci_interpreter::IMemoryManager *memory_manager)
+{
+ using dtype = typename loco::DataTypeImpl<DType>::Type;
+ Shape base_shape = {2, 3, 1, 2};
+ std::vector<Shape> test_shapes{{1, 1, 3, 2}, {1, 3, 1, 2}, {2, 1, 3, 1}, {2, 3, 1, 1}};
+ std::vector<std::vector<dtype>> test_outputs = {
+ {0, 1, 2, 3, 0, 0, 0, 0, 4, 1, 0, 0, 0, 0, 7, 0, 3, 0,
+ 0, 2, 4, 4, 0, 0, 3, 0, 10, 0, 6, 0, 3, 0, 10, 2, 6, 0},
+ {0, 1, 4, 1, 3, 0, 0, 2, 10, 0, 6, 0},
+ {0, 0, 0, 1, 2, 5, 0, 0, 0, 0, 4, 3, 0, 0, 3, 0, 7, 0,
+ 2, 4, 0, 2, 0, 0, 8, 0, 6, 0, 1, 0, 8, 2, 6, 0, 1, 0},
+ {0, 0, 0, 0, 7, 0, 2, 4, 6, 0, 1, 0}};
+ std::vector<dtype> input1_data{-1, 2, 1, 0, 4, -5, 1, 3, 7, -1, 7, 1};
+ std::vector<dtype> input2_data{4, 1, -3, -1, 1, 6};
+ for (size_t i = 0; i < test_shapes.size(); ++i)
+ {
+ Tensor input1_tensor = makeInputTensor<DType>(base_shape, input1_data, memory_manager);
+ Tensor input2_tensor = makeInputTensor<DType>(test_shapes[i], input2_data, memory_manager);
+ Tensor output_tensor = makeOutputTensor(DType);
+
+ SubParams params{};
+ params.activation = Activation::RELU;
+
+ Sub kernel(&input1_tensor, &input2_tensor, &output_tensor, params);
+ kernel.configure();
+ memory_manager->allocate_memory(output_tensor);
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<dtype>(output_tensor), test_outputs[i])
+ << "With shape number " << i;
+ }
+};
+
+TEST_F(SubTest, SInt32)
+{
+ CheckInteger<loco::DataType::S32>(_memory_manager.get());
+ SUCCEED();
+}
+
+TEST_F(SubTest, SInt64)
+{
+ CheckInteger<loco::DataType::S64>(_memory_manager.get());
+ SUCCEED();
+}
+
TEST_F(SubTest, Input_Output_Type_NEG)
{
Tensor input1_tensor = makeInputTensor<DataType::FLOAT32>({1}, {1.f}, _memory_manager.get());
@@ -175,11 +220,24 @@ TEST_F(SubTest, Input_Output_Type_NEG)
EXPECT_ANY_THROW(kernel.configure());
}
-TEST_F(SubTest, Invalid_Input_Type_NEG)
+TEST_F(SubTest, Invalid_Output_Type_NEG)
{
Tensor input1_tensor = makeInputTensor<DataType::S64>({1}, {1}, _memory_manager.get());
Tensor input2_tensor = makeInputTensor<DataType::S64>({1}, {2}, _memory_manager.get());
- Tensor output_tensor = makeOutputTensor(DataType::S64);
+ Tensor output_tensor = makeOutputTensor(DataType::S32);
+
+ SubParams params{};
+ params.activation = Activation::RELU;
+
+ Sub kernel(&input1_tensor, &input2_tensor, &output_tensor, params);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+TEST_F(SubTest, Invalid_Input_Type_NEG)
+{
+ Tensor input1_tensor = makeInputTensor<DataType::U64>({1}, {1}, _memory_manager.get());
+ Tensor input2_tensor = makeInputTensor<DataType::U64>({1}, {2}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::U64);
SubParams params{};
params.activation = Activation::RELU;
@@ -190,6 +248,19 @@ TEST_F(SubTest, Invalid_Input_Type_NEG)
EXPECT_ANY_THROW(kernel.execute());
}
+TEST_F(SubTest, Mismatching_Input_Int_Types_NEG)
+{
+ Tensor input1_tensor = makeInputTensor<DataType::S32>({1}, {1}, _memory_manager.get());
+ Tensor input2_tensor = makeInputTensor<DataType::S64>({1}, {2}, _memory_manager.get());
+ Tensor output_tensor = makeOutputTensor(DataType::S32);
+
+ SubParams params{};
+ params.activation = Activation::NONE;
+
+ Sub kernel(&input1_tensor, &input2_tensor, &output_tensor, params);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
} // namespace
} // namespace kernels
} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/kernels/Transpose.test.cpp b/compiler/luci-interpreter/src/kernels/Transpose.test.cpp
index 107179910..43be8f8b9 100644
--- a/compiler/luci-interpreter/src/kernels/Transpose.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/Transpose.test.cpp
@@ -52,7 +52,7 @@ template <typename T> class TransposeTest : public ::testing::Test
};
using DataTypes = ::testing::Types<float, uint8_t>;
-TYPED_TEST_CASE(TransposeTest, DataTypes);
+TYPED_TEST_SUITE(TransposeTest, DataTypes);
TYPED_TEST(TransposeTest, Small3D)
{
diff --git a/compiler/luci-interpreter/src/kernels/Unpack.test.cpp b/compiler/luci-interpreter/src/kernels/Unpack.test.cpp
index 4f22c9f30..9384ddc83 100644
--- a/compiler/luci-interpreter/src/kernels/Unpack.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/Unpack.test.cpp
@@ -75,7 +75,7 @@ template <typename T> class UnpackTest : public ::testing::Test
};
using DataTypes = ::testing::Types<float, uint8_t>;
-TYPED_TEST_CASE(UnpackTest, DataTypes);
+TYPED_TEST_SUITE(UnpackTest, DataTypes);
TYPED_TEST(UnpackTest, ThreeOutputs)
{
diff --git a/compiler/luci-interpreter/src/kernels/Utils.cpp b/compiler/luci-interpreter/src/kernels/Utils.cpp
index 586cfa1e1..5d8e5db83 100644
--- a/compiler/luci-interpreter/src/kernels/Utils.cpp
+++ b/compiler/luci-interpreter/src/kernels/Utils.cpp
@@ -27,17 +27,18 @@ namespace luci_interpreter
namespace kernels
{
-void calculateActivationRange(Activation activation, float *activation_min, float *activation_max)
+template <typename T>
+void calculateActivationRange(Activation activation, T *activation_min, T *activation_max)
{
switch (activation)
{
case Activation::NONE:
- *activation_min = std::numeric_limits<float>::lowest();
- *activation_max = std::numeric_limits<float>::max();
+ *activation_min = std::numeric_limits<T>::lowest();
+ *activation_max = std::numeric_limits<T>::max();
break;
case Activation::RELU:
*activation_min = 0;
- *activation_max = std::numeric_limits<float>::max();
+ *activation_max = std::numeric_limits<T>::max();
break;
case Activation::RELU_N1_TO_1:
*activation_min = -1;
@@ -52,6 +53,13 @@ void calculateActivationRange(Activation activation, float *activation_min, floa
}
}
+template void calculateActivationRange(Activation activation, float *activation_min,
+ float *activation_max);
+template void calculateActivationRange(Activation activation, int32_t *activation_min,
+ int32_t *activation_max);
+template void calculateActivationRange(Activation activation, int64_t *activation_min,
+ int64_t *activation_max);
+
static void calculateActivationRangeQuantizedImpl(Activation activation, int32_t qmin, int32_t qmax,
const Tensor *output, int32_t *activation_min,
int32_t *activation_max)
@@ -175,7 +183,11 @@ Shape calculateShapeForBroadcast(const Shape &input1_shape, const Shape &input2_
{
const int32_t input1_dim = i < num_input1_dims ? input1_shape.dim(num_input1_dims - i - 1) : 1;
const int32_t input2_dim = i < num_input2_dims ? input2_shape.dim(num_input2_dims - i - 1) : 1;
- assert(input1_dim == input2_dim || input1_dim == 1 || input2_dim == 1);
+
+ bool need_broadcast = input1_dim != input2_dim;
+ bool can_broadcast = input1_dim == 1 || input2_dim == 1;
+ LUCI_INTERPRETER_CHECK(!need_broadcast || can_broadcast);
+
output_shape.dim(num_out_dims - i - 1) = std::max(input1_dim, input2_dim);
}
diff --git a/compiler/luci-interpreter/src/kernels/Utils.h b/compiler/luci-interpreter/src/kernels/Utils.h
index 817a42f83..ebeb20e66 100644
--- a/compiler/luci-interpreter/src/kernels/Utils.h
+++ b/compiler/luci-interpreter/src/kernels/Utils.h
@@ -76,11 +76,42 @@ inline int32_t calcOffset(const Shape &shape, int32_t d0, int32_t d1, int32_t d2
return ((d0 * shape.dim(1) + d1) * shape.dim(2) + d2) * shape.dim(3) + d3;
}
-void calculateActivationRange(Activation activation, float *activation_min, float *activation_max);
+template <typename T>
+void calculateActivationRange(Activation activation, T *activation_min, T *activation_max);
void calculateActivationRangeQuantized(Activation activation, const Tensor *output,
int32_t *activation_min, int32_t *activation_max);
+template <typename T> constexpr bool one_of_types() { return false; }
+
+// Checks if T is equal to one of {U,Other} types
+template <typename T, typename U, typename... Other> constexpr bool one_of_types()
+{
+ return std::is_same<T, U>::value || one_of_types<T, Other...>();
+}
+
+/**
+ * Fills activation min and max parameters depending on given data type and activation
+ *
+ * T is a template parameter, so after optimization this code left with only required if case
+ *
+ * @tparam T data type of arithmetic operation output tensor
+ * @param params tflite params to fill
+ * @param activation luci_interpreter::Activation of arithmetic operation
+ */
+template <typename T>
+void fillArithmeticActivationRange(tflite::ArithmeticParams &p, Activation act)
+{
+ static_assert(one_of_types<T, float, int32_t, int64_t>(), "Unsupported dtype");
+
+ if (std::is_same<T, float>::value)
+ calculateActivationRange(act, &p.float_activation_min, &p.float_activation_max);
+ if (std::is_same<T, int32_t>::value)
+ calculateActivationRange(act, &p.quantized_activation_min, &p.quantized_activation_max);
+ else
+ calculateActivationRange(act, &p.int64_activation_min, &p.int64_activation_max);
+}
+
// Decompose a double multiplier into a Q0.31 int32 representation of its
// significand, and shift representation of its exponent.
//
diff --git a/compiler/luci-interpreter/src/loader/CMakeLists.txt b/compiler/luci-interpreter/src/loader/CMakeLists.txt
index 2cde99f5d..292771592 100644
--- a/compiler/luci-interpreter/src/loader/CMakeLists.txt
+++ b/compiler/luci-interpreter/src/loader/CMakeLists.txt
@@ -17,7 +17,9 @@ endmacro(REGISTER_KERNEL)
include(${KERNEL_REGISTER_FILE})
add_library(${LUCI_INTERPRETER_LOADER} STATIC ${SOURCES})
-set_target_properties(${LUCI_INTERPRETER_LOADER} PROPERTIES POSITION_INDEPENDENT_CODE ON)
+if (NOT NNCC_LIBRARY_NO_PIC)
+ set_target_properties(${LUCI_INTERPRETER_LOADER} PROPERTIES POSITION_INDEPENDENT_CODE ON)
+endif(NOT NNCC_LIBRARY_NO_PIC)
target_include_directories(${LUCI_INTERPRETER_LOADER} PUBLIC "${LUCI_INTERPRETER_PAL_DIR}")
target_include_directories(${LUCI_INTERPRETER_LOADER} PUBLIC "${LUCI_INTERPRETER_SOURCE_DIR}")
diff --git a/compiler/luci-interpreter/src/loader/GraphLoader.cpp b/compiler/luci-interpreter/src/loader/GraphLoader.cpp
index a14442ed5..dba39050c 100644
--- a/compiler/luci-interpreter/src/loader/GraphLoader.cpp
+++ b/compiler/luci-interpreter/src/loader/GraphLoader.cpp
@@ -73,6 +73,26 @@ const void *getNodeData(const luci::CircleConst *node, size_t *data_size)
}
}
+const void *getNodeData(const luci::CircleCustom *node, size_t *data_size)
+{
+ if (node->custom_code() != "CircleReferencingConst")
+ return nullptr;
+
+ // helper struct which describes data loaded to custom_options of CircleReferencingConst node
+ // TODO move this struct to header
+ struct ConstDataReference
+ {
+ const uint8_t *data = nullptr;
+ uint32_t size = 0;
+ };
+
+ const auto &custom_options = node->custom_options();
+ const auto &const_data_ref = *reinterpret_cast<const ConstDataReference *>(custom_options.data());
+
+ *data_size = const_data_ref.size;
+ return const_data_ref.data;
+}
+
bool isExecutableNode(const luci::CircleNode *node)
{
switch (node->opcode())
@@ -83,12 +103,30 @@ bool isExecutableNode(const luci::CircleNode *node)
case luci::CircleOpcode::CIRCLEOUTPUT:
case luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE:
// The following nodes denote outputs of multiple-output nodes.
+ case luci::CircleOpcode::CIRCLEBIDIRECTIONAL_SEQUENCE_LSTM_OUT:
+ case luci::CircleOpcode::CIRCLECUSTOMOUT:
case luci::CircleOpcode::CIRCLEIFOUT:
+ case luci::CircleOpcode::CIRCLENONMAXSUPPRESSIONV4OUT:
+ case luci::CircleOpcode::CIRCLENONMAXSUPPRESSIONV5OUT:
case luci::CircleOpcode::CIRCLESPLITOUT:
case luci::CircleOpcode::CIRCLESPLITVOUT:
+ case luci::CircleOpcode::CIRCLETOPKV2OUT:
+ case luci::CircleOpcode::CIRCLEUNIQUEOUT:
case luci::CircleOpcode::CIRCLEUNPACKOUT:
+ case luci::CircleOpcode::CIRCLEVARIABLE:
case luci::CircleOpcode::CIRCLEWHILEOUT:
return false;
+ // Custom nodes may be executable and non-executable
+ case luci::CircleOpcode::CUSTOM:
+ {
+ auto const custom_node = loco::must_cast<const luci::CircleCustom *>(node);
+
+ // TODO handle more non-executable Custom ops here
+ if (custom_node->custom_code() == "CircleReferencingConst")
+ return false;
+
+ return true;
+ }
default:
return true;
}
@@ -102,15 +140,34 @@ bool isTensorProducingNode(const luci::CircleNode *node)
case luci::CircleOpcode::CIRCLEOUTPUT:
// The following nodes are multiple-output nodes. They do not produce tensors, the tensors
// are produced by the corresponding *Out nodes instead.
+ case luci::CircleOpcode::BIDIRECTIONAL_SEQUENCE_LSTM:
+ case luci::CircleOpcode::CUSTOM:
case luci::CircleOpcode::IF:
+ case luci::CircleOpcode::NON_MAX_SUPPRESSION_V4:
+ case luci::CircleOpcode::NON_MAX_SUPPRESSION_V5:
case luci::CircleOpcode::SPLIT:
+ case luci::CircleOpcode::SPLIT_V:
+ case luci::CircleOpcode::TOPK_V2:
+ case luci::CircleOpcode::UNIQUE:
case luci::CircleOpcode::UNPACK:
+ case luci::CircleOpcode::WHILE:
return false;
default:
return true;
}
}
+bool isSupportedCustomNode(const luci::CircleNode *node)
+{
+ const auto custom_node = loco::must_cast<const luci::CircleCustom *>(node);
+
+ // TODO handle more Custom ops here
+ if (custom_node->custom_code() == "CircleReferencingConst")
+ return true;
+
+ return false;
+}
+
} // namespace
GraphLoader::GraphLoader(
@@ -129,18 +186,25 @@ void GraphLoader::loadTensors()
{
const auto *node = loco::must_cast<const luci::CircleNode *>(_graph->nodes()->at(i));
+ if (node->opcode() == luci::CircleOpcode::CUSTOM && !isSupportedCustomNode(node))
+ throw std::runtime_error("Unknown Custom Node, yet.");
+
if (!isTensorProducingNode(node))
continue;
- // Only Input and Const nodes have shapes. Shapes of intermediate tensors will be inferred.
+ // Only Input, Const, Custom and Variable nodes have shapes. Shapes of intermediate tensors will
+ // be inferred.
Shape shape{};
- if (const auto *input_node = dynamic_cast<const luci::CircleInput *>(node))
+ switch (node->opcode())
{
- shape = getNodeShape(input_node);
- }
- else if (const auto *const_node = dynamic_cast<const luci::CircleConst *>(node))
- {
- shape = getNodeShape(const_node);
+ case luci::CircleOpcode::CIRCLECONST:
+ case luci::CircleOpcode::CIRCLECUSTOMOUT:
+ case luci::CircleOpcode::CIRCLEINPUT:
+ case luci::CircleOpcode::CIRCLEVARIABLE:
+ shape = getNodeShape(node);
+ break;
+ default:
+ break;
}
AffineQuantization quantization;
@@ -175,6 +239,22 @@ void GraphLoader::loadTensors()
tensor->writeData(const_data, data_size);
}
}
+ else if (const auto *custom_out_node = dynamic_cast<const luci::CircleCustomOut *>(node))
+ {
+ const auto *custom_node =
+ loco::must_cast<const luci::CircleCustom *>(custom_out_node->input());
+
+ if (custom_node->custom_code() == "CircleReferencingConst")
+ {
+ size_t data_size{};
+ const void *const_data = getNodeData(custom_node, &data_size);
+ if (const_data != nullptr)
+ {
+ _memory_manager->allocate_memory(*tensor);
+ tensor->writeData(const_data, data_size);
+ }
+ }
+ }
_node_to_tensor.emplace(node, tensor.get());
_runtime_to_ir.tensor_to_node.emplace(tensor.get(), node);
diff --git a/compiler/luci-interpreter/src/loader/KernelBuilder.test.cpp b/compiler/luci-interpreter/src/loader/KernelBuilder.test.cpp
index 7a457a62f..b221b6921 100644
--- a/compiler/luci-interpreter/src/loader/KernelBuilder.test.cpp
+++ b/compiler/luci-interpreter/src/loader/KernelBuilder.test.cpp
@@ -21,6 +21,7 @@
#include <kernels/Add.h>
#include <kernels/ArgMax.h>
#include <kernels/AveragePool2D.h>
+#include <kernels/BatchMatMul.h>
#include <kernels/Cast.h>
#include <kernels/Concatenation.h>
#include <kernels/Conv2D.h>
@@ -54,6 +55,7 @@
#include <kernels/Mul.h>
#include <kernels/Neg.h>
#include <kernels/NotEqual.h>
+#include <kernels/OneHot.h>
#include <kernels/Pad.h>
#include <kernels/PadV2.h>
#include <kernels/Pow.h>
@@ -209,6 +211,27 @@ TEST_F(KernelBuilderTest, AveragePool2D)
EXPECT_THAT(kernel->params().activation, Eq(op->fusedActivationFunction()));
}
+TEST_F(KernelBuilderTest, BatchMatMul)
+{
+ auto *lhs = createInputNode();
+ auto *rhs = createInputNode();
+
+ auto *op = createNode<luci::CircleBatchMatMul>();
+ op->x(lhs);
+ op->y(rhs);
+ op->adj_x(false);
+ op->adj_y(false);
+
+ auto kernel = buildKernel<kernels::BatchMatMul>(op);
+ ASSERT_THAT(kernel, NotNull());
+
+ checkTensor(kernel->x(), lhs);
+ checkTensor(kernel->y(), rhs);
+ checkTensor(kernel->output(), op);
+ EXPECT_THAT(kernel->params().adj_x, Eq(op->adj_x()));
+ EXPECT_THAT(kernel->params().adj_y, Eq(op->adj_y()));
+}
+
TEST_F(KernelBuilderTest, Cast)
{
auto *input = createInputNode();
@@ -832,6 +855,31 @@ TEST_F(KernelBuilderTest, NotEqual)
checkTensor(kernel->output(), op);
}
+TEST_F(KernelBuilderTest, OneHot)
+{
+ auto *indices = createInputNode();
+ auto *depth = createInputNode();
+ auto *on_value = createInputNode();
+ auto *off_value = createInputNode();
+ auto axis = 1;
+
+ auto *op = createNode<luci::CircleOneHot>();
+ op->indices(indices);
+ op->depth(depth);
+ op->on_value(on_value);
+ op->off_value(off_value);
+ op->axis(axis);
+
+ auto kernel = buildKernel<kernels::OneHot>(op);
+ ASSERT_THAT(kernel, NotNull());
+
+ checkTensor(kernel->indices(), indices);
+ checkTensor(kernel->depth(), depth);
+ checkTensor(kernel->on_value(), on_value);
+ checkTensor(kernel->off_value(), off_value);
+ EXPECT_THAT(kernel->params().axis, Eq(op->axis()));
+}
+
TEST_F(KernelBuilderTest, Pad)
{
auto *input = createInputNode();
diff --git a/compiler/luci-interpreter/src/loader/nodes/AveragePool2D.cpp b/compiler/luci-interpreter/src/loader/nodes/AveragePool2D.cpp
index 5bc37bd4a..efb011257 100644
--- a/compiler/luci-interpreter/src/loader/nodes/AveragePool2D.cpp
+++ b/compiler/luci-interpreter/src/loader/nodes/AveragePool2D.cpp
@@ -17,6 +17,7 @@
#include "Builders.h"
#include "kernels/AveragePool2D.h"
+#include <luci/Plan/CircleNodeExecutionPlan.h>
namespace luci_interpreter
{
@@ -40,7 +41,26 @@ std::unique_ptr<Kernel> build_kernel_CircleAveragePool2D(const luci::CircleNode
params.stride_width = node->stride()->w();
params.activation = node->fusedActivationFunction();
- return std::make_unique<kernels::AveragePool2D>(input, output, params);
+ // It is unknown what data will be stored in scratchpad tensor,
+ // using UINT8 as a most general option
+ auto scratchpad = std::make_unique<Tensor>(DataType::U8, Shape({}), AffineQuantization{}, "");
+ scratchpad->set_observable(false);
+ scratchpad->set_data_buffer(nullptr);
+ // If node has execution plan then read memory offsets for scratchpad temporary tensor
+ // from the beginning of shared memory buffer.
+ // Used in Static Memory Manager.
+ // TODO move tensors offset initialization to one place
+ if (luci::has_execution_plan(node))
+ {
+ const auto execution_plan = luci::get_execution_plan(node);
+ // Check whether the offset for the current CircleConv2D temporary was found.
+ if (execution_plan.offsets().size() > 1)
+ // If this is true, then we keep this offset in scratchpad.
+ scratchpad->set_offset(execution_plan.offsets().at(1));
+ }
+ Tensor *tmp = helper.getRuntimeGraph(node->graph())->addTensor(std::move(scratchpad));
+
+ return std::make_unique<kernels::AveragePool2D>(input, output, tmp, params);
}
} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/loader/nodes/BatchMatMul.cpp b/compiler/luci-interpreter/src/loader/nodes/BatchMatMul.cpp
new file mode 100644
index 000000000..aae3dbab1
--- /dev/null
+++ b/compiler/luci-interpreter/src/loader/nodes/BatchMatMul.cpp
@@ -0,0 +1,72 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Builders.h"
+
+#include "kernels/BatchMatMul.h"
+#include <luci/Plan/CircleNodeExecutionPlan.h>
+
+namespace luci_interpreter
+{
+
+std::unique_ptr<Kernel> build_kernel_CircleBatchMatMul(const luci::CircleNode *circle_node,
+ KernelBuilderHelper &helper)
+{
+ const auto *node = dynamic_cast<const luci::CircleBatchMatMul *>(circle_node);
+ if (node == nullptr)
+ throw std::runtime_error("wrong builder for operation");
+ assert(node->arity() == 2);
+
+ const Tensor *lhs = helper.getInputTensor(node->x());
+ const Tensor *rhs = helper.getInputTensor(node->y());
+ Tensor *output = helper.getOutputTensor(node);
+
+ auto lhs_scratchpad =
+ std::make_unique<Tensor>(lhs->element_type(), Shape({}), AffineQuantization{}, "");
+ lhs_scratchpad->set_observable(false);
+ lhs_scratchpad->set_data_buffer(nullptr);
+ auto rhs_scratchpad =
+ std::make_unique<Tensor>(rhs->element_type(), Shape({}), AffineQuantization{}, "");
+ rhs_scratchpad->set_observable(false);
+ rhs_scratchpad->set_data_buffer(nullptr);
+ // If node has execution plan then read memory offsets for scratchpad temporary tensor
+ // from the beginning of shared memory buffer.
+ // Used in Static Memory Manager.
+ // TODO move tensors offset initialization to one place
+ if (luci::has_execution_plan(node))
+ {
+ const auto execution_plan = luci::get_execution_plan(node);
+ // Check whether the offset for the current BatchMatMul temporary was found.
+ if (execution_plan.offsets().size() > 1)
+ {
+ assert(execution_plan.offsets().size() == 3);
+
+ // If this is true, then we keep this offset in scratchpad.
+ lhs_scratchpad->set_offset(execution_plan.offsets().at(1));
+ rhs_scratchpad->set_offset(execution_plan.offsets().at(2));
+ }
+ }
+ Tensor *lhs_tmp = helper.getRuntimeGraph(node->graph())->addTensor(std::move(lhs_scratchpad));
+ Tensor *rhs_tmp = helper.getRuntimeGraph(node->graph())->addTensor(std::move(rhs_scratchpad));
+
+ BatchMatMulParams params;
+ params.adj_x = node->adj_x();
+ params.adj_y = node->adj_y();
+
+ return std::make_unique<kernels::BatchMatMul>(lhs, rhs, output, lhs_tmp, rhs_tmp, params);
+}
+
+} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/loader/nodes/Conv2D.cpp b/compiler/luci-interpreter/src/loader/nodes/Conv2D.cpp
index 22fd1aca4..b48d97d19 100644
--- a/compiler/luci-interpreter/src/loader/nodes/Conv2D.cpp
+++ b/compiler/luci-interpreter/src/loader/nodes/Conv2D.cpp
@@ -35,11 +35,12 @@ std::unique_ptr<Kernel> build_kernel_CircleConv2D(const luci::CircleNode *circle
const Tensor *bias = helper.getOptionalInputTensor(node->bias());
Tensor *output = helper.getOutputTensor(node);
- auto im2col =
- std::make_unique<Tensor>(input->element_type(), Shape({}), AffineQuantization{}, "");
- im2col->set_observable(false);
- im2col->set_data_buffer(nullptr);
- // If node has execution plan then read memory offsets for im2col temporary tensor
+ // It is unknown what data will be stored in scratchpad tensor,
+ // using UINT8 as a most general option
+ auto scratchpad = std::make_unique<Tensor>(DataType::U8, Shape({}), AffineQuantization{}, "");
+ scratchpad->set_observable(false);
+ scratchpad->set_data_buffer(nullptr);
+ // If node has execution plan then read memory offsets for scratchpad temporary tensor
// from the beginning of shared memory buffer.
// Used in Static Memory Manager.
// TODO move tensors offset initialization to one place
@@ -48,10 +49,10 @@ std::unique_ptr<Kernel> build_kernel_CircleConv2D(const luci::CircleNode *circle
const auto execution_plan = luci::get_execution_plan(node);
// Check whether the offset for the current CircleConv2D temporary was found.
if (execution_plan.offsets().size() > 1)
- // If this is true, then we keep this offset in im2col.
- im2col->set_offset(execution_plan.offsets().at(1));
+ // If this is true, then we keep this offset in scratchpad.
+ scratchpad->set_offset(execution_plan.offsets().at(1));
}
- Tensor *tmp = helper.getRuntimeGraph(node->graph())->addTensor(std::move(im2col));
+ Tensor *tmp = helper.getRuntimeGraph(node->graph())->addTensor(std::move(scratchpad));
Conv2DParams params{};
params.padding = node->padding();
diff --git a/compiler/luci-interpreter/src/loader/nodes/DepthwiseConv2D.cpp b/compiler/luci-interpreter/src/loader/nodes/DepthwiseConv2D.cpp
index c2f0346a2..db26ecf2e 100644
--- a/compiler/luci-interpreter/src/loader/nodes/DepthwiseConv2D.cpp
+++ b/compiler/luci-interpreter/src/loader/nodes/DepthwiseConv2D.cpp
@@ -17,6 +17,7 @@
#include "Builders.h"
#include "kernels/DepthwiseConv2D.h"
+#include <luci/Plan/CircleNodeExecutionPlan.h>
namespace luci_interpreter
{
@@ -43,7 +44,26 @@ std::unique_ptr<Kernel> build_kernel_CircleDepthwiseConv2D(const luci::CircleNod
params.dilation_width_factor = node->dilation()->w();
params.activation = node->fusedActivationFunction();
- return std::make_unique<kernels::DepthwiseConv2D>(input, filter, bias, output, params);
+ // It is unknown what data will be stored in scratchpad tensor,
+ // using UINT8 as a most general option
+ auto scratchpad = std::make_unique<Tensor>(DataType::U8, Shape({}), AffineQuantization{}, "");
+ scratchpad->set_observable(false);
+ scratchpad->set_data_buffer(nullptr);
+ // If node has execution plan then read memory offsets for scratchpad temporary tensor
+ // from the beginning of shared memory buffer.
+ // Used in Static Memory Manager.
+ // TODO move tensors offset initialization to one place
+ if (luci::has_execution_plan(node))
+ {
+ const auto execution_plan = luci::get_execution_plan(node);
+ // Check whether the offset for the current CircleConv2D temporary was found.
+ if (execution_plan.offsets().size() > 1)
+ // If this is true, then we keep this offset in scratchpad.
+ scratchpad->set_offset(execution_plan.offsets().at(1));
+ }
+ Tensor *tmp = helper.getRuntimeGraph(node->graph())->addTensor(std::move(scratchpad));
+
+ return std::make_unique<kernels::DepthwiseConv2D>(input, filter, bias, output, tmp, params);
}
} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/loader/nodes/Dequantize.cpp b/compiler/luci-interpreter/src/loader/nodes/Dequantize.cpp
new file mode 100644
index 000000000..4aae56469
--- /dev/null
+++ b/compiler/luci-interpreter/src/loader/nodes/Dequantize.cpp
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Builders.h"
+
+#include "kernels/Dequantize.h"
+
+namespace luci_interpreter
+{
+
+std::unique_ptr<Kernel> build_kernel_CircleDequantize(const luci::CircleNode *circle_node,
+ KernelBuilderHelper &helper)
+{
+ const auto *node = dynamic_cast<const luci::CircleDequantize *>(circle_node);
+ if (node == nullptr)
+ throw std::runtime_error("wrong builder for operation");
+
+ const Tensor *input = helper.getInputTensor(node->input());
+ Tensor *output = helper.getOutputTensor(node);
+
+ return std::make_unique<kernels::Dequantize>(input, output);
+}
+
+} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/loader/nodes/ExpandDims.cpp b/compiler/luci-interpreter/src/loader/nodes/ExpandDims.cpp
new file mode 100644
index 000000000..9840c34e5
--- /dev/null
+++ b/compiler/luci-interpreter/src/loader/nodes/ExpandDims.cpp
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Builders.h"
+
+#include "kernels/ExpandDims.h"
+
+namespace luci_interpreter
+{
+
+std::unique_ptr<Kernel> build_kernel_CircleExpandDims(const luci::CircleNode *circle_node,
+ KernelBuilderHelper &helper)
+{
+ const auto *node = loco::must_cast<const luci::CircleExpandDims *>(circle_node);
+ assert(node->arity() == 2);
+
+ const Tensor *input = helper.getInputTensor(node->input());
+ const Tensor *axis = helper.getInputTensor(node->axis());
+ Tensor *output = helper.getOutputTensor(node);
+
+ return std::make_unique<kernels::ExpandDims>(input, axis, output);
+}
+
+} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/loader/nodes/FullyConnected.cpp b/compiler/luci-interpreter/src/loader/nodes/FullyConnected.cpp
index 2917598fc..0b8ac44bd 100644
--- a/compiler/luci-interpreter/src/loader/nodes/FullyConnected.cpp
+++ b/compiler/luci-interpreter/src/loader/nodes/FullyConnected.cpp
@@ -36,6 +36,7 @@ std::unique_ptr<Kernel> build_kernel_CircleFullyConnected(const luci::CircleNode
FullyConnectedParams params{};
params.activation = node->fusedActivationFunction();
+ params.keep_num_dims = node->keep_num_dims();
return std::make_unique<kernels::FullyConnected>(input, weights, bias, output, params);
}
diff --git a/compiler/luci-interpreter/src/loader/nodes/Gather.cpp b/compiler/luci-interpreter/src/loader/nodes/Gather.cpp
new file mode 100644
index 000000000..9df9775c5
--- /dev/null
+++ b/compiler/luci-interpreter/src/loader/nodes/Gather.cpp
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Builders.h"
+
+#include "kernels/Gather.h"
+
+namespace luci_interpreter
+{
+
+std::unique_ptr<Kernel> build_kernel_CircleGather(const luci::CircleNode *circle_node,
+ KernelBuilderHelper &helper)
+{
+ const auto *node = dynamic_cast<const luci::CircleGather *>(circle_node);
+ if (node == nullptr)
+ throw std::runtime_error("wrong builder for operation");
+ assert(node->arity() == 2);
+
+ const Tensor *params = helper.getInputTensor(node->params());
+ const Tensor *indices = helper.getInputTensor(node->indices());
+ Tensor *output = helper.getOutputTensor(node);
+
+ GatherParams gparams{};
+ gparams.axis = node->axis();
+ // TODO support batch_dims
+ gparams.batch_dims = 0;
+
+ return std::make_unique<kernels::Gather>(params, indices, output, gparams);
+}
+
+} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/loader/nodes/OneHot.cpp b/compiler/luci-interpreter/src/loader/nodes/OneHot.cpp
new file mode 100644
index 000000000..a40160945
--- /dev/null
+++ b/compiler/luci-interpreter/src/loader/nodes/OneHot.cpp
@@ -0,0 +1,42 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Builders.h"
+
+#include "kernels/OneHot.h"
+
+namespace luci_interpreter
+{
+
+std::unique_ptr<Kernel> build_kernel_CircleOneHot(const luci::CircleNode *circle_node,
+ KernelBuilderHelper &helper)
+{
+ const auto *node = loco::must_cast<const luci::CircleOneHot *>(circle_node);
+ assert(node->arity() == 4);
+
+ const Tensor *indices = helper.getInputTensor(node->indices());
+ const Tensor *depth = helper.getInputTensor(node->depth());
+ const Tensor *on_value = helper.getInputTensor(node->on_value());
+ const Tensor *off_value = helper.getInputTensor(node->off_value());
+ Tensor *output = helper.getOutputTensor(node);
+
+ OneHotParams params{};
+ params.axis = node->axis();
+
+ return std::make_unique<kernels::OneHot>(indices, depth, on_value, off_value, output, params);
+}
+
+} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/loader/nodes/Quantize.cpp b/compiler/luci-interpreter/src/loader/nodes/Quantize.cpp
new file mode 100644
index 000000000..fd9836345
--- /dev/null
+++ b/compiler/luci-interpreter/src/loader/nodes/Quantize.cpp
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Builders.h"
+
+#include "kernels/Quantize.h"
+
+namespace luci_interpreter
+{
+
+std::unique_ptr<Kernel> build_kernel_CircleQuantize(const luci::CircleNode *circle_node,
+ KernelBuilderHelper &helper)
+{
+ const auto *node = dynamic_cast<const luci::CircleQuantize *>(circle_node);
+ if (node == nullptr)
+ throw std::runtime_error("wrong builder for operation");
+
+ const Tensor *input = helper.getInputTensor(node->input());
+ Tensor *output = helper.getOutputTensor(node);
+
+ return std::make_unique<kernels::Quantize>(input, output);
+}
+
+} // namespace luci_interpreter
diff --git a/compiler/luci-interpreter/src/loader/nodes/SVDF.cpp b/compiler/luci-interpreter/src/loader/nodes/SVDF.cpp
new file mode 100644
index 000000000..89528d5ee
--- /dev/null
+++ b/compiler/luci-interpreter/src/loader/nodes/SVDF.cpp
@@ -0,0 +1,93 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Builders.h"
+
+#include "kernels/SVDF.h"
+
+namespace luci_interpreter
+{
+
+std::unique_ptr<Kernel> build_kernel_CircleSVDF(const luci::CircleNode *circle_node,
+ KernelBuilderHelper &helper)
+{
+ const auto *node = dynamic_cast<const luci::CircleSVDF *>(circle_node);
+ if (node == nullptr)
+ throw std::runtime_error("wrong builder for operation");
+
+ const Tensor *input = helper.getInputTensor(node->input());
+ const Tensor *feature = helper.getInputTensor(node->weight_feature());
+ const Tensor *time = helper.getInputTensor(node->weight_time());
+ const Tensor *bias = helper.getOptionalInputTensor(node->bias());
+ const Tensor *input_activation_state = helper.getInputTensor(node->input_activation_state());
+ Tensor *output = helper.getOutputTensor(node);
+
+ auto scratchpad_tensor = std::make_unique<Tensor>(input_activation_state->element_type(),
+ Shape({}), AffineQuantization{}, "");
+ scratchpad_tensor->set_observable(false);
+ scratchpad_tensor->set_data_buffer(nullptr);
+ Tensor *tmp = helper.getRuntimeGraph(node->graph())->addTensor(std::move(scratchpad_tensor));
+
+ DataType data_type = input->element_type() == DataType::S8 ? DataType::S32 : DataType::FLOAT32;
+
+ scratchpad_tensor = std::make_unique<Tensor>(data_type, Shape({}), AffineQuantization{}, "");
+ scratchpad_tensor->set_observable(false);
+ scratchpad_tensor->set_data_buffer(nullptr);
+ Tensor *tmp_1 = helper.getRuntimeGraph(node->graph())->addTensor(std::move(scratchpad_tensor));
+
+ if (data_type == DataType::FLOAT32 &&
+ (feature->element_type() == DataType::S8 || feature->element_type() == DataType::U8))
+ {
+ data_type = feature->element_type();
+ }
+
+ scratchpad_tensor = std::make_unique<Tensor>(data_type, Shape({}), AffineQuantization{}, "");
+ scratchpad_tensor->set_observable(false);
+ scratchpad_tensor->set_data_buffer(nullptr);
+ Tensor *tmp_2 = helper.getRuntimeGraph(node->graph())->addTensor(std::move(scratchpad_tensor));
+
+ data_type = DataType::FLOAT32;
+
+ scratchpad_tensor = std::make_unique<Tensor>(data_type, Shape({}), AffineQuantization{}, "");
+ scratchpad_tensor->set_observable(false);
+ scratchpad_tensor->set_data_buffer(nullptr);
+ Tensor *tmp_3 = helper.getRuntimeGraph(node->graph())->addTensor(std::move(scratchpad_tensor));
+
+ scratchpad_tensor = std::make_unique<Tensor>(data_type, Shape({}), AffineQuantization{}, "");
+ scratchpad_tensor->set_observable(false);
+ scratchpad_tensor->set_data_buffer(nullptr);
+ Tensor *tmp_4 = helper.getRuntimeGraph(node->graph())->addTensor(std::move(scratchpad_tensor));
+
+ scratchpad_tensor = std::make_unique<Tensor>(data_type, Shape({}), AffineQuantization{}, "");
+ scratchpad_tensor->set_observable(false);
+ scratchpad_tensor->set_data_buffer(nullptr);
+ Tensor *tmp_5 = helper.getRuntimeGraph(node->graph())->addTensor(std::move(scratchpad_tensor));
+
+ scratchpad_tensor = std::make_unique<Tensor>(data_type, Shape({}), AffineQuantization{}, "");
+ scratchpad_tensor->set_observable(false);
+ scratchpad_tensor->set_data_buffer(nullptr);
+ Tensor *tmp_6 = helper.getRuntimeGraph(node->graph())->addTensor(std::move(scratchpad_tensor));
+
+ SVDFParams params{};
+ params.activation = node->fusedActivationFunction();
+ params.svdf_rank = node->svdf_rank();
+ params.asymmetric_quantize_inputs = node->asymmetric_quantize_inputs();
+
+ return std::make_unique<kernels::SVDF>(input, feature, time, bias, input_activation_state, output,
+ tmp, tmp_1, tmp_2, tmp_3, tmp_4, tmp_5, tmp_6, params);
+}
+
+} // namespace luci_interpreter
diff --git a/compiler/luci-micro/CMakeLists.txt b/compiler/luci-micro/CMakeLists.txt
index 94347082c..c8a2e12e1 100644
--- a/compiler/luci-micro/CMakeLists.txt
+++ b/compiler/luci-micro/CMakeLists.txt
@@ -6,7 +6,7 @@ set(ARM_OBJCOPY "arm-none-eabi-objcopy")
find_program(ARM_C_COMPILER_PATH ${ARM_C_COMPILER})
if(NOT ARM_C_COMPILER_PATH)
- message(WARNING "ARM compiler is NOT FOUND, skipping luci-micro build")
+ message(STATUS "Build luci-micro: FALSE(ARM compiler is NOT FOUND)")
return()
endif()
diff --git a/compiler/luci-pass-value-test/CMakeLists.txt b/compiler/luci-pass-value-test/CMakeLists.txt
index b31415870..034fe5269 100644
--- a/compiler/luci-pass-value-test/CMakeLists.txt
+++ b/compiler/luci-pass-value-test/CMakeLists.txt
@@ -1,3 +1,7 @@
+if(NOT ENABLE_TEST)
+ return()
+endif(NOT ENABLE_TEST)
+
unset(TEST_DEPS)
unset(LUCI_PASS_VALUE_TESTS)
@@ -38,7 +42,7 @@ add_test(NAME luci_pass_value_test
COMMAND "${CMAKE_CURRENT_SOURCE_DIR}/eval_driver.sh"
"${CMAKE_CURRENT_BINARY_DIR}"
"${ARTIFACTS_BIN_PATH}"
- "${NNCC_OVERLAY_DIR}/venv_2_6_0"
+ "${NNCC_OVERLAY_DIR}/venv_2_8_0"
"$<TARGET_FILE:luci_eval_driver>"
${LUCI_PASS_VALUE_TESTS}
)
diff --git a/compiler/luci-pass-value-test/eval_result_verifier.py b/compiler/luci-pass-value-test/eval_result_verifier.py
index c6005edfc..0073c4db5 100644
--- a/compiler/luci-pass-value-test/eval_result_verifier.py
+++ b/compiler/luci-pass-value-test/eval_result_verifier.py
@@ -22,6 +22,18 @@ circle_model = args.circle
interpreter = tf.lite.Interpreter(tflite_model)
interpreter.allocate_tensors()
+# Read SignatureDef and get output tensor id orders for remapping
+full_signatures = interpreter._get_full_signature_list()
+full_signatures_outputs_remap = None
+if full_signatures != None:
+ signature_serving_default = full_signatures.get('serving_default', None)
+ if signature_serving_default != None:
+ signature_outputs = signature_serving_default['outputs']
+
+ full_signatures_outputs_remap = []
+ for index, (key, value) in enumerate(signature_outputs.items()):
+ full_signatures_outputs_remap.append(value)
+
# Generate random input data.
num_inputs = len(interpreter.get_input_details())
for i in range(num_inputs):
@@ -33,6 +45,10 @@ for i in range(num_inputs):
input_data = np.array(
np.random.randint(0, 256, size=input_details["shape"]),
input_details["dtype"])
+ elif input_details["dtype"] == np.int16:
+ input_data = np.array(
+ np.random.randint(0, 100, size=input_details["shape"]),
+ input_details["dtype"])
elif input_details["dtype"] == np.bool_:
input_data = np.array(
np.random.choice(a=[True, False], size=input_details["shape"]),
@@ -55,48 +71,38 @@ subprocess.run(
check=True)
# Compare the results.
-for idx in range(len(interpreter.get_output_details())):
- output_details = interpreter.get_output_details()[idx]
+inpt_output_details = interpreter.get_output_details()
+for idx in range(len(inpt_output_details)):
+ output_details = inpt_output_details[idx]
output_data = np.fromfile(circle_model + ".output" + str(idx),
output_details["dtype"])
shape_file = open(circle_model + ".output" + str(idx) + ".shape", 'r')
output_shape = [int(i) for i in shape_file.read().split(',')]
luci_output_data = np.reshape(output_data, output_shape)
+ output_tensor = output_details["index"]
+ if full_signatures_outputs_remap != None:
+ output_tensor = full_signatures_outputs_remap[idx]
+ intp_output_data = interpreter.get_tensor(output_tensor)
try:
if output_details["dtype"] == np.uint8:
- if np.allclose(
- luci_output_data,
- interpreter.get_tensor(
- interpreter.get_output_details()[idx]["index"]),
- rtol=0,
- atol=0) == False:
+ if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) == False:
raise SystemExit("Execution result of " + tflite_model +
" does not match with " + circle_model)
elif output_details["dtype"] == np.float32:
if np.allclose(
- luci_output_data,
- interpreter.get_tensor(
- interpreter.get_output_details()[idx]["index"]),
- rtol=1.e-5,
- atol=1.e-5) == False:
+ luci_output_data, intp_output_data, rtol=1.e-5, atol=1.e-5) == False:
raise SystemExit("Execution result of " + tflite_model +
" does not match with " + circle_model)
elif output_details["dtype"] == np.int64:
- if np.allclose(
- luci_output_data,
- interpreter.get_tensor(
- interpreter.get_output_details()[idx]["index"]),
- rtol=0,
- atol=0) == False:
+ if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) == False:
raise SystemExit("Execution result of " + tflite_model +
" does not match with " + circle_model)
elif output_details["dtype"] == np.int32:
- if np.allclose(
- luci_output_data,
- interpreter.get_tensor(
- interpreter.get_output_details()[idx]["index"]),
- rtol=0,
- atol=0) == False:
+ if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) == False:
+ raise SystemExit("Execution result of " + tflite_model +
+ " does not match with " + circle_model)
+ elif output_details["dtype"] == np.int16:
+ if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) == False:
raise SystemExit("Execution result of " + tflite_model +
" does not match with " + circle_model)
else:
diff --git a/compiler/luci-pass-value-test/test.lst b/compiler/luci-pass-value-test/test.lst
index 9c408887d..67476c644 100644
--- a/compiler/luci-pass-value-test/test.lst
+++ b/compiler/luci-pass-value-test/test.lst
@@ -29,3 +29,7 @@ addeval(Net_InstanceNorm_001 fuse_instnorm)
addeval(Net_InstanceNorm_002 fuse_instnorm)
addeval(Net_InstanceNorm_003 fuse_instnorm)
addeval(Net_StridedSlice_StridedSlice_000 remove_unnecessary_strided_slice)
+
+# test SignatureDef, with any optimization
+#addeval(SignatureDef_MultiOut_000 fuse_instnorm)
+#addeval(SignatureDef_MultiOut_001 fuse_instnorm)
diff --git a/compiler/luci-value-test/CMakeLists.txt b/compiler/luci-value-test/CMakeLists.txt
index 3c7185b80..ebf9c5926 100644
--- a/compiler/luci-value-test/CMakeLists.txt
+++ b/compiler/luci-value-test/CMakeLists.txt
@@ -1,9 +1,18 @@
+if(NOT ENABLE_TEST)
+ return()
+endif(NOT ENABLE_TEST)
+
unset(LUCI_VALUE_TESTS)
+unset(LUCI_VALUE_TESTS_TOL)
macro(addeval NAME)
list(APPEND LUCI_VALUE_TESTS ${NAME})
endmacro(addeval)
+macro(addevaltol NAME RTOL ATOL)
+ list(APPEND LUCI_VALUE_TESTS_TOL ${NAME} ${RTOL} ${ATOL})
+endmacro(addevaltol)
+
# Read "test.lst"
include("test.lst")
# Read "test.local.lst" if exists
@@ -12,13 +21,60 @@ include("test.local.lst" OPTIONAL)
# Generate dependencies
add_custom_target(luci_eval_testfiles ALL DEPENDS ${TESTFILES})
-get_target_property(ARTIFACTS_BIN_PATH testDataGenerator BINARY_DIR)
+if(NOT CMAKE_CROSSCOMPILING)
+
+ get_target_property(ARTIFACTS_BIN_PATH testDataGenerator BINARY_DIR)
+
+ add_test(NAME luci_value_test
+ COMMAND "${CMAKE_CURRENT_SOURCE_DIR}/evalverify.sh"
+ "${CMAKE_CURRENT_BINARY_DIR}"
+ "${ARTIFACTS_BIN_PATH}"
+ "${NNCC_OVERLAY_DIR}/venv_2_8_0"
+ "$<TARGET_FILE:luci_eval_driver>"
+ ${LUCI_VALUE_TESTS}
+ )
+
+ if(DEFINED LUCI_VALUE_TESTS_TOL)
+ add_test(NAME luci_value_tol_test
+ COMMAND "${CMAKE_CURRENT_SOURCE_DIR}/evalverifytol.sh"
+ "${CMAKE_CURRENT_BINARY_DIR}"
+ "${ARTIFACTS_BIN_PATH}"
+ "${NNCC_OVERLAY_DIR}/venv_2_8_0"
+ "$<TARGET_FILE:luci_eval_driver>"
+ ${LUCI_VALUE_TESTS_TOL}
+ )
+ endif()
+
+else(NOT CMAKE_CROSSCOMPILING)
+ # NOTE target test is carried out using reference input/output data from host
+ # test results. this is because it would be difficult to prepare
+ # TensorFlow lite for target device.
+ # thus, one must run the host test and then run the test in target device
+ # with the test result files from the host test.
+
+ if(NOT DEFINED ENV{BUILD_HOST_EXEC})
+ message(STATUS "BUILD_HOST_EXEC not set: Skip luci-value-test")
+ return()
+ endif(NOT DEFINED ENV{BUILD_HOST_EXEC})
+
+ set(ARTIFACTS_BIN_PATH $ENV{BUILD_HOST_EXEC}/compiler/common-artifacts)
+
+ add_test(NAME luci_value_cross_test
+ COMMAND "${CMAKE_CURRENT_SOURCE_DIR}/evalverify_ref.sh"
+ "${CMAKE_CURRENT_BINARY_DIR}"
+ "${ARTIFACTS_BIN_PATH}"
+ "$<TARGET_FILE:luci_eval_driver>"
+ ${LUCI_VALUE_TESTS}
+ )
+
+ if(DEFINED LUCI_VALUE_TESTS_TOL)
+ add_test(NAME luci_value_cross_tol_test
+ COMMAND "${CMAKE_CURRENT_SOURCE_DIR}/evalverifytol_ref.sh"
+ "${CMAKE_CURRENT_BINARY_DIR}"
+ "${ARTIFACTS_BIN_PATH}"
+ "$<TARGET_FILE:luci_eval_driver>"
+ ${LUCI_VALUE_TESTS_TOL}
+ )
+ endif()
-add_test(NAME luci_value_test
- COMMAND "${CMAKE_CURRENT_SOURCE_DIR}/evalverify.sh"
- "${CMAKE_CURRENT_BINARY_DIR}"
- "${ARTIFACTS_BIN_PATH}"
- "${NNCC_OVERLAY_DIR}/venv_2_6_0"
- "$<TARGET_FILE:luci_eval_driver>"
- ${LUCI_VALUE_TESTS}
-)
+endif(NOT CMAKE_CROSSCOMPILING)
diff --git a/compiler/luci-value-test/evalverify.sh b/compiler/luci-value-test/evalverify.sh
index 01c4bce46..3d2091176 100755
--- a/compiler/luci-value-test/evalverify.sh
+++ b/compiler/luci-value-test/evalverify.sh
@@ -4,10 +4,12 @@
#
# HOW TO USE
#
-# ./evalverify.sh <path/to/bin_dir> <path/to/work_dir> <path/to/venv_dir> <TEST 1> <TEST 2> ...
+# ./evalverify.sh <path/to/bin_dir> <path/to/work_dir> <path/to/venv_dir> <path/to/eval_driver> \
+# <TEST 1> <TEST 2> ...
# bin_dir : build directory of luci-value-test (ex: build/compiler/luci-value-test)
# work_dir : artifacts directoy where test materials exist
# venv_dir : python virtual environment home directory
+# eval_driver : luci_eval_driver path for evaluation
VERIFY_SOURCE_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
VERIFY_SCRIPT_PATH="${VERIFY_SOURCE_PATH}/luci_eval_verifier.py"
diff --git a/compiler/luci-value-test/evalverify_ref.sh b/compiler/luci-value-test/evalverify_ref.sh
new file mode 100755
index 000000000..f1e538aa3
--- /dev/null
+++ b/compiler/luci-value-test/evalverify_ref.sh
@@ -0,0 +1,63 @@
+#!/bin/bash
+
+# This script verifies the basic behavior of luci interpreter
+#
+# HOW TO USE
+#
+# ./evalverify_ref.sh <path/to/bin_dir> <path/to/ref_dir> <path/to/eval_driver> \
+# <TEST 1> <TEST 2> ...
+# bin_dir : build directory of luci-value-test (ex: build/compiler/luci-value-test)
+# ref_dir : artifacts directoy where reference test materials exist
+# eval_driver : luci_eval_driver path for evaluation
+
+VERIFY_SOURCE_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+VERIFY_SCRIPT_PATH="${VERIFY_SOURCE_PATH}/luci_eval_verifier_ref.py"
+BINDIR="$1"; shift
+REFDIR="$1"; shift
+INTERPRETER_DRIVER_PATH="$1"; shift
+
+TESTED=()
+PASSED=()
+FAILED=()
+
+for TESTCASE in "$@"; do
+ TESTED+=("${TESTCASE}")
+
+ TESTCASE_FILE="${REFDIR}/${TESTCASE}"
+ TEST_RESULT_FILE="${BINDIR}/${TESTCASE}"
+
+ PASSED_TAG="${TEST_RESULT_FILE}.passed"
+ rm -f "${PASSED_TAG}"
+
+ cat > "${TEST_RESULT_FILE}.log" <(
+ exec 2>&1
+ set -ex
+
+ "python3" "${VERIFY_SCRIPT_PATH}" \
+ --driver "${INTERPRETER_DRIVER_PATH}" \
+ --model_ref "${TESTCASE_FILE}" \
+ --work_path "${TEST_RESULT_FILE}"
+
+ if [[ $? -eq 0 ]]; then
+ touch "${PASSED_TAG}"
+ fi
+ )
+
+ if [[ -f "${PASSED_TAG}" ]]; then
+ PASSED+=("${TESTCASE}")
+ else
+ FAILED+=("${TESTCASE}")
+ fi
+done
+
+if [[ ${#TESTED[@]} -ne ${#PASSED[@]} ]]; then
+ echo "FAILED"
+ for TEST in "${FAILED[@]}"
+ do
+ echo "- ${TEST}"
+ done
+ exit 255
+fi
+
+echo "PASSED"
+exit 0
diff --git a/compiler/luci-value-test/evalverifytol.sh b/compiler/luci-value-test/evalverifytol.sh
new file mode 100755
index 000000000..92094055a
--- /dev/null
+++ b/compiler/luci-value-test/evalverifytol.sh
@@ -0,0 +1,71 @@
+#!/bin/bash
+
+# This script verifies the basic behavior of luci interpreter
+#
+# HOW TO USE
+#
+# ./evalverifytol.sh <path/to/bin_dir> <path/to/work_dir> <path/to/venv_dir> <path/to/eval_driver> \
+# <TEST 1> <RTOL 1> <ATOL 1> <TEST 2> <RTOL 2> <ATOL 2> ...
+# bin_dir : build directory of luci-value-test (ex: build/compiler/luci-value-test)
+# work_dir : artifacts directoy where test materials exist
+# venv_dir : python virtual environment home directory
+
+VERIFY_SOURCE_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+VERIFY_SCRIPT_PATH="${VERIFY_SOURCE_PATH}/luci_eval_verifier.py"
+BINDIR="$1"; shift
+WORKDIR="$1"; shift
+VIRTUALENV="$1"; shift
+INTERPRETER_DRIVER_PATH="$1"; shift
+
+TESTED=()
+PASSED=()
+FAILED=()
+
+while (( "$#" >= 3 )); do
+ TESTCASE=$1
+ RTOLERANCE=$2
+ ATOLERANCE=$3
+ shift 3
+
+ TESTED+=("${TESTCASE}")
+
+ TESTCASE_FILE="${WORKDIR}/${TESTCASE}"
+ TEST_RESULT_FILE="${BINDIR}/${TESTCASE}"
+
+ PASSED_TAG="${TEST_RESULT_FILE}.passed"
+ rm -f "${PASSED_TAG}"
+
+ cat > "${TEST_RESULT_FILE}.log" <(
+ exec 2>&1
+ set -ex
+
+ source "${VIRTUALENV}/bin/activate"
+ "${VIRTUALENV}/bin/python" "${VERIFY_SCRIPT_PATH}" \
+ --driver "${INTERPRETER_DRIVER_PATH}" \
+ --model "${TESTCASE_FILE}" \
+ --rtolf32 "${RTOLERANCE}" \
+ --atolf32 "${ATOLERANCE}"
+
+ if [[ $? -eq 0 ]]; then
+ touch "${PASSED_TAG}"
+ fi
+ )
+
+ if [[ -f "${PASSED_TAG}" ]]; then
+ PASSED+=("${TESTCASE}")
+ else
+ FAILED+=("${TESTCASE}")
+ fi
+done
+
+if [[ ${#TESTED[@]} -ne ${#PASSED[@]} ]]; then
+ echo "FAILED"
+ for TEST in "${FAILED[@]}"
+ do
+ echo "- ${TEST}"
+ done
+ exit 255
+fi
+
+echo "PASSED"
+exit 0
diff --git a/compiler/luci-value-test/evalverifytol_ref.sh b/compiler/luci-value-test/evalverifytol_ref.sh
new file mode 100755
index 000000000..cc7267b18
--- /dev/null
+++ b/compiler/luci-value-test/evalverifytol_ref.sh
@@ -0,0 +1,70 @@
+#!/bin/bash
+
+# This script verifies the basic behavior of luci interpreter
+#
+# HOW TO USE
+#
+# ./evalverifytol_ref.sh <path/to/bin_dir> <path/to/ref_dir> <path/to/eval_driver> \
+# <TEST 1> <RTOL 1> <ATOL 1> <TEST 2> <RTOL 2> <ATOL 2> ...
+# bin_dir : build directory of luci-value-test (ex: build/compiler/luci-value-test)
+# ref_dir : artifacts directoy where reference test materials exist
+# eval_driver : luci_eval_driver path for evaluation
+
+VERIFY_SOURCE_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+VERIFY_SCRIPT_PATH="${VERIFY_SOURCE_PATH}/luci_eval_verifier_ref.py"
+BINDIR="$1"; shift
+REFDIR="$1"; shift
+INTERPRETER_DRIVER_PATH="$1"; shift
+
+TESTED=()
+PASSED=()
+FAILED=()
+
+while (( "$#" >= 3 )); do
+ TESTCASE=$1
+ RTOLERANCE=$2
+ ATOLERANCE=$3
+ shift 3
+
+ TESTED+=("${TESTCASE}")
+
+ TESTCASE_FILE="${REFDIR}/${TESTCASE}"
+ TEST_RESULT_FILE="${BINDIR}/${TESTCASE}"
+
+ PASSED_TAG="${TEST_RESULT_FILE}.passed"
+ rm -f "${PASSED_TAG}"
+
+ cat > "${TEST_RESULT_FILE}.log" <(
+ exec 2>&1
+ set -ex
+
+ "python3" "${VERIFY_SCRIPT_PATH}" \
+ --driver "${INTERPRETER_DRIVER_PATH}" \
+ --model_ref "${TESTCASE_FILE}" \
+ --work_path "${TEST_RESULT_FILE}" \
+ --rtolf32 "${RTOLERANCE}" \
+ --atolf32 "${ATOLERANCE}"
+
+ if [[ $? -eq 0 ]]; then
+ touch "${PASSED_TAG}"
+ fi
+ )
+
+ if [[ -f "${PASSED_TAG}" ]]; then
+ PASSED+=("${TESTCASE}")
+ else
+ FAILED+=("${TESTCASE}")
+ fi
+done
+
+if [[ ${#TESTED[@]} -ne ${#PASSED[@]} ]]; then
+ echo "FAILED"
+ for TEST in "${FAILED[@]}"
+ do
+ echo "- ${TEST}"
+ done
+ exit 255
+fi
+
+echo "PASSED"
+exit 0
diff --git a/compiler/luci-value-test/luci_eval_verifier.py b/compiler/luci-value-test/luci_eval_verifier.py
index a76bd1403..560e34fca 100755
--- a/compiler/luci-value-test/luci_eval_verifier.py
+++ b/compiler/luci-value-test/luci_eval_verifier.py
@@ -14,16 +14,41 @@ import traceback
parser = argparse.ArgumentParser()
parser.add_argument('--driver', type=str, required=True)
parser.add_argument('--model', type=str, required=True)
+parser.add_argument('--rtolf32', type=str, required=False)
+parser.add_argument('--atolf32', type=str, required=False)
args = parser.parse_args()
driver = args.driver
tflite_model = args.model + ".tflite"
circle_model = args.model + ".circle"
+rtolf32 = 1e-5
+atolf32 = 1e-5
+try:
+ if args.rtolf32 != None:
+ rtolf32 = float(args.rtolf32)
+ if args.atolf32 != None:
+ atolf32 = float(args.atolf32)
+except ValueError:
+ print("rtolf32 or atolf32 is not a number")
+ quit(128)
+
# Build TFLite interpreter.
interpreter = tf.lite.Interpreter(tflite_model)
interpreter.allocate_tensors()
+# Read SignatureDef and get output tensor id orders for remapping
+full_signatures = interpreter._get_full_signature_list()
+full_signatures_outputs_remap = None
+if full_signatures != None:
+ signature_serving_default = full_signatures.get('serving_default', None)
+ if signature_serving_default != None:
+ signature_outputs = signature_serving_default['outputs']
+
+ full_signatures_outputs_remap = []
+ for index, (key, value) in enumerate(signature_outputs.items()):
+ full_signatures_outputs_remap.append(value)
+
# Generate random input data.
num_inputs = len(interpreter.get_input_details())
for i in range(num_inputs):
@@ -31,19 +56,40 @@ for i in range(num_inputs):
if input_details["dtype"] == np.float32:
input_data = np.array(
np.random.random_sample(input_details["shape"]), input_details["dtype"])
+ input_dtype = "float32"
elif input_details["dtype"] == np.uint8:
input_data = np.array(
np.random.randint(0, 256, size=input_details["shape"]),
input_details["dtype"])
+ input_dtype = "uint8"
+ elif input_details["dtype"] == np.int16:
+ input_data = np.array(
+ np.random.randint(0, 100, size=input_details["shape"]),
+ input_details["dtype"])
+ input_dtype = "int16"
+ elif input_details["dtype"] == np.int32:
+ input_data = np.array(
+ np.random.randint(0, 100, size=input_details["shape"]),
+ input_details["dtype"])
+ input_dtype = "int32"
+ elif input_details["dtype"] == np.int64:
+ input_data = np.array(
+ np.random.randint(0, 100, size=input_details["shape"]),
+ input_details["dtype"])
+ input_dtype = "int64"
elif input_details["dtype"] == np.bool_:
input_data = np.array(
np.random.choice(a=[True, False], size=input_details["shape"]),
input_details["dtype"])
+ input_dtype = "bool"
else:
raise SystemExit("Unsupported input dtype")
interpreter.set_tensor(input_details["index"], input_data)
input_data.tofile(circle_model + ".input" + str(i))
+ input_details["shape"].tofile(circle_model + ".input" + str(i) + ".shape", sep=',')
+ with open(circle_model + ".input" + str(i) + ".dtype", 'w') as dtype_file:
+ dtype_file.write(input_dtype)
# Do inference
interpreter.invoke()
@@ -57,34 +103,57 @@ subprocess.run(
check=True)
# Compare the results.
-for idx in range(len(interpreter.get_output_details())):
- output_details = interpreter.get_output_details()[idx]
+inpt_output_details = interpreter.get_output_details()
+for idx in range(len(inpt_output_details)):
+ output_details = inpt_output_details[idx]
output_data = np.fromfile(circle_model + ".output" + str(idx),
output_details["dtype"])
shape_file = open(circle_model + ".output" + str(idx) + ".shape", 'r')
output_shape = [int(i) for i in shape_file.read().split(',')]
luci_output_data = np.reshape(output_data, output_shape)
- intp_output_data = interpreter.get_tensor(output_details["index"])
+ output_tensor = output_details["index"]
+ if full_signatures_outputs_remap != None:
+ output_tensor = full_signatures_outputs_remap[idx]
+ intp_output_data = interpreter.get_tensor(output_tensor)
try:
if output_details["dtype"] == np.uint8:
if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) == False:
raise SystemExit("Execution result of " + tflite_model +
" does not match with " + circle_model)
+ output_dtype = "uint8"
elif output_details["dtype"] == np.float32:
if np.allclose(
- luci_output_data, intp_output_data, rtol=1.e-5, atol=1.e-5) == False:
+ luci_output_data, intp_output_data, rtol=rtolf32,
+ atol=atolf32) == False:
raise SystemExit("Execution result of " + tflite_model +
" does not match with " + circle_model)
+ output_dtype = "float32"
elif output_details["dtype"] == np.int64:
if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) == False:
raise SystemExit("Execution result of " + tflite_model +
" does not match with " + circle_model)
+ output_dtype = "int64"
elif output_details["dtype"] == np.int32:
if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) == False:
raise SystemExit("Execution result of " + tflite_model +
" does not match with " + circle_model)
+ output_dtype = "int32"
+ elif output_details["dtype"] == np.int16:
+ if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) == False:
+ raise SystemExit("Execution result of " + tflite_model +
+ " does not match with " + circle_model)
+ output_dtype = "int16"
+ elif output_details["dtype"] == np.bool_:
+ if np.allclose(luci_output_data, intp_output_data, rtol=0, atol=0) == False:
+ raise SystemExit("Execution result of " + tflite_model +
+ " does not match with " + circle_model)
+ output_dtype = "bool"
else:
raise SystemExit("Unsupported data type: ", output_details["dtype"])
+
+ # save outputN.dtype file
+ with open(circle_model + ".output" + str(idx) + ".dtype", 'w') as dtype_file:
+ dtype_file.write(output_dtype)
except:
print(traceback.format_exc())
quit(255)
diff --git a/compiler/luci-value-test/luci_eval_verifier_ref.py b/compiler/luci-value-test/luci_eval_verifier_ref.py
new file mode 100755
index 000000000..5313e336e
--- /dev/null
+++ b/compiler/luci-value-test/luci_eval_verifier_ref.py
@@ -0,0 +1,151 @@
+#!/usr/bin/env python3
+import numpy as np
+import subprocess
+import argparse
+import traceback
+import os
+
+#
+# This script compares the execution result of luci-interpreter with that from ref_model path
+#
+# Basic usage:
+# luci_eval_verifier_ref.py --driver build/compiler/luci-eval-driver/luci_eval_driver
+# --ref_model ref_model_path --model this_model_path
+# Assumption:
+# these file exist with its purpose
+# - ref_model_path.circle; circle model
+# - ref_model_path.circle.inputN; N'th input numpy data
+# - ref_model_path.circle.inputN.dtype; N'th input data type in text
+# - ref_model_path.circle.inputN.shape; N'th input data shape in CSV
+# - ref_model_path.circle.outputN; N'th output numpy data
+# - ref_model_path.circle.outputN.dtype; N'th output data type in text
+# - ref_model_path.circle.outputN.shape; N'th output data shape in CSV
+
+
+def dtype_from_file(file_path):
+ with open(file_path, 'r') as dtype_file:
+ dtype_str = dtype_file.read()
+ if dtype_str == "float32":
+ return np.float32
+ if dtype_str == "uint8":
+ return np.uint8
+ if dtype_str == "int16":
+ return np.int16
+ if dtype_str == "int32":
+ return np.int32
+ if dtype_str == "int64":
+ return np.int64
+ if dtype_str == "bool":
+ return np.bool_
+ raise SystemExit("Unsupported dtype from file", dtype_str)
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--driver', type=str, required=True)
+parser.add_argument('--model_ref', type=str, required=True)
+parser.add_argument('--work_path', type=str, required=True)
+parser.add_argument('--rtolf32', type=str, required=False)
+parser.add_argument('--atolf32', type=str, required=False)
+args = parser.parse_args()
+
+driver = args.driver
+circle_model_ref = args.model_ref + ".circle"
+circle_model = args.work_path + ".circle"
+# circle_model is used as to follow existing luci_eval_verifier.py
+
+rtolf32 = 1e-5
+atolf32 = 1e-5
+try:
+ if args.rtolf32 != None:
+ rtolf32 = float(args.rtolf32)
+ if args.atolf32 != None:
+ atolf32 = float(args.atolf32)
+except ValueError:
+ print("rtolf32 or atolf32 is not a number")
+ quit(128)
+
+# get num of inputs by checking existance of model.inputN
+check_input = 0
+while True:
+ input_file_path = circle_model_ref + ".input" + str(check_input)
+ if not os.path.isfile(input_file_path):
+ num_inputs = check_input
+ break
+ check_input = check_input + 1
+
+if num_inputs == 0:
+ print("input file not exist for", circle_model_ref)
+ quit(128)
+
+# get num of outputs by checking existance of model.outputN
+check_output = 0
+while True:
+ output_file_path = circle_model_ref + ".output" + str(check_output)
+ if not os.path.isfile(output_file_path):
+ num_outputs = check_output
+ break
+ check_output = check_output + 1
+
+if num_outputs == 0:
+ print("output file not exist for", circle_model_ref)
+ quit(128)
+
+# Execute luci interpreter with reference input
+subprocess.run(
+ [
+ driver, circle_model_ref,
+ str(num_inputs), circle_model_ref + ".input", circle_model + ".output"
+ ],
+ check=True)
+
+# Compare the results.
+for idx in range(num_outputs):
+ output_dtype = dtype_from_file(circle_model_ref + ".output" + str(idx) + ".dtype")
+ shape_file = open(circle_model_ref + ".output" + str(idx) + ".shape", 'r')
+ output_shape = [int(i) for i in shape_file.read().split(',')]
+
+ output_data_ref = np.fromfile(circle_model_ref + ".output" + str(idx), output_dtype)
+ luci_output_data_ref = np.reshape(output_data_ref, output_shape)
+
+ output_data = np.fromfile(circle_model + ".output" + str(idx), output_dtype)
+ luci_output_data = np.reshape(output_data, output_shape)
+
+ try:
+ if output_dtype == np.uint8:
+ if np.allclose(
+ luci_output_data, luci_output_data_ref, rtol=0, atol=0) == False:
+ raise SystemExit("Execution result of " + circle_model_ref +
+ " does not match with " + circle_model)
+ elif output_dtype == np.float32:
+ if np.allclose(
+ luci_output_data, luci_output_data_ref, rtol=rtolf32,
+ atol=atolf32) == False:
+ raise SystemExit("Execution result of " + circle_model_ref +
+ " does not match with " + circle_model)
+ elif output_dtype == np.int64:
+ if np.allclose(
+ luci_output_data, luci_output_data_ref, rtol=0, atol=0) == False:
+ raise SystemExit("Execution result of " + circle_model_ref +
+ " does not match with " + circle_model)
+ elif output_dtype == np.int32:
+ if np.allclose(
+ luci_output_data, luci_output_data_ref, rtol=0, atol=0) == False:
+ raise SystemExit("Execution result of " + circle_model_ref +
+ " does not match with " + circle_model)
+ elif output_dtype == np.int16:
+ if np.allclose(
+ luci_output_data, luci_output_data_ref, rtol=0, atol=0) == False:
+ raise SystemExit("Execution result of " + circle_model_ref +
+ " does not match with " + circle_model)
+ elif output_dtype == np.bool_:
+ if np.allclose(
+ luci_output_data, luci_output_data_ref, rtol=0, atol=0) == False:
+ raise SystemExit("Execution result of " + circle_model_ref +
+ " does not match with " + circle_model)
+ else:
+ raise SystemExit("Unsupported data type: ", output_dtype)
+ except:
+ print(traceback.format_exc())
+ quit(255)
+
+quit(0)
diff --git a/compiler/luci-value-test/test.lst b/compiler/luci-value-test/test.lst
index 2b5c93fa3..f62b72919 100644
--- a/compiler/luci-value-test/test.lst
+++ b/compiler/luci-value-test/test.lst
@@ -20,90 +20,90 @@ addeval(ArgMax_U8_003)
#addeval(ArgMin_U8_002)
#addeval(ArgMin_U8_003)
addeval(AveragePool2D_000)
-#addeval(BatchMatMul_000)
+addeval(BatchMatMul_000)
#addeval(BatchMatMulV2_000)
#addeval(BatchMatMulV2_001)
#addeval(BatchToSpaceND_000)
-#addeval(Cast_000)
-#addeval(Cast_001)
+addeval(Cast_000)
+addeval(Cast_001)
#addeval(Ceil_000)
addeval(Concatenation_000)
addeval(Concatenation_U8_000)
addeval(Conv2D_000)
addeval(Conv2D_001)
addeval(Conv2D_002)
-#addeval(Conv2D_003)
+addeval(Conv2D_003)
addeval(Conv2D_U8_000)
addeval(Conv2D_U8_001)
#addeval(Cos_000)
-#addeval(DepthToSpace_000)
+addeval(DepthToSpace_000)
addeval(DepthwiseConv2D_000)
addeval(DepthwiseConv2D_U8_000)
#addeval(DepthwiseConv2D_U8_001)
addeval(DepthwiseConv2D_001)
-#addeval(Div_000)
+addeval(Div_000)
addeval(ELU_000)
-#addeval(Equal_000)
-#addeval(Exp_000)
+addeval(Equal_000)
+addeval(Exp_000)
#addeval(ExpandDims_000)
#addeval(ExpandDims_001)
#addeval(ExpandDims_002)
#addeval(ExpandDims_003)
#addeval(Fill_000)
#addeval(Fill_001)
-#addeval(Floor_000)
-#addeval(FloorDiv_000)
-#addeval(FloorDiv_001)
+addeval(Floor_000)
+addeval(FloorDiv_000)
+addeval(FloorDiv_001)
#addeval(FloorMod_000)
#addeval(FloorMod_001)
addeval(FullyConnected_000)
addeval(FullyConnected_001)
addeval(FullyConnected_002)
#addeval(FullyConnected_U8_000)
-#addeval(Gather_000)
+addeval(Gather_000)
#addeval(GatherNd_000)
#addeval(Greater_000)
-#addeval(GreaterEqual_000)
+addeval(GreaterEqual_000)
addeval(If_000)
addeval(If_001)
addeval(L2Normalize_000)
addeval(L2Pool2D_000)
#addeval(L2Pool2D_U8_000)
addeval(LeakyRelu_000)
-#addeval(Less_000)
-#addeval(LessEqual_000)
+addeval(Less_000)
+addeval(LessEqual_000)
addeval(LocalResponseNormalization_000)
#addeval(Log_000)
-#addeval(LogicalAnd_000)
-#addeval(LogicalNot_000)
-#addeval(LogicalOr_000)
+addeval(LogicalAnd_000)
+addeval(LogicalNot_000)
+addeval(LogicalOr_000)
addeval(Logistic_000)
-#addeval(LogSoftmax_000)
+addeval(LogSoftmax_000)
#addeval(MatMul_000)
#addeval(MatrixDiag_000)
#addeval(MatrixSetDiag_000)
-#addeval(Maximum_000)
+addeval(Maximum_000)
addeval(MaxPool2D_000)
addeval(MaxPool2D_U8_000)
addeval(Mean_000)
addeval(Mean_001)
-#addeval(Mean_U8_000)
-#addeval(Minimum_000)
+addeval(Mean_U8_000)
+addeval(Minimum_000)
#addeval(MirrorPad_000)
addeval(Mul_000)
#addeval(Mul_U8_000)
-#addeval(Neg_000)
-#addeval(NotEqual_000)
-#addeval(OneHot_000)
-#addeval(OneHot_001)
-#addeval(OneHot_002)
+addeval(Neg_000)
+addeval(NotEqual_000)
+addeval(OneHot_000)
+addeval(OneHot_001)
+addeval(OneHot_002)
#addeval(OneHot_003)
-#addeval(Pack_000)
-#addeval(Pack_U8_000)
+addeval(Pack_000)
+addeval(Pack_U8_000)
addeval(Pad_000)
addeval(Pad_U8_000)
-#addeval(Pow_000)
-#addeval(PRelu_000)
+addeval(Pow_000)
+addeval(PRelu_000)
#addeval(Range_000)
#addeval(Rank_000)
#addeval(ReduceAny_000)
@@ -116,20 +116,20 @@ addeval(Pad_U8_000)
#addeval(ReduceProd_001)
#addeval(ReduceProd_002)
#addeval(ReduceProd_003)
-#addeval(ReLU_000)
-#addeval(ReLU6_000)
+addeval(ReLU_000)
+addeval(ReLU6_000)
#addeval(ReLUN1To1_000)
addeval(Reshape_000)
addeval(Reshape_001)
addeval(Reshape_002)
#addeval(Reshape_003)
addeval(Reshape_U8_000)
-#addeval(ResizeBilinear_000)
-#addeval(ResizeNearestNeighbor_000)
+addeval(ResizeBilinear_000)
+addeval(ResizeNearestNeighbor_000)
#addeval(ReverseSequence_000)
#addeval(ReverseV2_000)
#addeval(Round_000)
-#addeval(Rsqrt_000)
+addeval(Rsqrt_000)
#addeval(ScatterNd_000)
#addeval(SegmentSum_000)
#addeval(Select_000)
@@ -139,37 +139,39 @@ addeval(Reshape_U8_000)
#addeval(SelectV2_001)
#addeval(SelectV2_002)
#addeval(Shape_000)
+addeval(SignatureDef_MultiOut_000)
+addeval(SignatureDef_MultiOut_001)
#addeval(Sin_000)
addeval(Slice_000)
addeval(Softmax_000)
-#addeval(Softmax_U8_000)
-#addeval(SpaceToBatchND_000)
-#addeval(SpaceToBatchND_001)
-#addeval(SpaceToBatchND_002)
-#addeval(SpaceToBatchND_003)
+addeval(Softmax_U8_000)
+addeval(SpaceToBatchND_000)
+addeval(SpaceToBatchND_001)
+addeval(SpaceToBatchND_002)
+addeval(SpaceToBatchND_003)
addeval(SpaceToDepth_000)
#addeval(SparseToDense_000)
addeval(Split_000)
-#addeval(SplitV_000)
-#addeval(Sqrt_000)
-#addeval(Square_000)
-#addeval(SquaredDifference_000)
+addeval(SplitV_000)
+addeval(Sqrt_000)
+addeval(Square_000)
+addeval(SquaredDifference_000)
addeval(Squeeze_000)
addeval(Squeeze_001)
addeval(StridedSlice_000)
addeval(StridedSlice_001)
addeval(StridedSlice_002)
-#addeval(Sub_000)
-#addeval(Sub_U8_000)
+addeval(Sub_000)
+addeval(Sub_U8_000)
#addeval(Sum_000)
#addeval(Sum_001)
-#addeval(Tanh_000)
+addeval(Tanh_000)
#addeval(Tile_000)
#addeval(Tile_U8_000)
#addeval(TopKV2_000)
#addeval(TopKV2_001)
addeval(Transpose_000)
-#addeval(TransposeConv_000)
+addeval(TransposeConv_000)
addeval(Unpack_000)
addeval(Unpack_001)
addeval(Unpack_002)
@@ -180,9 +182,13 @@ addeval(Unpack_003)
#addeval(While_001)
#addeval(While_002)
#addeval(While_003)
-#addeval(YUV_TO_RGB_U8_000)
+addeval(YUV_TO_RGB_U8_000)
#addeval(ZerosLike_000)
# Simple Network test
addeval(Part_While_000)
addeval(Part_While_001)
+
+# Tests with tolerance
+addevaltol(SVDF_000 8e-3 8e-3)
+addevaltol(SVDF_001 8e-3 8e-3)
diff --git a/compiler/luci/CMakeLists.txt b/compiler/luci/CMakeLists.txt
index b92eefb40..460dc7b23 100644
--- a/compiler/luci/CMakeLists.txt
+++ b/compiler/luci/CMakeLists.txt
@@ -23,4 +23,8 @@ add_subdirectory(import)
add_subdirectory(export)
add_subdirectory(tester)
+if(NOT ENABLE_TEST)
+ return()
+endif(NOT ENABLE_TEST)
+
add_subdirectory(tests)
diff --git a/compiler/luci/export/CMakeLists.txt b/compiler/luci/export/CMakeLists.txt
index a267d0e1f..f46181eb6 100644
--- a/compiler/luci/export/CMakeLists.txt
+++ b/compiler/luci/export/CMakeLists.txt
@@ -12,7 +12,7 @@ target_include_directories(luci_export PUBLIC include)
target_link_libraries(luci_export PRIVATE luci_lang)
target_link_libraries(luci_export PRIVATE luci_service)
target_link_libraries(luci_export PRIVATE luci_pass)
-target_link_libraries(luci_export PRIVATE mio_circle)
+target_link_libraries(luci_export PRIVATE mio_circle04)
target_link_libraries(luci_export PRIVATE luci_env)
target_link_libraries(luci_export PRIVATE luci_log)
target_link_libraries(luci_export PRIVATE luci_logex)
@@ -36,6 +36,6 @@ target_include_directories(luci_export_test PRIVATE src)
target_link_libraries(luci_export_test luci_export)
target_link_libraries(luci_export_test luci_plan)
target_link_libraries(luci_export_test luci_lang)
-target_link_libraries(luci_export_test mio_circle)
+target_link_libraries(luci_export_test mio_circle04)
target_link_libraries(luci_export_test luci_env)
target_link_libraries(luci_export_test oops)
diff --git a/compiler/luci/export/src/CircleBuiltinTypesExtractor.h b/compiler/luci/export/src/CircleBuiltinTypesExtractor.h
new file mode 100644
index 000000000..0ff21a34b
--- /dev/null
+++ b/compiler/luci/export/src/CircleBuiltinTypesExtractor.h
@@ -0,0 +1,539 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CIRCLE_BUILTIN_TYPES_EXTRACTOR_H__
+#define __CIRCLE_BUILTIN_TYPES_EXTRACTOR_H__
+
+#include "CircleExporterUtils.h"
+
+#include <luci/IR/CircleNode.h>
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+#include <flatbuffers/flexbuffers.h>
+
+namespace luci
+{
+
+// NOTE Virtual nodes are not circle builtin operators.
+// Therefore, they are not defined here.
+class BuiltinOptionsExtractor final
+ : public luci::CircleNodeMutableVisitor<flatbuffers::Offset<void>>
+{
+public:
+ BuiltinOptionsExtractor(flatbuffers::FlatBufferBuilder &builder) : _builder{builder}
+ {
+ // DO NOTHING
+ }
+
+public:
+ flatbuffers::Offset<void> visit(luci::CircleAbs *)
+ {
+ return circle::CreateAbsOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleAdd *node)
+ {
+ return circle::CreateAddOptions(_builder, to_circle_actfunc(node->fusedActivationFunction()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleAddN *)
+ {
+ return circle::CreateAddNOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleArgMax *node)
+ {
+ return circle::CreateArgMaxOptions(_builder, luci::to_circle_tensortype(node->output_type()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleArgMin *node)
+ {
+ return circle::CreateArgMinOptions(_builder, luci::to_circle_tensortype(node->output_type()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleAveragePool2D *node)
+ {
+ return circle::CreatePool2DOptions(_builder, getOpPadding(node->padding()), node->stride()->w(),
+ node->stride()->h(), node->filter()->w(),
+ node->filter()->h(),
+ to_circle_actfunc(node->fusedActivationFunction()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleBatchMatMul *node)
+ {
+ return circle::CreateBatchMatMulOptions(_builder, node->adj_x(), node->adj_y()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleBatchToSpaceND *)
+ {
+ return circle::CreateBatchToSpaceNDOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleBidirectionalSequenceLSTM *node)
+ {
+ return circle::CreateBidirectionalSequenceLSTMOptions(
+ _builder, to_circle_actfunc(node->fusedActivationFunction()), node->cell_clip(),
+ node->proj_clip(), node->merge_outputs(), node->time_major(),
+ node->asymmetric_quantize_inputs())
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleCast *node)
+ {
+ if (node->out_data_type() == loco::DataType::Unknown)
+ return _no_option;
+ else
+ return circle::CreateCastOptions(_builder, luci::to_circle_tensortype(node->in_data_type()),
+ luci::to_circle_tensortype(node->out_data_type()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleCeil *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleConcatenation *node)
+ {
+ return circle::CreateConcatenationOptions(_builder, node->axis(),
+ to_circle_actfunc(node->fusedActivationFunction()))
+ .Union();
+ }
+ // CircleConst is not virtual but not builtinOperator
+ // flatbuffers::Offset<void> visit(luci::CircleConst *)
+ flatbuffers::Offset<void> visit(luci::CircleConv2D *node)
+ {
+ return circle::CreateConv2DOptions(_builder, getOpPadding(node->padding()), node->stride()->w(),
+ node->stride()->h(),
+ to_circle_actfunc(node->fusedActivationFunction()),
+ node->dilation()->w(), node->dilation()->h())
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleCos *)
+ {
+ return circle::CreateCosOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleCustom *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleDepthToSpace *node)
+ {
+ return circle::CreateDepthToSpaceOptions(_builder, node->block_size()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleDepthwiseConv2D *node)
+ {
+ return circle::CreateDepthwiseConv2DOptions(
+ _builder, getOpPadding(node->padding()), node->stride()->w(), node->stride()->h(),
+ node->depthMultiplier(), to_circle_actfunc(node->fusedActivationFunction()),
+ node->dilation()->w(), node->dilation()->h())
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleDequantize *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleDiv *node)
+ {
+ return circle::CreateDivOptions(_builder, to_circle_actfunc(node->fusedActivationFunction()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleElu *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleEqual *)
+ {
+ return circle::CreateEqualOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleExp *)
+ {
+ return circle::CreateExpOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleExpandDims *)
+ {
+ return circle::CreateExpandDimsOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleFakeQuant *node)
+ {
+ return circle::CreateFakeQuantOptions(_builder, node->min(), node->max(), node->num_bits(),
+ node->narrow_range())
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleFill *)
+ {
+ return circle::CreateFillOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleFloor *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleFloorDiv *)
+ {
+ return circle::CreateFloorDivOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleFloorMod *)
+ {
+ return circle::CreateFloorModOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleFullyConnected *node)
+ {
+ return circle::CreateFullyConnectedOptions(
+ _builder, to_circle_actfunc(node->fusedActivationFunction()),
+ to_circle_weightsformat(node->weights_format()), node->keep_num_dims())
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleGather *node)
+ {
+ return circle::CreateGatherOptions(_builder, node->axis()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleGatherNd *)
+ {
+ return circle::CreateGatherNdOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleGreater *)
+ {
+ return circle::CreateGreaterOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleGreaterEqual *)
+ {
+ return circle::CreateGreaterEqualOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleIf *node)
+ {
+ return circle::CreateIfOptions(_builder, node->then_branch(), node->else_branch()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleL2Normalize *node)
+ {
+ return circle::CreateL2NormOptions(_builder, to_circle_actfunc(node->fusedActivationFunction()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleL2Pool2D *node)
+ {
+ return circle::CreatePool2DOptions(_builder, getOpPadding(node->padding()), node->stride()->w(),
+ node->stride()->h(), node->filter()->w(),
+ node->filter()->h(),
+ to_circle_actfunc(node->fusedActivationFunction()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleLeakyRelu *node)
+ {
+ return circle::CreateLeakyReluOptions(_builder, node->alpha()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleLess *)
+ {
+ return circle::CreateLessOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleLessEqual *)
+ {
+ return circle::CreateLessEqualOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleLocalResponseNormalization *node)
+ {
+ return circle::CreateLocalResponseNormalizationOptions(_builder, node->radius(), node->bias(),
+ node->alpha(), node->beta())
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleLog *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleLogicalAnd *)
+ {
+ return circle::CreateLogicalAndOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleLogicalNot *)
+ {
+ return circle::CreateLogicalNotOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleLogicalOr *)
+ {
+ return circle::CreateLogicalOrOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleLogistic *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleLogSoftmax *)
+ {
+ return circle::CreateLogSoftmaxOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleMatrixDiag *)
+ {
+ return circle::CreateMatrixDiagOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleMatrixSetDiag *)
+ {
+ return circle::CreateMatrixSetDiagOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleMaximum *)
+ {
+ return circle::CreateMaximumMinimumOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleMaxPool2D *node)
+ {
+ return circle::CreatePool2DOptions(_builder, getOpPadding(node->padding()), node->stride()->w(),
+ node->stride()->h(), node->filter()->w(),
+ node->filter()->h(),
+ to_circle_actfunc(node->fusedActivationFunction()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleMean *node)
+ {
+ return circle::CreateReducerOptions(_builder, node->keep_dims()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleMinimum *)
+ {
+ return circle::CreateMaximumMinimumOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleMirrorPad *node)
+ {
+ return circle::CreateMirrorPadOptions(_builder, to_circle_mirrorpadmode(node->mode())).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleMul *node)
+ {
+ return circle::CreateMulOptions(_builder, to_circle_actfunc(node->fusedActivationFunction()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleNeg *)
+ {
+ return circle::CreateNegOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleNonMaxSuppressionV4 *)
+ {
+ return circle::CreateNonMaxSuppressionV4Options(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleNonMaxSuppressionV5 *)
+ {
+ return circle::CreateNonMaxSuppressionV5Options(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleNotEqual *)
+ {
+ return circle::CreateNotEqualOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleOneHot *node)
+ {
+ return circle::CreateOneHotOptions(_builder, node->axis()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CirclePack *node)
+ {
+ return circle::CreatePackOptions(_builder, node->values_count(), node->axis()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CirclePad *)
+ {
+ return circle::CreatePadOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CirclePadV2 *)
+ {
+ return circle::CreatePadV2Options(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CirclePow *)
+ {
+ return circle::CreatePowOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CirclePRelu *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleQuantize *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleRange *)
+ {
+ return circle::CreateRangeOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleRank *)
+ {
+ return circle::CreateRankOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleReduceAny *node)
+ {
+ return circle::CreateReducerOptions(_builder, node->keep_dims()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleReduceMax *node)
+ {
+ return circle::CreateReducerOptions(_builder, node->keep_dims()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleReduceMin *node)
+ {
+ return circle::CreateReducerOptions(_builder, node->keep_dims()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleReduceProd *node)
+ {
+ return circle::CreateReducerOptions(_builder, node->keep_dims()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleRelu *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleRelu6 *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleReluN1To1 *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleReshape *node)
+ {
+ auto new_shape = _builder.CreateVector<int32_t>(
+ node->newShape()->rank(), [node](size_t i) { return node->newShape()->dim(i); });
+ return circle::CreateReshapeOptions(_builder, new_shape).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleResizeBilinear *node)
+ {
+ return circle::CreateResizeBilinearOptions(_builder, node->align_corners(),
+ node->half_pixel_centers())
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleResizeNearestNeighbor *node)
+ {
+ return circle::CreateResizeNearestNeighborOptions(_builder, node->align_corners()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleReverseSequence *node)
+ {
+ return circle::CreateReverseSequenceOptions(_builder, node->seq_axis(), node->batch_axis())
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleReverseV2 *)
+ {
+ return circle::CreateReverseV2Options(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleRound *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleRsqrt *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleScatterNd *)
+ {
+ return circle::CreateScatterNdOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSegmentSum *)
+ {
+ return circle::CreateSegmentSumOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSelect *)
+ {
+ return circle::CreateSelectOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSelectV2 *)
+ {
+ return circle::CreateSelectV2Options(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleShape *node)
+ {
+ return circle::CreateShapeOptions(_builder, luci::to_circle_tensortype(node->out_type()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSin *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleSlice *)
+ {
+ return circle::CreateSliceOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSoftmax *node)
+ {
+ return circle::CreateSoftmaxOptions(_builder, node->beta()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSpaceToBatchND *)
+ {
+ return circle::CreateSpaceToBatchNDOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSpaceToDepth *node)
+ {
+ return circle::CreateSpaceToDepthOptions(_builder, node->block_size()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSparseToDense *node)
+ {
+ return circle::CreateSparseToDenseOptions(_builder, node->validate_indices()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSplit *node)
+ {
+ return circle::CreateSplitOptions(_builder, node->num_split()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSplitV *node)
+ {
+ return circle::CreateSplitVOptions(_builder, node->num_split()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSqrt *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleSquare *)
+ {
+ return circle::CreateSquareOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSquaredDifference *)
+ {
+ return circle::CreateSquaredDifferenceOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSqueeze *node)
+ {
+ auto squeeze_dims = _builder.CreateVector<int32_t>(node->squeeze_dims());
+ return circle::CreateSqueezeOptions(_builder, squeeze_dims).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleStridedSlice *node)
+ {
+ return circle::CreateStridedSliceOptions(_builder, node->begin_mask(), node->end_mask(),
+ node->ellipsis_mask(), node->new_axis_mask(),
+ node->shrink_axis_mask())
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSub *node)
+ {
+ return circle::CreateSubOptions(_builder, to_circle_actfunc(node->fusedActivationFunction()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSum *node)
+ {
+ return circle::CreateReducerOptions(_builder, node->keep_dims()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleSVDF *node)
+ {
+ return circle::CreateSVDFOptions(_builder, node->svdf_rank(),
+ to_circle_actfunc(node->fusedActivationFunction()),
+ node->asymmetric_quantize_inputs())
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleTanh *) { return _no_option; }
+ flatbuffers::Offset<void> visit(luci::CircleTile *)
+ {
+ return circle::CreateTileOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleTopKV2 *)
+ {
+ return circle::CreateTopKV2Options(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleTranspose *)
+ {
+ return circle::CreateTransposeOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleTransposeConv *node)
+ {
+ return circle::CreateTransposeConvOptions(_builder, getOpPadding(node->padding()),
+ node->stride()->w(), node->stride()->h())
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleUnidirectionalSequenceLSTM *node)
+ {
+ return circle::CreateUnidirectionalSequenceLSTMOptions(
+ _builder, to_circle_actfunc(node->fusedActivationFunction()), node->cell_clip(),
+ node->proj_clip(), node->time_major(), node->asymmetric_quantize_inputs())
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleUnique *node)
+ {
+ return circle::CreateUniqueOptions(_builder, luci::to_circle_tensortype(node->idx_out_type()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleUnpack *node)
+ {
+ return circle::CreateUnpackOptions(_builder, node->num(), node->axis()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleWhere *)
+ {
+ return circle::CreateWhereOptions(_builder).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleWhile *node)
+ {
+ return circle::CreateWhileOptions(_builder, node->cond_branch(), node->body_branch()).Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleZerosLike *)
+ {
+ return circle::CreateZerosLikeOptions(_builder).Union();
+ }
+ // Circle only
+ flatbuffers::Offset<void> visit(luci::CircleBCQFullyConnected *node)
+ {
+ return circle::CreateBCQFullyConnectedOptions(
+ _builder, node->weights_hidden_size(),
+ to_circle_actfunc(node->fusedActivationFunction()))
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleBCQGather *node)
+ {
+ return circle::CreateBCQGatherOptions(_builder, node->input_hidden_size(), node->axis())
+ .Union();
+ }
+ flatbuffers::Offset<void> visit(luci::CircleInstanceNorm *node)
+ {
+ return circle::CreateInstanceNormOptions(_builder, node->epsilon(),
+ to_circle_actfunc(node->fusedActivationFunction()))
+ .Union();
+ }
+
+protected:
+ flatbuffers::FlatBufferBuilder &_builder;
+
+private:
+ const flatbuffers::Offset<void> _no_option = 0;
+};
+
+} // namespace luci
+
+#endif // __CIRCLE_BUILTIN_TYPES_EXTRACTOR_H__
diff --git a/compiler/luci/export/src/CircleBuiltinTypesMappingRule.h b/compiler/luci/export/src/CircleBuiltinTypesMappingRule.h
new file mode 100644
index 000000000..6f7c0f70e
--- /dev/null
+++ b/compiler/luci/export/src/CircleBuiltinTypesMappingRule.h
@@ -0,0 +1,79 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CIRCLE_EXPORT_BUILTIN_TYPES_MAPPING_RULE_H__
+#define __CIRCLE_EXPORT_BUILTIN_TYPES_MAPPING_RULE_H__
+
+#include <luci/IR/CircleNode.h>
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+namespace luci
+{
+
+class BuiltinOperatorMappingRule final : public CircleNodeVisitor<circle::BuiltinOperator>
+{
+public:
+ BuiltinOperatorMappingRule()
+ {
+ // DO NOTHING
+ }
+
+public:
+ static BuiltinOperatorMappingRule &get()
+ {
+ static BuiltinOperatorMappingRule instance;
+ return instance;
+ }
+
+public:
+#define CIRCLE_NODE(CIRCLE_NODE, OP, OPTION) \
+ circle::BuiltinOperator visit(const CIRCLE_NODE *) final { return circle::OP; }
+// Virtual nodes are not circle builtin operator
+#define CIRCLE_VNODE(CIRCLE_NODE)
+#include "CircleOps.lst"
+#undef CIRCLE_VNODE
+#undef CIRCLE_NODE
+};
+
+class BuiltinOptionsMappingRule final : public CircleNodeVisitor<circle::BuiltinOptions>
+{
+public:
+ BuiltinOptionsMappingRule()
+ {
+ // DO NOTHING
+ }
+
+public:
+ static BuiltinOptionsMappingRule &get()
+ {
+ static BuiltinOptionsMappingRule instance;
+ return instance;
+ }
+
+public:
+#define CIRCLE_NODE(CIRCLE_NODE, OP, OPTION) \
+ circle::BuiltinOptions visit(const CIRCLE_NODE *) final { return circle::OPTION; }
+// Virtual nodes are not circle builtin operator
+#define CIRCLE_VNODE(CIRCLE_NODE)
+#include "CircleOps.lst"
+#undef CIRCLE_VNODE
+#undef CIRCLE_NODE
+};
+
+} // namespace luci
+
+#endif // __CIRCLE_EXPORT_BUILTIN_TYPES_MAPPING_RULE_H__
diff --git a/compiler/luci/export/src/CircleExporterImpl.cpp b/compiler/luci/export/src/CircleExporterImpl.cpp
index 5868c176c..083add9be 100644
--- a/compiler/luci/export/src/CircleExporterImpl.cpp
+++ b/compiler/luci/export/src/CircleExporterImpl.cpp
@@ -79,14 +79,19 @@ encodeOperatorCodes(FlatBufferBuilder &builder, std::unordered_map<luci::OpCode,
for (auto it : opcodes)
{
uint32_t idx = it.second;
+ int8_t dep_code = 127; // BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES
+ if (it.first.opcode < BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES)
+ dep_code = static_cast<int8_t>(it.first.opcode);
if (it.first.opcode != BuiltinOperator_CUSTOM)
{
- operator_codes_vec[idx] = CreateOperatorCode(builder, it.first.opcode, 0, it.first.version);
+ operator_codes_vec[idx] =
+ CreateOperatorCode(builder, dep_code, 0, it.first.version, it.first.opcode);
}
else
{
operator_codes_vec[idx] =
- CreateOperatorCode(builder, it.first.opcode, builder.CreateString(it.first.custom_code));
+ CreateOperatorCode(builder, dep_code, builder.CreateString(it.first.custom_code),
+ it.first.version, it.first.opcode);
}
}
diff --git a/compiler/luci/export/src/CircleExporterUtils.cpp b/compiler/luci/export/src/CircleExporterUtils.cpp
index 3a7ba304f..9473c2c4e 100644
--- a/compiler/luci/export/src/CircleExporterUtils.cpp
+++ b/compiler/luci/export/src/CircleExporterUtils.cpp
@@ -15,6 +15,7 @@
*/
#include "CircleExporterUtils.h"
+#include "CircleBuiltinTypesMappingRule.h"
#include <oops/InternalExn.h>
@@ -163,36 +164,63 @@ circle::SparseIndexVector to_circle_sparse_index_vector_type(luci::SparseIndexVe
}
}
-} // namespace luci
+circle::BuiltinOperator circle_builtin_operator(const luci::CircleNode *node)
+{
+ return node->accept(&BuiltinOperatorMappingRule::get());
+}
-namespace luci
+circle::BuiltinOptions circle_builtin_options(const luci::CircleNode *node)
{
+ if (auto cast = dynamic_cast<const luci::CircleCast *>(node))
+ {
+ return (cast->out_data_type() == loco::DataType::Unknown) ? circle::BuiltinOptions_NONE
+ : circle::BuiltinOptions_CastOptions;
+ }
-uint32_t SerializedModelData::registerBuiltinOpcode(circle::BuiltinOperator builtin_code,
- const int32_t op_version)
+ return node->accept(&BuiltinOptionsMappingRule::get());
+}
+
+std::string circle_custom_code(const luci::CircleNode *node)
{
- assert(op_version > 0);
+ if (auto custom_node = dynamic_cast<const luci::CircleCustom *>(node))
+ {
+ return custom_node->custom_code();
+ }
- auto it = _operator_codes.find(OpCode{builtin_code, "", op_version});
- if (it != _operator_codes.end())
+ return "";
+}
+
+flatbuffers::Offset<flatbuffers::Vector<uint8_t>>
+circle_custom_options(flatbuffers::FlatBufferBuilder &fb, const luci::CircleNode *node)
+{
+ if (auto custom_node = dynamic_cast<const luci::CircleCustom *>(node))
{
- return it->second;
+ std::vector<uint8_t> custom_options_vec{custom_node->custom_options().begin(),
+ custom_node->custom_options().end()};
+ return fb.CreateVector(custom_options_vec);
}
- auto idx = static_cast<uint32_t>(_operator_codes.size());
- _operator_codes.emplace(OpCode{builtin_code, "", op_version}, idx);
- return idx;
+
+ return 0;
}
-uint32_t SerializedModelData::registerCustomOpcode(const std::string &custom_code)
+} // namespace luci
+
+namespace luci
{
- const circle::BuiltinOperator builtin_code = circle::BuiltinOperator_CUSTOM;
- auto it = _operator_codes.find(OpCode{builtin_code, custom_code});
+
+uint32_t SerializedModelData::registerBuiltinOpcode(circle::BuiltinOperator builtin_code,
+ const std::string &custom_code,
+ const int32_t op_version)
+{
+ assert(op_version > 0);
+
+ auto it = _operator_codes.find(OpCode{builtin_code, custom_code, op_version});
if (it != _operator_codes.end())
{
return it->second;
}
auto idx = static_cast<uint32_t>(_operator_codes.size());
- _operator_codes.emplace(OpCode{builtin_code, custom_code}, idx);
+ _operator_codes.emplace(OpCode{builtin_code, custom_code, op_version}, idx);
return idx;
}
diff --git a/compiler/luci/export/src/CircleExporterUtils.h b/compiler/luci/export/src/CircleExporterUtils.h
index 95310b353..4a4c54a69 100644
--- a/compiler/luci/export/src/CircleExporterUtils.h
+++ b/compiler/luci/export/src/CircleExporterUtils.h
@@ -39,6 +39,12 @@ flatbuffers::Offset<void> to_circle_sparse_index_vector(flatbuffers::FlatBufferB
const SparseIndexVector &sparse_idx_vec);
circle::SparseIndexVector to_circle_sparse_index_vector_type(luci::SparseIndexVectorType type);
+circle::BuiltinOperator circle_builtin_operator(const luci::CircleNode *node);
+circle::BuiltinOptions circle_builtin_options(const luci::CircleNode *node);
+std::string circle_custom_code(const luci::CircleNode *node);
+flatbuffers::Offset<flatbuffers::Vector<uint8_t>>
+circle_custom_options(flatbuffers::FlatBufferBuilder &fb, const luci::CircleNode *node);
+
} // namespace luci
namespace luci
diff --git a/compiler/luci/export/src/CircleOperationExporter.cpp b/compiler/luci/export/src/CircleOperationExporter.cpp
index be64a52d4..b300a7fcf 100644
--- a/compiler/luci/export/src/CircleOperationExporter.cpp
+++ b/compiler/luci/export/src/CircleOperationExporter.cpp
@@ -15,1686 +15,30 @@
*/
#include "CircleOperationExporter.h"
-#include "CircleExporterUtils.h"
-#include "Check.h"
+#include "CircleOperationExporterRule.h"
#include <luci/IR/CircleNode.h>
-#include <luci/IR/CircleNodes.h>
-#include <luci/IR/CircleNodeVisitor.h>
#include <luci/Profile/CircleNodeOrigin.h>
#include <luci/Plan/CircleNodeExecutionPlan.h>
-#include <luci/UserSettings.h>
-#include <luci/Log.h>
+#include <loco/IR/Algorithm.h>
-#include <loco/IR/CanonicalNodeVisitor.h>
-#include <oops/InternalExn.h>
-
-#include <flatbuffers/flexbuffers.h>
-
-using namespace flatbuffers;
-using namespace circle;
-
-namespace
-{
-
-using namespace luci;
-
-struct ExportContext
-{
- FlatBufferBuilder &builder;
- SerializedModelData &md;
- SerializedGraphData &gd;
-};
-
-/**
- * @brief Exports CircleMaxPool2D or CircleAveragePool2D
- *
- * @note CirclePool2D should be one of CircleMaxPool2D or CircleAveragePool2D
- */
-template <class CirclePool2D>
-void export_pool_2d(ExportContext &ctx, CirclePool2D *node, circle::BuiltinOperator builtin_op)
-{
- LUCI_ASSERT(builtin_op == circle::BuiltinOperator_MAX_POOL_2D ||
- builtin_op == circle::BuiltinOperator_L2_POOL_2D ||
- builtin_op == circle::BuiltinOperator_AVERAGE_POOL_2D,
- "Should be L2Pool, MaxPool or AvgPool");
- LUCI_ASSERT(node->padding() != luci::Padding::UNDEFINED, "Padding is not set");
-
- uint32_t op_idx = ctx.md.registerBuiltinOpcode(builtin_op, node->op_version());
- std::vector<int32_t> inputs_vec{get_tensor_index(node->value())};
- std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
-
- circle::Padding padding = getOpPadding(node->padding());
-
- auto options = CreatePool2DOptions(ctx.builder, padding, node->stride()->w(), node->stride()->h(),
- node->filter()->w(), node->filter()->h(),
- to_circle_actfunc(node->fusedActivationFunction()));
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_Pool2DOptions, options.Union());
- ctx.gd._operators.push_back(op_offset);
-}
-
-/**
- * @brief export simple nodes
- */
-void export_node(ExportContext &ctx, loco::Node *node, circle::BuiltinOperator bop,
- circle::BuiltinOptions bot, flatbuffers::Offset<void> options_offset)
-{
- uint32_t op_idx =
- ctx.md.registerBuiltinOpcode(bop, loco::must_cast<luci::CircleNode *>(node)->op_version());
- std::vector<int32_t> inputs_vec;
- std::vector<int32_t> outputs_vec{get_tensor_index(node)};
- for (uint32_t i = 0; i < node->arity(); ++i)
- inputs_vec.push_back(get_tensor_index(node->arg(i)));
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs, bot, options_offset);
- ctx.gd._operators.push_back(op_offset);
-}
-
-/**
- * @brief export simple nodes having void options
- */
-void export_node(ExportContext &ctx, loco::Node *node, circle::BuiltinOperator bop)
-{
- uint32_t op_idx =
- ctx.md.registerBuiltinOpcode(bop, loco::must_cast<luci::CircleNode *>(node)->op_version());
- std::vector<int32_t> inputs_vec;
- std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
- for (uint32_t i = 0; i < node->arity(); ++i)
- inputs_vec.push_back(get_tensor_index(node->arg(i)));
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs);
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleAddN *node)
-{
- uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_ADD_N, node->op_version());
- std::vector<int32_t> inputs_vec;
- std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
-
- for (uint32_t i = 0; i < node->arity(); ++i)
- inputs_vec.push_back(get_tensor_index(node->inputs(i)));
-
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto options = CreateAddNOptions(ctx.builder);
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_AddNOptions, options.Union());
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleCast *node)
-{
- uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_CAST, node->op_version());
- std::vector<int32_t> inputs_vec{get_tensor_index(node->x())};
- std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
-
- flatbuffers::Offset<Operator> op_offset;
- if (node->out_data_type() != loco::DataType::Unknown)
- {
- auto options = CreateCastOptions(ctx.builder, to_circle_tensortype(node->in_data_type()),
- to_circle_tensortype(node->out_data_type()));
- op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_CastOptions, options.Union());
- }
- else
- {
- op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs);
- }
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleConcatenation *node)
-{
- uint32_t op_idx =
- ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_CONCATENATION, node->op_version());
- std::vector<int32_t> inputs_vec;
- std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
-
- for (uint32_t i = 0; i < node->numValues(); ++i)
- inputs_vec.push_back(get_tensor_index(node->values(i)));
-
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto options = CreateConcatenationOptions(ctx.builder, node->axis(),
- to_circle_actfunc(node->fusedActivationFunction()));
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_ConcatenationOptions, options.Union());
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleCustom *node)
-{
- auto custom_outputs = loco::succs(node);
- assert(custom_outputs.size() == node->numOutputs());
-
- uint32_t op_idx = ctx.md.registerCustomOpcode(node->custom_code());
- std::vector<int32_t> inputs_vec;
- std::vector<int32_t> outputs_vec;
-
- for (uint32_t index = 0; index < node->numInputs(); index++)
- {
- inputs_vec.push_back(get_tensor_index(node->inputs(index)));
- }
- for (uint32_t index = 0; index < custom_outputs.size(); index++)
- {
- // store in order of index
- bool found = false;
- for (auto out : custom_outputs)
- {
- auto custom_out = loco::must_cast<luci::CircleCustomOut *>(out);
- if (custom_out->index() == static_cast<int32_t>(index))
- {
- outputs_vec.push_back(get_tensor_index(custom_out));
- found = true;
- break;
- }
- }
- if (!found)
- {
- INTERNAL_EXN("Invalid Custom output");
- }
- }
-
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- flatbuffers::Offset<flatbuffers::Vector<uint8_t>> circle_custom_options;
- std::vector<uint8_t> custom_options_vec{node->custom_options().begin(),
- node->custom_options().end()};
- circle_custom_options = ctx.builder.CreateVector(custom_options_vec);
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs, circle::BuiltinOptions_NONE,
- flatbuffers::Offset<void>(), circle_custom_options);
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleIf *node)
-{
- auto if_outs = loco::succs(node);
- assert(if_outs.size() == node->output_count());
-
- uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_IF, node->op_version());
- std::vector<int32_t> inputs_vec;
- std::vector<int32_t> outputs_vec;
-
- inputs_vec.push_back(get_tensor_index(node->cond()));
- for (uint32_t idx = 0; idx < node->input_count(); ++idx)
- inputs_vec.push_back(get_tensor_index(node->input(idx)));
-
- for (uint32_t idx = 0; idx < node->output_count(); ++idx)
- {
- // store in order of index
- bool found = false;
- for (auto out : if_outs)
- {
- auto if_out = loco::must_cast<luci::CircleIfOut *>(out);
- if (if_out->index() == static_cast<int32_t>(idx))
- {
- outputs_vec.push_back(get_tensor_index(if_out));
- found = true;
- break;
- }
- }
- if (!found)
- {
- INTERNAL_EXN("Invalid CircleIf output");
- }
- }
-
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto options = CreateIfOptions(ctx.builder, node->then_branch(), node->else_branch());
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_IfOptions, options.Union());
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleNonMaxSuppressionV4 *node)
-{
- auto nms_outs = loco::succs(node);
- assert(nms_outs.size() == 2);
-
- uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_NON_MAX_SUPPRESSION_V4,
- node->op_version());
- std::vector<int32_t> inputs_vec{
- get_tensor_index(node->boxes()), get_tensor_index(node->scores()),
- get_tensor_index(node->max_output_size()), get_tensor_index(node->iou_threshold()),
- get_tensor_index(node->score_threshold()),
- };
- std::vector<int32_t> outputs_vec;
-
- for (uint32_t idx = 0; idx < nms_outs.size(); ++idx)
- {
- // store in order of index
- bool found = false;
- for (auto out : nms_outs)
- {
- auto nms_out = loco::must_cast<luci::CircleNonMaxSuppressionV4Out *>(out);
- if (nms_out->index() == static_cast<int32_t>(idx))
- {
- outputs_vec.push_back(get_tensor_index(nms_out));
- found = true;
- break;
- }
- }
- if (!found)
- {
- INTERNAL_EXN("Invalid NonMaxSuppressionV4 output");
- }
- }
-
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto options = CreateNonMaxSuppressionV4Options(ctx.builder);
- auto op_offset =
- CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_NonMaxSuppressionV4Options, options.Union());
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleNonMaxSuppressionV5 *node)
-{
- auto nms_outs = loco::succs(node);
- assert(nms_outs.size() == 3);
-
- uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_NON_MAX_SUPPRESSION_V5,
- node->op_version());
- std::vector<int32_t> inputs_vec{
- get_tensor_index(node->boxes()), get_tensor_index(node->scores()),
- get_tensor_index(node->max_output_size()), get_tensor_index(node->iou_threshold()),
- get_tensor_index(node->score_threshold()), get_tensor_index(node->soft_nms_sigma()),
- };
- std::vector<int32_t> outputs_vec;
-
- for (uint32_t idx = 0; idx < nms_outs.size(); ++idx)
- {
- // store in order of index
- bool found = false;
- for (auto out : nms_outs)
- {
- auto nms_out = loco::must_cast<luci::CircleNonMaxSuppressionV5Out *>(out);
- if (nms_out->index() == static_cast<int32_t>(idx))
- {
- outputs_vec.push_back(get_tensor_index(nms_out));
- found = true;
- break;
- }
- }
- if (!found)
- {
- INTERNAL_EXN("Invalid NonMaxSuppressionV5 output");
- }
- }
-
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto options = CreateNonMaxSuppressionV5Options(ctx.builder);
- auto op_offset =
- CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_NonMaxSuppressionV5Options, options.Union());
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleReverseV2 *node)
-{
- uint32_t op_idx =
- ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_REVERSE_V2, node->op_version());
- std::vector<int32_t> inputs_vec{get_tensor_index(node->tensor()), get_tensor_index(node->axis())};
- std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto options = CreateReverseV2Options(ctx.builder);
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_ReverseSequenceOptions, options.Union());
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleSplit *node)
-{
- auto split_outs = loco::succs(node);
- assert(int32_t(split_outs.size()) == node->num_split());
-
- uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_SPLIT, node->op_version());
- // NOTE BuiltinOperator_SPLIT input is placed at second position
- std::vector<int32_t> inputs_vec{get_tensor_index(node->split_dim()),
- get_tensor_index(node->input())};
- std::vector<int32_t> outputs_vec;
-
- for (int32_t index = 0; index < node->num_split(); index++)
- {
- // store in order of index
- bool found = false;
- for (auto out : split_outs)
- {
- auto split_out = loco::must_cast<luci::CircleSplitOut *>(out);
- if (split_out->index() == index)
- {
- outputs_vec.push_back(get_tensor_index(split_out));
- found = true;
- break;
- }
- }
- if (!found)
- {
- INTERNAL_EXN("Invalid Split output");
- }
- }
-
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto options = CreateSplitOptions(ctx.builder, node->num_split());
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_SplitOptions, options.Union());
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleSplitV *node)
-{
- auto split_outs = loco::succs(node);
- assert(int32_t(split_outs.size()) == node->num_split());
-
- uint32_t op_idx =
- ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_SPLIT_V, node->op_version());
- std::vector<int32_t> inputs_vec{get_tensor_index(node->input()),
- get_tensor_index(node->size_splits()),
- get_tensor_index(node->split_dim())};
- std::vector<int32_t> outputs_vec;
-
- for (int32_t index = 0; index < node->num_split(); index++)
- {
- // store in order of index
- bool found = false;
- for (auto out : split_outs)
- {
- auto split_out = loco::must_cast<luci::CircleSplitVOut *>(out);
- if (split_out->index() == index)
- {
- outputs_vec.push_back(get_tensor_index(split_out));
- found = true;
- break;
- }
- }
- if (!found)
- {
- INTERNAL_EXN("Invalid SplitV output");
- }
- }
-
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto options = CreateSplitVOptions(ctx.builder, node->num_split());
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_SplitVOptions, options.Union());
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleTopKV2 *node)
-{
- auto topkv2_outs = loco::succs(node);
- int outs_count = int32_t(topkv2_outs.size());
- assert(outs_count == 2);
-
- uint32_t op_idx =
- ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_TOPK_V2, node->op_version());
- std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), get_tensor_index(node->k())};
- std::vector<int32_t> outputs_vec;
-
- for (int32_t index = 0; index < outs_count; index++)
- {
- // store in order of index
- bool found = false;
- for (auto out : topkv2_outs)
- {
- auto topkv2_out = loco::must_cast<luci::CircleTopKV2Out *>(out);
- if (topkv2_out->index() == index)
- {
- outputs_vec.push_back(get_tensor_index(topkv2_out));
- found = true;
- break;
- }
- }
- if (!found)
- {
- INTERNAL_EXN("Invalid TopKV2 output");
- }
- }
-
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto options = CreateTopKV2Options(ctx.builder);
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_TopKV2Options, options.Union());
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleUnique *node)
-{
- auto unique_outs = loco::succs(node);
- assert(int32_t(unique_outs.size()) == 2);
- uint32_t op_idx =
- ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_UNIQUE, node->op_version());
-
- std::vector<int32_t> inputs_vec{get_tensor_index(node->input())};
- std::vector<int32_t> outputs_vec;
-
- for (int32_t index = 0; index < 2; index++)
- {
- // store in order of index
- bool found = false;
- for (auto out : unique_outs)
- {
- auto unique_out = loco::must_cast<luci::CircleUniqueOut *>(out);
- if (unique_out->index() == index)
- {
- outputs_vec.push_back(get_tensor_index(unique_out));
- found = true;
- break;
- }
- }
- if (!found)
- {
- INTERNAL_EXN("Invalid Unique output");
- }
- }
-
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto options = CreateUniqueOptions(ctx.builder, to_circle_tensortype(node->idx_out_type()));
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_UniqueOptions, options.Union());
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleUnpack *node)
-{
- LOGGER(l);
- auto settings = luci::UserSettings::settings();
-
- auto unpack_outs = loco::succs(node);
- // NOTE real models may not use all of the outputs
- if (static_cast<int32_t>(unpack_outs.size()) != node->num())
- {
- if (settings->get(luci::UserSettings::Key::DisableValidation))
- {
- WARN(l) << "Warning: export Unpack(" << node->name() << ") 'num' not same as outputs";
- }
- else
- assert(false);
- }
-
- uint32_t op_idx =
- ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_UNPACK, node->op_version());
- std::vector<int32_t> inputs_vec{get_tensor_index(node->value())};
- std::vector<int32_t> outputs_vec;
-
- for (int32_t index = 0; index < node->num(); index++)
- {
- // store in order of index
- bool found = false;
- for (auto out : unpack_outs)
- {
- auto unpack_out = loco::must_cast<luci::CircleUnpackOut *>(out);
- if (unpack_out->index() == index)
- {
- outputs_vec.push_back(get_tensor_index(unpack_out));
- found = true;
- break;
- }
- }
- // NOTE real models may not use all of the outputs
- if (!found)
- {
- if (settings->get(luci::UserSettings::Key::DisableValidation))
- {
- WARN(l) << "Warning: export Unpack(" << node->name() << ") output " << index << " not used";
- }
- else
- assert(false);
- }
- }
-
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto options = CreateUnpackOptions(ctx.builder, node->num(), node->axis());
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_UnpackOptions, options.Union());
- ctx.gd._operators.push_back(op_offset);
-}
-
-void export_node(ExportContext &ctx, luci::CircleWhile *node)
-{
- auto while_outs = loco::succs(node);
- assert(while_outs.size() == node->output_count());
-
- uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_WHILE, node->op_version());
- std::vector<int32_t> inputs_vec;
- std::vector<int32_t> outputs_vec;
-
- for (uint32_t idx = 0; idx < node->input_count(); ++idx)
- inputs_vec.push_back(get_tensor_index(node->input(idx)));
-
- for (uint32_t idx = 0; idx < node->output_count(); ++idx)
- {
- // store in order of index
- bool found = false;
- for (auto out : while_outs)
- {
- auto while_out = loco::must_cast<luci::CircleWhileOut *>(out);
- if (while_out->index() == static_cast<int32_t>(idx))
- {
- outputs_vec.push_back(get_tensor_index(while_out));
- found = true;
- break;
- }
- }
- if (!found)
- {
- INTERNAL_EXN("Invalid CircleWhile output");
- }
- }
-
- auto inputs = ctx.builder.CreateVector(inputs_vec);
- auto outputs = ctx.builder.CreateVector(outputs_vec);
- auto options = CreateWhileOptions(ctx.builder, node->cond_branch(), node->body_branch());
- auto op_offset = CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_WhileOptions, options.Union());
- ctx.gd._operators.push_back(op_offset);
-}
-
-class ExportHelper
-{
-public:
- ExportHelper(ExportContext &ctx) : _ctx{ctx}
- {
- // DO NOTHING
- }
-
-protected:
- /**
- * @brief export simple nodes
- */
- void export_simple(loco::Node *node, circle::BuiltinOperator bop, circle::BuiltinOptions bot,
- flatbuffers::Offset<void> options_offset)
- {
- export_node(_ctx, node, bop, bot, options_offset);
- }
-
- /**
- * @brief export simple nodes having void options
- */
- void export_simple(loco::Node *node, circle::BuiltinOperator bop)
- {
- export_node(_ctx, node, bop);
- }
-
-protected:
- ExportContext &_ctx;
-};
-
-enum class OE
-{
- ABC,
- DEF,
- GHIJ,
- KLMN,
- OPQR,
- STUV,
- WXYZ,
- CIRC, // circle only
- VIRT, // virtual
-};
-
-class OperationExporter final : public ExportHelper
-{
-public:
- OperationExporter(ExportContext &ctx) : ExportHelper(ctx)
- {
- // DO NOTHING
- }
-
-public:
- void export_node(luci::CircleNode *);
-};
-
-template <OE oe> class OpExporterLet;
-
-template <>
-class OpExporterLet<OE::ABC> final : public luci::CircleNodeMutableVisitor<void>,
- public ExportHelper
-{
-public:
- OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
- {
- // DO NOTHING
- }
-
-public:
- // NOTE visit for luci::CircleNode is added NOT to throw NYI
- void visit(luci::CircleNode *) final {}
-
-public:
- void visit(luci::CircleAbs *) final;
- void visit(luci::CircleAdd *) final;
- void visit(luci::CircleAddN *) final;
- void visit(luci::CircleArgMax *) final;
- void visit(luci::CircleArgMin *) final;
- void visit(luci::CircleAveragePool2D *) final;
- void visit(luci::CircleBatchMatMul *) final;
- void visit(luci::CircleBatchToSpaceND *) final;
- void visit(luci::CircleBidirectionalSequenceLSTM *) final;
- void visit(luci::CircleCast *) final;
- void visit(luci::CircleCeil *) final;
- void visit(luci::CircleConcatenation *) final;
- void visit(luci::CircleConst *) final{/* skip, everything is done in exportOpDefinedTensors */};
- void visit(luci::CircleConv2D *) final;
- void visit(luci::CircleCos *) final;
- void visit(luci::CircleCustom *) final;
-};
-
-template <>
-class OpExporterLet<OE::DEF> final : public luci::CircleNodeMutableVisitor<void>,
- public ExportHelper
-{
-public:
- OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
- {
- // DO NOTHING
- }
-
-public:
- void visit(luci::CircleNode *) final {}
-
-public:
- void visit(luci::CircleDepthToSpace *) final;
- void visit(luci::CircleDepthwiseConv2D *) final;
- void visit(luci::CircleDequantize *) final;
- void visit(luci::CircleDiv *) final;
- void visit(luci::CircleElu *) final;
- void visit(luci::CircleEqual *) final;
- void visit(luci::CircleExp *) final;
- void visit(luci::CircleExpandDims *) final;
- void visit(luci::CircleFakeQuant *) final;
- void visit(luci::CircleFill *) final;
- void visit(luci::CircleFloor *) final;
- void visit(luci::CircleFloorDiv *) final;
- void visit(luci::CircleFloorMod *) final;
- void visit(luci::CircleFullyConnected *) final;
-};
-
-template <>
-class OpExporterLet<OE::GHIJ> final : public luci::CircleNodeMutableVisitor<void>,
- public ExportHelper
-{
-public:
- OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
- {
- // DO NOTHING
- }
-
-public:
- void visit(luci::CircleNode *) final {}
-
-public:
- void visit(luci::CircleGather *) final;
- void visit(luci::CircleGatherNd *) final;
- void visit(luci::CircleGreater *) final;
- void visit(luci::CircleGreaterEqual *) final;
- void visit(luci::CircleIf *) final;
-};
-
-template <>
-class OpExporterLet<OE::KLMN> final : public luci::CircleNodeMutableVisitor<void>,
- public ExportHelper
-{
-public:
- OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
- {
- // DO NOTHING
- }
-
-public:
- void visit(luci::CircleNode *) final {}
-
-public:
- void visit(luci::CircleL2Normalize *) final;
- void visit(luci::CircleL2Pool2D *) final;
- void visit(luci::CircleLeakyRelu *) final;
- void visit(luci::CircleLess *) final;
- void visit(luci::CircleLessEqual *) final;
- void visit(luci::CircleLocalResponseNormalization *) final;
- void visit(luci::CircleLog *) final;
- void visit(luci::CircleLogicalAnd *) final;
- void visit(luci::CircleLogicalNot *) final;
- void visit(luci::CircleLogicalOr *) final;
- void visit(luci::CircleLogistic *) final;
- void visit(luci::CircleLogSoftmax *) final;
- void visit(luci::CircleMatrixDiag *) final;
- void visit(luci::CircleMatrixSetDiag *) final;
- void visit(luci::CircleMaximum *) final;
- void visit(luci::CircleMaxPool2D *) final;
- void visit(luci::CircleMean *) final;
- void visit(luci::CircleMinimum *) final;
- void visit(luci::CircleMirrorPad *) final;
- void visit(luci::CircleMul *) final;
- void visit(luci::CircleNeg *) final;
- void visit(luci::CircleNonMaxSuppressionV4 *) final;
- void visit(luci::CircleNonMaxSuppressionV5 *) final;
- void visit(luci::CircleNotEqual *) final;
-};
-
-template <>
-class OpExporterLet<OE::OPQR> final : public luci::CircleNodeMutableVisitor<void>,
- public ExportHelper
-{
-public:
- OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
- {
- // DO NOTHING
- }
-
-public:
- void visit(luci::CircleNode *) final {}
-
-public:
- void visit(luci::CircleOneHot *) final;
- void visit(luci::CirclePack *) final;
- void visit(luci::CirclePad *) final;
- void visit(luci::CirclePadV2 *) final;
- void visit(luci::CirclePow *) final;
- void visit(luci::CirclePRelu *) final;
- void visit(luci::CircleQuantize *) final;
- void visit(luci::CircleRange *) final;
- void visit(luci::CircleRank *) final;
- void visit(luci::CircleReduceAny *) final;
- void visit(luci::CircleReduceMax *) final;
- void visit(luci::CircleReduceMin *) final;
- void visit(luci::CircleReduceProd *) final;
- void visit(luci::CircleRelu *) final;
- void visit(luci::CircleRelu6 *) final;
- void visit(luci::CircleReluN1To1 *) final;
- void visit(luci::CircleReshape *) final;
- void visit(luci::CircleResizeBilinear *) final;
- void visit(luci::CircleResizeNearestNeighbor *) final;
- void visit(luci::CircleReverseSequence *) final;
- void visit(luci::CircleReverseV2 *) final;
- void visit(luci::CircleRound *) final;
- void visit(luci::CircleRsqrt *) final;
-};
-
-template <>
-class OpExporterLet<OE::STUV> final : public luci::CircleNodeMutableVisitor<void>,
- public ExportHelper
-{
-public:
- OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
- {
- // DO NOTHING
- }
-
-public:
- void visit(luci::CircleNode *) final {}
-
-public:
- void visit(luci::CircleScatterNd *) final;
- void visit(luci::CircleSegmentSum *) final;
- void visit(luci::CircleSelect *) final;
- void visit(luci::CircleSelectV2 *) final;
- void visit(luci::CircleShape *) final;
- void visit(luci::CircleSin *) final;
- void visit(luci::CircleSlice *) final;
- void visit(luci::CircleSoftmax *) final;
- void visit(luci::CircleSpaceToBatchND *) final;
- void visit(luci::CircleSpaceToDepth *) final;
- void visit(luci::CircleSparseToDense *) final;
- void visit(luci::CircleSplit *) final;
- void visit(luci::CircleSplitV *) final;
- void visit(luci::CircleSqrt *) final;
- void visit(luci::CircleSquare *) final;
- void visit(luci::CircleSquaredDifference *) final;
- void visit(luci::CircleSqueeze *) final;
- void visit(luci::CircleStridedSlice *) final;
- void visit(luci::CircleSub *) final;
- void visit(luci::CircleSum *) final;
- void visit(luci::CircleTanh *) final;
- void visit(luci::CircleTile *) final;
- void visit(luci::CircleTopKV2 *) final;
- void visit(luci::CircleTranspose *) final;
- void visit(luci::CircleTransposeConv *) final;
- void visit(luci::CircleUnidirectionalSequenceLSTM *) final;
- void visit(luci::CircleUnique *) final;
- void visit(luci::CircleUnpack *) final;
-};
-
-template <>
-class OpExporterLet<OE::WXYZ> final : public luci::CircleNodeMutableVisitor<void>,
- public ExportHelper
-{
-public:
- OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
- {
- // DO NOTHING
- }
-
-public:
- void visit(luci::CircleNode *) final {}
-
-public:
- void visit(luci::CircleWhere *) final;
- void visit(luci::CircleWhile *) final;
- void visit(luci::CircleZerosLike *) final;
-};
-
-template <>
-class OpExporterLet<OE::CIRC> final : public luci::CircleNodeMutableVisitor<void>,
- public ExportHelper
-{
-public:
- OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
- {
- // DO NOTHING
- }
-
-public:
- void visit(luci::CircleNode *) final {}
-
-public:
- // Circle only
- void visit(luci::CircleBCQFullyConnected *) final;
- void visit(luci::CircleBCQGather *) final;
- void visit(luci::CircleInstanceNorm *) final;
-};
-
-template <>
-class OpExporterLet<OE::VIRT> final : public luci::CircleNodeMutableVisitor<void>,
- public ExportHelper
-{
-public:
- OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
- {
- // DO NOTHING
- }
-
-public:
- void visit(luci::CircleNode *) final {}
-
-public:
- // Virtual
- void visit(luci::CircleInput *) final {}
- void visit(luci::CircleOutput *) final {}
- void visit(luci::CircleOutputDummy *) final {}
- void visit(luci::CircleOutputExclude *) final {}
- // Virtual for multiple-outputs
- void visit(luci::CircleBidirectionalSequenceLSTMOut *) final {}
- void visit(luci::CircleCustomOut *) final {}
- void visit(luci::CircleIfOut *) final {}
- void visit(luci::CircleNonMaxSuppressionV4Out *) final {}
- void visit(luci::CircleNonMaxSuppressionV5Out *) final {}
- void visit(luci::CircleSplitOut *) final {}
- void visit(luci::CircleSplitVOut *) final {}
- void visit(luci::CircleTopKV2Out *) final {}
- void visit(luci::CircleUniqueOut *) final {}
- void visit(luci::CircleUnpackOut *) final {}
- void visit(luci::CircleWhileOut *) final {}
-};
-
-void OperationExporter::export_node(luci::CircleNode *node)
-{
- // TODO revise return type to bool and return if handled
-#define VISIT_OE(GRP) \
- do \
- { \
- OpExporterLet<OE::GRP> oe(_ctx); \
- node->accept(&oe); \
- } while (false)
-
- VISIT_OE(ABC);
- VISIT_OE(DEF);
- VISIT_OE(GHIJ);
- VISIT_OE(KLMN);
- VISIT_OE(OPQR);
- VISIT_OE(STUV);
- VISIT_OE(WXYZ);
- VISIT_OE(CIRC);
- VISIT_OE(VIRT);
-
-#undef VISIT_OE
-}
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleAbs *node)
-{
- export_simple(node, circle::BuiltinOperator_ABS, circle::BuiltinOptions_AbsOptions,
- CreateAbsOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleAdd *node)
-{
- export_simple(
- node, circle::BuiltinOperator_ADD, circle::BuiltinOptions_AddOptions,
- CreateAddOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union());
-}
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleAddN *node) { export_node(_ctx, node); }
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleArgMax *node)
-{
- export_simple(
- node, circle::BuiltinOperator_ARG_MAX, circle::BuiltinOptions_ArgMaxOptions,
- CreateArgMaxOptions(_ctx.builder, to_circle_tensortype(node->output_type())).Union());
-}
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleArgMin *node)
-{
- export_simple(
- node, circle::BuiltinOperator_ARG_MIN, circle::BuiltinOptions_ArgMinOptions,
- CreateArgMinOptions(_ctx.builder, to_circle_tensortype(node->output_type())).Union());
-}
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleAveragePool2D *node)
-{
- export_pool_2d<luci::CircleAveragePool2D>(_ctx, node, circle::BuiltinOperator_AVERAGE_POOL_2D);
-}
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleBatchMatMul *node)
-{
- export_simple(node, circle::BuiltinOperator_BATCH_MATMUL,
- circle::BuiltinOptions_BatchMatMulOptions,
- CreateBatchMatMulOptions(_ctx.builder, node->adj_x(), node->adj_y()).Union());
-}
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleBidirectionalSequenceLSTM *node)
-{
- auto bidi_lstm_outs = loco::succs(node);
- assert((bidi_lstm_outs.size() == 1) || (bidi_lstm_outs.size() == 2));
- uint32_t op_idx = _ctx.md.registerBuiltinOpcode(
- circle::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, node->op_version());
-
- std::vector<int32_t> inputs_vec{get_tensor_index(node->input())};
- std::vector<int32_t> outputs_vec;
-
- for (int32_t index = 0; index < 2; index++)
- {
- // store in order of index
- bool found = false;
- for (auto out : bidi_lstm_outs)
- {
- auto bidi_lstm_out = loco::must_cast<luci::CircleBidirectionalSequenceLSTMOut *>(out);
- if (bidi_lstm_out->index() == index)
- {
- outputs_vec.push_back(get_tensor_index(bidi_lstm_out));
- found = true;
- break;
- }
- }
- if (!found)
- {
- INTERNAL_EXN("Invalid BidirectionalSequenceLSTM output");
- }
- }
-
- auto inputs = _ctx.builder.CreateVector(inputs_vec);
- auto outputs = _ctx.builder.CreateVector(outputs_vec);
- auto options = CreateBidirectionalSequenceLSTMOptions(
- _ctx.builder, to_circle_actfunc(node->fusedActivationFunction()), node->cell_clip(),
- node->proj_clip(), node->merge_outputs(), node->time_major(),
- node->asymmetric_quantize_inputs());
- auto op_offset =
- CreateOperator(_ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_BidirectionalSequenceLSTMOptions, options.Union());
- _ctx.gd._operators.push_back(op_offset);
-}
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleCast *node) { export_node(_ctx, node); }
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleCeil *node)
-{
- export_simple(node, circle::BuiltinOperator_CEIL);
-}
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleConcatenation *node) { export_node(_ctx, node); }
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleBatchToSpaceND *node)
-{
- export_simple(node, circle::BuiltinOperator_BATCH_TO_SPACE_ND,
- circle::BuiltinOptions_BatchToSpaceNDOptions,
- CreateBatchToSpaceNDOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleConv2D *node)
-{
- export_simple(node, circle::BuiltinOperator_CONV_2D, circle::BuiltinOptions_Conv2DOptions,
- CreateConv2DOptions(_ctx.builder, getOpPadding(node->padding()),
- node->stride()->w(), node->stride()->h(),
- to_circle_actfunc(node->fusedActivationFunction()),
- node->dilation()->w(), node->dilation()->h())
- .Union());
-}
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleCos *node)
-{
- export_simple(node, circle::BuiltinOperator_COS, circle::BuiltinOptions_CosOptions,
- CreateCosOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::ABC>::visit(luci::CircleCustom *node) { export_node(_ctx, node); }
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleDepthToSpace *node)
-{
- export_simple(node, circle::BuiltinOperator_DEPTH_TO_SPACE,
- circle::BuiltinOptions_DepthToSpaceOptions,
- CreateDepthToSpaceOptions(_ctx.builder, node->block_size()).Union());
-}
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleDepthwiseConv2D *node)
-{
- export_simple(
- node, circle::BuiltinOperator_DEPTHWISE_CONV_2D, circle::BuiltinOptions_DepthwiseConv2DOptions,
- CreateDepthwiseConv2DOptions(_ctx.builder, getOpPadding(node->padding()), node->stride()->w(),
- node->stride()->h(), node->depthMultiplier(),
- to_circle_actfunc(node->fusedActivationFunction()),
- node->dilation()->w(), node->dilation()->h())
- .Union());
-}
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleDequantize *node)
-{
- export_simple(node, circle::BuiltinOperator_DEQUANTIZE);
-}
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleDiv *node)
-{
- export_simple(
- node, circle::BuiltinOperator_DIV, circle::BuiltinOptions_DivOptions,
- CreateDivOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union());
-}
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleElu *node)
-{
- export_simple(node, circle::BuiltinOperator_ELU);
-}
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleEqual *node)
-{
- export_simple(node, circle::BuiltinOperator_EQUAL, circle::BuiltinOptions_EqualOptions,
- CreateEqualOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleExp *node)
-{
- export_simple(node, circle::BuiltinOperator_EXP, circle::BuiltinOptions_ExpOptions,
- CreateExpOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleExpandDims *node)
-{
- export_simple(node, circle::BuiltinOperator_EXPAND_DIMS, circle::BuiltinOptions_ExpandDimsOptions,
- CreateExpandDimsOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleFakeQuant *node)
-{
- export_simple(node, circle::BuiltinOperator_FAKE_QUANT, circle::BuiltinOptions_FakeQuantOptions,
- CreateFakeQuantOptions(_ctx.builder, node->min(), node->max(), node->num_bits(),
- node->narrow_range())
- .Union());
-}
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleFill *node)
-{
- export_simple(node, circle::BuiltinOperator_FILL, circle::BuiltinOptions_FillOptions,
- CreateFillOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleFloor *node)
-{
- export_simple(node, circle::BuiltinOperator_FLOOR);
-}
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleFloorDiv *node)
-{
- export_simple(node, circle::BuiltinOperator_FLOOR_DIV, circle::BuiltinOptions_FloorDivOptions,
- CreateFloorDivOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleFloorMod *node)
-{
- export_simple(node, circle::BuiltinOperator_FLOOR_MOD, circle::BuiltinOptions_FloorModOptions,
- CreateFloorModOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::DEF>::visit(luci::CircleFullyConnected *node)
-{
- export_simple(
- node, circle::BuiltinOperator_FULLY_CONNECTED, circle::BuiltinOptions_FullyConnectedOptions,
- CreateFullyConnectedOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction()),
- to_circle_weightsformat(node->weights_format()))
- .Union());
-}
-
-void OpExporterLet<OE::GHIJ>::visit(luci::CircleGather *node)
-{
- export_simple(node, circle::BuiltinOperator_GATHER, circle::BuiltinOptions_GatherOptions,
- CreateGatherOptions(_ctx.builder, node->axis()).Union());
-}
-
-void OpExporterLet<OE::GHIJ>::visit(luci::CircleGatherNd *node)
-{
- export_simple(node, circle::BuiltinOperator_GATHER_ND, circle::BuiltinOptions_GatherNdOptions,
- CreateGatherNdOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::GHIJ>::visit(luci::CircleGreater *node)
-{
- export_simple(node, circle::BuiltinOperator_GREATER, circle::BuiltinOptions_GreaterOptions,
- CreateGreaterOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::GHIJ>::visit(luci::CircleGreaterEqual *node)
-{
- export_simple(node, circle::BuiltinOperator_GREATER_EQUAL,
- circle::BuiltinOptions_GreaterEqualOptions,
- CreateGreaterEqualOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::GHIJ>::visit(luci::CircleIf *node) { export_node(_ctx, node); }
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleL2Normalize *node)
-{
- export_simple(
- node, circle::BuiltinOperator_L2_NORMALIZATION, circle::BuiltinOptions_L2NormOptions,
- CreateL2NormOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleL2Pool2D *node)
-{
- export_pool_2d<luci::CircleL2Pool2D>(_ctx, node, circle::BuiltinOperator_L2_POOL_2D);
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleLeakyRelu *node)
-{
- export_simple(node, circle::BuiltinOperator_LEAKY_RELU, circle::BuiltinOptions_LeakyReluOptions,
- CreateLeakyReluOptions(_ctx.builder, node->alpha()).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleLess *node)
-{
- export_simple(node, circle::BuiltinOperator_LESS, circle::BuiltinOptions_LessOptions,
- CreateLessOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleLessEqual *node)
-{
- export_simple(node, circle::BuiltinOperator_LESS_EQUAL, circle::BuiltinOptions_LessEqualOptions,
- CreateLessEqualOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleLocalResponseNormalization *node)
-{
- export_simple(node, circle::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
- circle::BuiltinOptions_LocalResponseNormalizationOptions,
- CreateLocalResponseNormalizationOptions(_ctx.builder, node->radius(), node->bias(),
- node->alpha(), node->beta())
- .Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleLog *node)
-{
- export_simple(node, circle::BuiltinOperator_LOG);
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleLogicalAnd *node)
-{
- export_simple(node, circle::BuiltinOperator_LOGICAL_AND, circle::BuiltinOptions_LogicalAndOptions,
- CreateLogicalAndOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleLogicalNot *node)
-{
- export_simple(node, circle::BuiltinOperator_LOGICAL_NOT, circle::BuiltinOptions_LogicalNotOptions,
- CreateLogicalNotOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleLogicalOr *node)
-{
- export_simple(node, circle::BuiltinOperator_LOGICAL_OR, circle::BuiltinOptions_LogicalOrOptions,
- CreateLogicalOrOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleLogistic *node)
-{
- export_simple(node, circle::BuiltinOperator_LOGISTIC);
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleLogSoftmax *node)
-{
- export_simple(node, circle::BuiltinOperator_LOG_SOFTMAX, circle::BuiltinOptions_LogSoftmaxOptions,
- CreateLogSoftmaxOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleMatrixDiag *node)
-{
- export_simple(node, circle::BuiltinOperator_MATRIX_DIAG, circle::BuiltinOptions_MatrixDiagOptions,
- CreateMatrixDiagOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleMatrixSetDiag *node)
-{
- export_simple(node, circle::BuiltinOperator_MATRIX_SET_DIAG,
- circle::BuiltinOptions_MatrixSetDiagOptions,
- CreateMatrixSetDiagOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleMaximum *node)
-{
- export_simple(node, circle::BuiltinOperator_MAXIMUM, circle::BuiltinOptions_MaximumMinimumOptions,
- CreateMaximumMinimumOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleMaxPool2D *node)
-{
- export_pool_2d<luci::CircleMaxPool2D>(_ctx, node, circle::BuiltinOperator_MAX_POOL_2D);
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleMean *node)
-{
- export_simple(node, circle::BuiltinOperator_MEAN, circle::BuiltinOptions_ReducerOptions,
- CreateReducerOptions(_ctx.builder, node->keep_dims()).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleMinimum *node)
-{
- export_simple(node, circle::BuiltinOperator_MINIMUM, circle::BuiltinOptions_MaximumMinimumOptions,
- CreateMaximumMinimumOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleMirrorPad *node)
-{
- export_simple(
- node, circle::BuiltinOperator_MIRROR_PAD, circle::BuiltinOptions_MirrorPadOptions,
- CreateMirrorPadOptions(_ctx.builder, to_circle_mirrorpadmode(node->mode())).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleMul *node)
-{
- export_simple(
- node, circle::BuiltinOperator_MUL, circle::BuiltinOptions_MulOptions,
- CreateMulOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleNeg *node)
-{
- export_simple(node, circle::BuiltinOperator_NEG, circle::BuiltinOptions_NegOptions,
- CreateNegOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleNonMaxSuppressionV4 *node)
-{
- export_node(_ctx, node);
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleNonMaxSuppressionV5 *node)
-{
- export_node(_ctx, node);
-}
-
-void OpExporterLet<OE::KLMN>::visit(luci::CircleNotEqual *node)
-{
- export_simple(node, circle::BuiltinOperator_NOT_EQUAL, circle::BuiltinOptions_NotEqualOptions,
- CreateNotEqualOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleOneHot *node)
-{
- export_simple(node, circle::BuiltinOperator_ONE_HOT, circle::BuiltinOptions_OneHotOptions,
- CreateOneHotOptions(_ctx.builder, node->axis()).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CirclePack *node)
-{
- export_simple(node, circle::BuiltinOperator_PACK, circle::BuiltinOptions_PackOptions,
- CreatePackOptions(_ctx.builder, node->values_count(), node->axis()).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CirclePad *node)
-{
- export_simple(node, circle::BuiltinOperator_PAD, circle::BuiltinOptions_PadOptions,
- CreatePadOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CirclePadV2 *node)
-{
- export_simple(node, circle::BuiltinOperator_PADV2, circle::BuiltinOptions_PadV2Options,
- CreatePadV2Options(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CirclePow *node)
-{
- export_simple(node, circle::BuiltinOperator_POW, circle::BuiltinOptions_PowOptions,
- CreatePowOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CirclePRelu *node)
-{
- export_simple(node, circle::BuiltinOperator_PRELU);
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleQuantize *node)
-{
- export_simple(node, circle::BuiltinOperator_QUANTIZE);
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleRange *node)
-{
- export_simple(node, circle::BuiltinOperator_RANGE, circle::BuiltinOptions_RangeOptions,
- CreateRangeOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleRank *node)
-{
- export_simple(node, circle::BuiltinOperator_RANK, circle::BuiltinOptions_RankOptions,
- CreateRankOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleReduceAny *node)
-{
- export_simple(node, circle::BuiltinOperator_REDUCE_ANY, circle::BuiltinOptions_ReducerOptions,
- CreateReducerOptions(_ctx.builder, node->keep_dims()).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleReduceMax *node)
-{
- export_simple(node, circle::BuiltinOperator_REDUCE_MAX, circle::BuiltinOptions_ReducerOptions,
- CreateReducerOptions(_ctx.builder, node->keep_dims()).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleReduceMin *node)
-{
- export_simple(node, circle::BuiltinOperator_REDUCE_MIN, circle::BuiltinOptions_ReducerOptions,
- CreateReducerOptions(_ctx.builder, node->keep_dims()).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleReduceProd *node)
-{
- export_simple(node, circle::BuiltinOperator_REDUCE_PROD, circle::BuiltinOptions_ReducerOptions,
- CreateReducerOptions(_ctx.builder, node->keep_dims()).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleRelu *node)
-{
- export_simple(node, circle::BuiltinOperator_RELU);
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleRelu6 *node)
-{
- export_simple(node, circle::BuiltinOperator_RELU6);
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleReluN1To1 *node)
-{
- export_simple(node, circle::BuiltinOperator_RELU_N1_TO_1);
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleReshape *node)
-{
- auto new_shape = _ctx.builder.CreateVector<int32_t>(
- node->newShape()->rank(), [node](size_t i) { return node->newShape()->dim(i); });
-
- export_simple(node, circle::BuiltinOperator_RESHAPE, circle::BuiltinOptions_ReshapeOptions,
- CreateReshapeOptions(_ctx.builder, new_shape).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleResizeBilinear *node)
-{
- export_simple(
- node, circle::BuiltinOperator_RESIZE_BILINEAR, circle::BuiltinOptions_ResizeBilinearOptions,
- CreateResizeBilinearOptions(_ctx.builder, node->align_corners(), node->half_pixel_centers())
- .Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleResizeNearestNeighbor *node)
-{
- export_simple(node, circle::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
- circle::BuiltinOptions_ResizeNearestNeighborOptions,
- CreateResizeNearestNeighborOptions(_ctx.builder, node->align_corners()).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleReverseSequence *node)
-{
- export_simple(
- node, circle::BuiltinOperator_REVERSE_SEQUENCE, circle::BuiltinOptions_ReverseSequenceOptions,
- CreateReverseSequenceOptions(_ctx.builder, node->seq_axis(), node->batch_axis()).Union());
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleReverseV2 *node) { export_node(_ctx, node); }
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleRound *node)
-{
- export_simple(node, circle::BuiltinOperator_ROUND);
-}
-
-void OpExporterLet<OE::OPQR>::visit(luci::CircleRsqrt *node)
-{
- export_simple(node, circle::BuiltinOperator_RSQRT);
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleScatterNd *node)
-{
- export_simple(node, circle::BuiltinOperator_SCATTER_ND, circle::BuiltinOptions_ScatterNdOptions,
- CreateScatterNdOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSegmentSum *node)
-{
- export_simple(node, circle::BuiltinOperator_SEGMENT_SUM, circle::BuiltinOptions_SegmentSumOptions,
- CreateSegmentSumOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSelect *node)
-{
- export_simple(node, circle::BuiltinOperator_SELECT, circle::BuiltinOptions_SelectOptions,
- CreateSelectOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSelectV2 *node)
-{
- export_simple(node, circle::BuiltinOperator_SELECT_V2, circle::BuiltinOptions_SelectV2Options,
- CreateSelectV2Options(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleShape *node)
-{
- export_simple(node, circle::BuiltinOperator_SHAPE, circle::BuiltinOptions_ShapeOptions,
- CreateShapeOptions(_ctx.builder, to_circle_tensortype(node->out_type())).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSin *node)
-{
- export_simple(node, circle::BuiltinOperator_SIN);
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSlice *node)
-{
- export_simple(node, circle::BuiltinOperator_SLICE, circle::BuiltinOptions_SliceOptions,
- CreateSliceOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSoftmax *node)
-{
- export_simple(node, circle::BuiltinOperator_SOFTMAX, circle::BuiltinOptions_SoftmaxOptions,
- CreateSoftmaxOptions(_ctx.builder, node->beta()).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSpaceToBatchND *node)
-{
- export_simple(node, circle::BuiltinOperator_SPACE_TO_BATCH_ND,
- circle::BuiltinOptions_SpaceToBatchNDOptions,
- CreateSpaceToBatchNDOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSpaceToDepth *node)
-{
- export_simple(node, circle::BuiltinOperator_SPACE_TO_DEPTH,
- circle::BuiltinOptions_SpaceToDepthOptions,
- CreateSpaceToDepthOptions(_ctx.builder, node->block_size()).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSparseToDense *node)
-{
- export_simple(node, circle::BuiltinOperator_SPARSE_TO_DENSE,
- circle::BuiltinOptions_SparseToDenseOptions,
- CreateSparseToDenseOptions(_ctx.builder, node->validate_indices()).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSplit *node) { export_node(_ctx, node); }
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSplitV *node) { export_node(_ctx, node); }
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSqrt *node)
-{
- export_simple(node, circle::BuiltinOperator_SQRT);
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSquare *node)
-{
- export_simple(node, circle::BuiltinOperator_SQUARE, circle::BuiltinOptions_SquareOptions,
- CreateSquareOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSquaredDifference *node)
-{
- export_simple(node, circle::BuiltinOperator_SQUARED_DIFFERENCE,
- circle::BuiltinOptions_SquaredDifferenceOptions,
- CreateSquaredDifferenceOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSqueeze *node)
-{
- auto squeeze_dims = _ctx.builder.CreateVector<int32_t>(node->squeeze_dims());
- export_simple(node, circle::BuiltinOperator_SQUEEZE, circle::BuiltinOptions_SqueezeOptions,
- CreateSqueezeOptions(_ctx.builder, squeeze_dims).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleStridedSlice *node)
-{
- export_simple(node, circle::BuiltinOperator_STRIDED_SLICE,
- circle::BuiltinOptions_StridedSliceOptions,
- CreateStridedSliceOptions(_ctx.builder, node->begin_mask(), node->end_mask(),
- node->ellipsis_mask(), node->new_axis_mask(),
- node->shrink_axis_mask())
- .Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSub *node)
-{
- export_simple(
- node, circle::BuiltinOperator_SUB, circle::BuiltinOptions_SubOptions,
- CreateSubOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleSum *node)
-{
- export_simple(node, circle::BuiltinOperator_SUM, circle::BuiltinOptions_ReducerOptions,
- CreateReducerOptions(_ctx.builder, node->keep_dims()).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleTanh *node)
-{
- export_simple(node, circle::BuiltinOperator_TANH);
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleTile *node)
-{
- export_simple(node, circle::BuiltinOperator_TILE, circle::BuiltinOptions_TileOptions,
- CreateTileOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleTopKV2 *node) { export_node(_ctx, node); }
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleTranspose *node)
-{
- export_simple(node, circle::BuiltinOperator_TRANSPOSE, circle::BuiltinOptions_TransposeOptions,
- CreateTransposeOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleTransposeConv *node)
-{
- export_simple(node, circle::BuiltinOperator_TRANSPOSE_CONV,
- circle::BuiltinOptions_TransposeConvOptions,
- CreateTransposeConvOptions(_ctx.builder, getOpPadding(node->padding()),
- node->stride()->w(), node->stride()->h())
- .Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleUnidirectionalSequenceLSTM *node)
-{
- export_simple(node, circle::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
- circle::BuiltinOptions_UnidirectionalSequenceLSTMOptions,
- CreateUnidirectionalSequenceLSTMOptions(
- _ctx.builder, to_circle_actfunc(node->fusedActivationFunction()),
- node->cell_clip(), node->proj_clip(), node->time_major(),
- node->asymmetric_quantize_inputs())
- .Union());
-}
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleUnique *node) { export_node(_ctx, node); }
-
-void OpExporterLet<OE::STUV>::visit(luci::CircleUnpack *node) { export_node(_ctx, node); }
-
-void OpExporterLet<OE::WXYZ>::visit(luci::CircleWhere *node)
-{
- export_simple(node, circle::BuiltinOperator_WHERE, circle::BuiltinOptions_WhereOptions,
- CreateWhereOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::WXYZ>::visit(luci::CircleWhile *node) { export_node(_ctx, node); }
-
-void OpExporterLet<OE::WXYZ>::visit(luci::CircleZerosLike *node)
-{
- export_simple(node, circle::BuiltinOperator_ZEROS_LIKE, circle::BuiltinOptions_ZerosLikeOptions,
- CreateZerosLikeOptions(_ctx.builder).Union());
-}
-
-void OpExporterLet<OE::CIRC>::visit(luci::CircleBCQFullyConnected *node)
-{
- export_simple(node, circle::BuiltinOperator_BCQ_FULLY_CONNECTED,
- circle::BuiltinOptions_BCQFullyConnectedOptions,
- CreateBCQFullyConnectedOptions(_ctx.builder, node->weights_hidden_size(),
- to_circle_actfunc(node->fusedActivationFunction()))
- .Union());
-}
-
-void OpExporterLet<OE::CIRC>::visit(luci::CircleBCQGather *node)
-{
- export_simple(
- node, circle::BuiltinOperator_BCQ_GATHER, circle::BuiltinOptions_BCQGatherOptions,
- CreateBCQGatherOptions(_ctx.builder, node->input_hidden_size(), node->axis()).Union());
-}
-
-void OpExporterLet<OE::CIRC>::visit(luci::CircleInstanceNorm *node)
+namespace luci
{
- export_simple(node, circle::BuiltinOperator_INSTANCE_NORM,
- circle::BuiltinOptions_InstanceNormOptions,
- CreateInstanceNormOptions(_ctx.builder, node->epsilon(),
- to_circle_actfunc(node->fusedActivationFunction()))
- .Union());
-}
-void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder, SerializedModelData &md,
- SerializedGraphData &gd, uint32_t node_position)
+void exportNodes(loco::Graph *g, flatbuffers::FlatBufferBuilder &builder, SerializedModelData &md,
+ SerializedGraphData &gd)
{
- if (auto circle_node = dynamic_cast<luci::CircleNode *>(node))
+ uint32_t node_position = 0;
+ for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
{
ExportContext ctx{builder, md, gd};
- OperationExporter exporter{ctx};
+ OperationExporterRule exporter_rule{ctx};
+
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ circle_node->accept(&exporter_rule);
const auto ops_size = gd._operators.size();
- exporter.export_node(circle_node);
if (has_origin(circle_node) && ops_size != gd._operators.size())
{
const auto node_id = gd._operators.size() - 1;
@@ -1716,25 +60,7 @@ void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder, Seria
}
md._metadata.add_execution_plan_table(node_position, execution_plan_vector);
}
- }
- else
- {
- INTERNAL_EXN("Node with unsupported dialect found");
- }
-}
-} // namespace
-
-namespace luci
-{
-
-void exportNodes(loco::Graph *g, FlatBufferBuilder &builder, SerializedModelData &md,
- SerializedGraphData &gd)
-{
- uint32_t node_position = 0;
- for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
- {
- exportNode(node, builder, md, gd, node_position);
node_position++;
}
}
diff --git a/compiler/luci/export/src/CircleOperationExporter.h b/compiler/luci/export/src/CircleOperationExporter.h
index de6abfc54..f2b3cfd6b 100644
--- a/compiler/luci/export/src/CircleOperationExporter.h
+++ b/compiler/luci/export/src/CircleOperationExporter.h
@@ -17,7 +17,7 @@
#ifndef __CIRCLE_OPERATION_EXPORTER_H__
#define __CIRCLE_OPERATION_EXPORTER_H__
-#include "CircleExporterUtils.h"
+#include "SerializedData.h"
#include <loco/IR/Graph.h>
diff --git a/compiler/luci/export/src/CircleOperationExporterRule.cpp b/compiler/luci/export/src/CircleOperationExporterRule.cpp
new file mode 100644
index 000000000..8dc59fa9c
--- /dev/null
+++ b/compiler/luci/export/src/CircleOperationExporterRule.cpp
@@ -0,0 +1,277 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleOperationExporterRule.h"
+#include "CircleBuiltinTypesExtractor.h"
+#include "Check.h"
+
+#include <loco/IR/Graph.h>
+#include <luci/IR/CircleNode.h>
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+#include <oops/InternalExn.h>
+
+#include <vector>
+
+namespace
+{
+class OutputVectorExtractor final : public luci::CircleNodeMutableVisitor<std::vector<int32_t>>
+{
+public:
+ OutputVectorExtractor()
+ {
+ // DO NOTHING
+ }
+
+public:
+ std::vector<int32_t> visit(luci::CircleNode *node) final
+ {
+ std::vector<int32_t> outputs_vec{luci::get_tensor_index(node)};
+ return outputs_vec;
+ }
+
+ std::vector<int32_t> visit(luci::CircleBidirectionalSequenceLSTM *node) final
+ {
+ auto bidi_lstm_outs = loco::succs(node);
+ assert((bidi_lstm_outs.size() == 1) || (bidi_lstm_outs.size() == 2));
+
+ std::vector<int32_t> outputs_vec(bidi_lstm_outs.size());
+
+ for (auto out : bidi_lstm_outs)
+ {
+ auto bidi_lstm_out = loco::must_cast<luci::CircleBidirectionalSequenceLSTMOut *>(out);
+ if (bidi_lstm_out->index() >= int32_t(bidi_lstm_outs.size()))
+ INTERNAL_EXN("Invalid BidirectionalSequenceLSTM output");
+ outputs_vec[bidi_lstm_out->index()] = luci::get_tensor_index(bidi_lstm_out);
+ }
+
+ return outputs_vec;
+ }
+
+ std::vector<int32_t> visit(luci::CircleCustom *node) final
+ {
+ auto custom_outputs = loco::succs(node);
+ assert(custom_outputs.size() == node->numOutputs());
+
+ std::vector<int32_t> outputs_vec(node->numOutputs());
+
+ for (auto out : custom_outputs)
+ {
+ auto custom_out = loco::must_cast<luci::CircleCustomOut *>(out);
+ if (custom_out->index() >= int32_t(node->numOutputs()))
+ INTERNAL_EXN("Invalid Custom output");
+ outputs_vec[custom_out->index()] = luci::get_tensor_index(custom_out);
+ }
+
+ return outputs_vec;
+ }
+
+ std::vector<int32_t> visit(luci::CircleIf *node) final
+ {
+ auto if_outs = loco::succs(node);
+ assert(if_outs.size() == node->output_count());
+
+ std::vector<int32_t> outputs_vec(node->output_count());
+
+ for (auto out : if_outs)
+ {
+ auto if_out = loco::must_cast<luci::CircleIfOut *>(out);
+ if (if_out->index() >= int32_t(node->output_count()))
+ INTERNAL_EXN("Invalid If output");
+ outputs_vec[if_out->index()] = luci::get_tensor_index(if_out);
+ }
+
+ return outputs_vec;
+ }
+
+ std::vector<int32_t> visit(luci::CircleNonMaxSuppressionV4 *node) final
+ {
+ auto nms_outs = loco::succs(node);
+ assert(nms_outs.size() == 2);
+
+ std::vector<int32_t> outputs_vec(2);
+
+ for (auto out : nms_outs)
+ {
+ auto nms_out = loco::must_cast<luci::CircleNonMaxSuppressionV4Out *>(out);
+ if (nms_out->index() >= 2)
+ INTERNAL_EXN("Invalid NonMaxSuppressionV4 output");
+ outputs_vec[nms_out->index()] = luci::get_tensor_index(nms_out);
+ }
+
+ return outputs_vec;
+ }
+
+ std::vector<int32_t> visit(luci::CircleNonMaxSuppressionV5 *node) final
+ {
+ auto nms_outs = loco::succs(node);
+ assert(nms_outs.size() == 3);
+
+ std::vector<int32_t> outputs_vec(3);
+
+ for (auto out : nms_outs)
+ {
+ auto nms_out = loco::must_cast<luci::CircleNonMaxSuppressionV5Out *>(out);
+ if (nms_out->index() >= 3)
+ INTERNAL_EXN("Invalid NonMaxSuppressionV5 output");
+ outputs_vec[nms_out->index()] = luci::get_tensor_index(nms_out);
+ }
+
+ return outputs_vec;
+ }
+
+ std::vector<int32_t> visit(luci::CircleSplit *node) final
+ {
+ auto split_outs = loco::succs(node);
+ assert(int32_t(split_outs.size()) == node->num_split());
+
+ std::vector<int32_t> outputs_vec(node->num_split());
+
+ for (auto out : split_outs)
+ {
+ auto split_out = loco::must_cast<luci::CircleSplitOut *>(out);
+ if (split_out->index() >= node->num_split())
+ INTERNAL_EXN("Invalid Split output");
+ outputs_vec[split_out->index()] = luci::get_tensor_index(split_out);
+ }
+
+ return outputs_vec;
+ }
+
+ std::vector<int32_t> visit(luci::CircleSplitV *node) final
+ {
+ auto split_outs = loco::succs(node);
+ assert(int32_t(split_outs.size()) == node->num_split());
+
+ std::vector<int32_t> outputs_vec(node->num_split());
+
+ for (auto out : split_outs)
+ {
+ auto split_out = loco::must_cast<luci::CircleSplitVOut *>(out);
+ if (split_out->index() >= node->num_split())
+ INTERNAL_EXN("Invalid SplitV output");
+ outputs_vec[split_out->index()] = luci::get_tensor_index(split_out);
+ }
+
+ return outputs_vec;
+ }
+
+ std::vector<int32_t> visit(luci::CircleTopKV2 *node) final
+ {
+ auto topkv2_outs = loco::succs(node);
+ assert(topkv2_outs.size() == 2);
+
+ std::vector<int32_t> outputs_vec(2);
+
+ for (auto out : topkv2_outs)
+ {
+ auto topkv2_out = loco::must_cast<luci::CircleTopKV2Out *>(out);
+ if (topkv2_out->index() >= 2)
+ INTERNAL_EXN("Invalid TopKV2 output");
+ outputs_vec[topkv2_out->index()] = luci::get_tensor_index(topkv2_out);
+ }
+
+ return outputs_vec;
+ }
+
+ std::vector<int32_t> visit(luci::CircleUnique *node) final
+ {
+ auto unique_outs = loco::succs(node);
+ assert(unique_outs.size() == 2);
+
+ std::vector<int32_t> outputs_vec(2);
+
+ for (auto out : unique_outs)
+ {
+ auto unique_out = loco::must_cast<luci::CircleUniqueOut *>(out);
+ if (unique_out->index() >= 2)
+ INTERNAL_EXN("Invalid Unique output");
+ outputs_vec[unique_out->index()] = luci::get_tensor_index(unique_out);
+ }
+
+ return outputs_vec;
+ }
+
+ std::vector<int32_t> visit(luci::CircleUnpack *node) final
+ {
+ auto unpack_outs = loco::succs(node);
+ assert(int32_t(unpack_outs.size()) == node->num());
+
+ std::vector<int32_t> outputs_vec(node->num());
+
+ for (auto out : unpack_outs)
+ {
+ auto unpack_out = loco::must_cast<luci::CircleUnpackOut *>(out);
+ if (unpack_out->index() >= node->num())
+ INTERNAL_EXN("Invalid Unpack output");
+ outputs_vec[unpack_out->index()] = luci::get_tensor_index(unpack_out);
+ }
+
+ return outputs_vec;
+ }
+
+ std::vector<int32_t> visit(luci::CircleWhile *node) final
+ {
+ auto while_outs = loco::succs(node);
+ assert(while_outs.size() == node->output_count());
+
+ std::vector<int32_t> outputs_vec(node->output_count());
+
+ for (auto out : while_outs)
+ {
+ auto while_out = loco::must_cast<luci::CircleWhileOut *>(out);
+ if (while_out->index() >= int32_t(node->output_count()))
+ INTERNAL_EXN("Invalid While output");
+ outputs_vec[while_out->index()] = luci::get_tensor_index(while_out);
+ }
+
+ return outputs_vec;
+ }
+};
+
+} // namespace
+
+namespace luci
+{
+
+void OperationExporterRule::visit(luci::CircleNode *node)
+{
+ auto op_idx = _ctx.md.registerBuiltinOpcode(circle_builtin_operator(node),
+ circle_custom_code(node), node->op_version());
+
+ std::vector<int32_t> inputs_vec;
+ for (uint32_t i = 0; i < node->arity(); ++i)
+ inputs_vec.push_back(luci::get_tensor_index(node->arg(i)));
+ auto inputs = _ctx.builder.CreateVector(inputs_vec);
+
+ OutputVectorExtractor outputs_vec_extractor;
+ auto outputs_vec = node->accept(&outputs_vec_extractor);
+ auto outputs = _ctx.builder.CreateVector(outputs_vec);
+
+ auto builtin_options = circle_builtin_options(node);
+
+ luci::BuiltinOptionsExtractor builtin_options_extractor(_ctx.builder);
+ auto options_offset = node->accept(&builtin_options_extractor);
+
+ // If node is not CircleCustom, null offset(0) is returned
+ auto custom_options = circle_custom_options(_ctx.builder, node);
+
+ auto op_offset = circle::CreateOperator(_ctx.builder, op_idx, inputs, outputs, builtin_options,
+ options_offset, custom_options);
+ _ctx.gd._operators.push_back(op_offset);
+}
+
+} // namespace luci
diff --git a/compiler/luci/export/src/CircleOperationExporterRule.h b/compiler/luci/export/src/CircleOperationExporterRule.h
new file mode 100644
index 000000000..23e7546cf
--- /dev/null
+++ b/compiler/luci/export/src/CircleOperationExporterRule.h
@@ -0,0 +1,76 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CIRCLE_OPERATION_EXPORTER_RULE_H__
+#define __CIRCLE_OPERATION_EXPORTER_RULE_H__
+
+#include "CircleOperationExporter.h"
+
+#include <luci/IR/CircleNode.h>
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+namespace luci
+{
+
+struct ExportContext
+{
+ flatbuffers::FlatBufferBuilder &builder;
+ luci::SerializedModelData &md;
+ luci::SerializedGraphData &gd;
+};
+
+class OperationExporterRule final : public luci::CircleNodeMutableVisitor<void>
+{
+public:
+ OperationExporterRule(ExportContext &ctx) : _ctx{ctx}
+ {
+ // DO NOTHING
+ }
+
+public:
+ // Default export rule
+ void visit(luci::CircleNode *node) final;
+
+ // Non-virtual
+ void visit(luci::CircleConst *) final{/* skip, everything is done in exportOpDefinedTensors */};
+
+ // Virtual
+ void visit(luci::CircleInput *) final {}
+ void visit(luci::CircleOutput *) final {}
+ void visit(luci::CircleOutputDummy *) final {}
+ void visit(luci::CircleOutputExclude *) final {}
+ // Virtual for multiple-outputs
+ void visit(luci::CircleBidirectionalSequenceLSTMOut *) final {}
+ void visit(luci::CircleCustomOut *) final {}
+ void visit(luci::CircleIfOut *) final {}
+ void visit(luci::CircleNonMaxSuppressionV4Out *) final {}
+ void visit(luci::CircleNonMaxSuppressionV5Out *) final {}
+ void visit(luci::CircleSplitOut *) final {}
+ void visit(luci::CircleSplitVOut *) final {}
+ void visit(luci::CircleTopKV2Out *) final {}
+ void visit(luci::CircleUniqueOut *) final {}
+ void visit(luci::CircleUnpackOut *) final {}
+ void visit(luci::CircleVariable *) final {}
+ void visit(luci::CircleWhileOut *) final {}
+
+protected:
+ ExportContext &_ctx;
+};
+
+} // namespace luci
+
+#endif // __CIRCLE_OPERATION_EXPORTER_RULE_H__
diff --git a/compiler/luci/export/src/CircleOps.lst b/compiler/luci/export/src/CircleOps.lst
new file mode 100644
index 000000000..1b6909303
--- /dev/null
+++ b/compiler/luci/export/src/CircleOps.lst
@@ -0,0 +1,154 @@
+#ifndef CIRCLE_NODE
+#error "Define CIRCLE_NODE"
+#endif // CIRCLE_NODE
+
+#ifndef CIRCLE_VNODE
+#error "Define CIRCLE_VNODE"
+#endif // CIRCLE_VNODE
+
+//
+// PLEASE SORT NODE DECLS IN ALPHABETICAL ORDER
+//
+// NOTE : CIRCLE_VNODE does not have any additional parameters
+// because they are not circle builtin operators
+// Please add parameters when they are needed.
+//
+// CIRCLE_NODE(CircleNode, circle::BuiltinOperator, circle::BuiltinOptions)
+// CIRCLE_VNODE(CircleNode)
+//
+
+CIRCLE_NODE(CircleAbs, BuiltinOperator_ABS, BuiltinOptions_AbsOptions)
+CIRCLE_NODE(CircleAdd, BuiltinOperator_ADD, BuiltinOptions_AddOptions)
+CIRCLE_NODE(CircleAddN, BuiltinOperator_ADD_N, BuiltinOptions_AddNOptions)
+CIRCLE_NODE(CircleArgMax, BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions)
+CIRCLE_NODE(CircleArgMin, BuiltinOperator_ARG_MIN, BuiltinOptions_ArgMinOptions)
+CIRCLE_NODE(CircleAveragePool2D, BuiltinOperator_AVERAGE_POOL_2D , BuiltinOptions_Pool2DOptions)
+CIRCLE_NODE(CircleBatchToSpaceND, BuiltinOperator_BATCH_TO_SPACE_ND, BuiltinOptions_BatchToSpaceNDOptions)
+CIRCLE_NODE(CircleBatchMatMul, BuiltinOperator_BATCH_MATMUL, BuiltinOptions_BatchMatMulOptions)
+CIRCLE_NODE(CircleBidirectionalSequenceLSTM, BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, BuiltinOptions_BidirectionalSequenceLSTMOptions)
+CIRCLE_NODE(CircleCast, BuiltinOperator_CAST, BuiltinOptions_CastOptions)
+CIRCLE_NODE(CircleCeil, BuiltinOperator_CEIL, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleConcatenation, BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions)
+CIRCLE_NODE(CircleConv2D, BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions)
+CIRCLE_NODE(CircleCos, BuiltinOperator_COS, BuiltinOptions_CosOptions)
+CIRCLE_NODE(CircleCustom, BuiltinOperator_CUSTOM, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleDepthToSpace, BuiltinOperator_DEPTH_TO_SPACE, BuiltinOptions_DepthToSpaceOptions)
+CIRCLE_NODE(CircleDepthwiseConv2D, BuiltinOperator_DEPTHWISE_CONV_2D, BuiltinOptions_DepthwiseConv2DOptions)
+CIRCLE_NODE(CircleDequantize, BuiltinOperator_DEQUANTIZE, BuiltinOptions_DequantizeOptions)
+CIRCLE_NODE(CircleDiv, BuiltinOperator_DIV, BuiltinOptions_DivOptions)
+CIRCLE_NODE(CircleElu, BuiltinOperator_ELU, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleEqual, BuiltinOperator_EQUAL, BuiltinOptions_EqualOptions)
+CIRCLE_NODE(CircleExp, BuiltinOperator_EXP, BuiltinOptions_ExpOptions)
+CIRCLE_NODE(CircleExpandDims, BuiltinOperator_EXPAND_DIMS, BuiltinOptions_ExpandDimsOptions)
+CIRCLE_NODE(CircleFakeQuant, BuiltinOperator_FAKE_QUANT, BuiltinOptions_FakeQuantOptions)
+CIRCLE_NODE(CircleFill, BuiltinOperator_FILL, BuiltinOptions_FillOptions)
+CIRCLE_NODE(CircleFloor, BuiltinOperator_FLOOR, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleFloorDiv, BuiltinOperator_FLOOR_DIV, BuiltinOptions_FloorDivOptions)
+CIRCLE_NODE(CircleFloorMod, BuiltinOperator_FLOOR_MOD, BuiltinOptions_FloorModOptions)
+CIRCLE_NODE(CircleFullyConnected, BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions)
+CIRCLE_NODE(CircleGather, BuiltinOperator_GATHER, BuiltinOptions_GatherOptions)
+CIRCLE_NODE(CircleGatherNd, BuiltinOperator_GATHER_ND, BuiltinOptions_GatherNdOptions)
+CIRCLE_NODE(CircleGreater, BuiltinOperator_GREATER, BuiltinOptions_GreaterOptions)
+CIRCLE_NODE(CircleGreaterEqual, BuiltinOperator_GREATER_EQUAL, BuiltinOptions_GreaterEqualOptions)
+CIRCLE_NODE(CircleIf, BuiltinOperator_IF, BuiltinOptions_IfOptions)
+CIRCLE_NODE(CircleL2Normalize, BuiltinOperator_L2_NORMALIZATION, BuiltinOptions_L2NormOptions)
+CIRCLE_NODE(CircleL2Pool2D, BuiltinOperator_L2_POOL_2D, BuiltinOptions_Pool2DOptions)
+CIRCLE_NODE(CircleLeakyRelu, BuiltinOperator_LEAKY_RELU, BuiltinOptions_LeakyReluOptions)
+CIRCLE_NODE(CircleLess, BuiltinOperator_LESS, BuiltinOptions_LessOptions)
+CIRCLE_NODE(CircleLessEqual, BuiltinOperator_LESS_EQUAL, BuiltinOptions_LessEqualOptions)
+CIRCLE_NODE(CircleLocalResponseNormalization, BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, BuiltinOptions_LocalResponseNormalizationOptions)
+CIRCLE_NODE(CircleLog, BuiltinOperator_LOG, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleLogicalAnd, BuiltinOperator_LOGICAL_AND, BuiltinOptions_LogicalAndOptions)
+CIRCLE_NODE(CircleLogicalNot, BuiltinOperator_LOGICAL_NOT, BuiltinOptions_LogicalNotOptions)
+CIRCLE_NODE(CircleLogicalOr, BuiltinOperator_LOGICAL_OR, BuiltinOptions_LogicalOrOptions)
+CIRCLE_NODE(CircleLogistic, BuiltinOperator_LOGISTIC, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleLogSoftmax, BuiltinOperator_LOG_SOFTMAX, BuiltinOptions_LogSoftmaxOptions)
+CIRCLE_NODE(CircleMatrixDiag, BuiltinOperator_MATRIX_DIAG, BuiltinOptions_MatrixDiagOptions)
+CIRCLE_NODE(CircleMaxPool2D, BuiltinOperator_MAX_POOL_2D, BuiltinOptions_Pool2DOptions)
+CIRCLE_NODE(CircleMatrixSetDiag, BuiltinOperator_MATRIX_SET_DIAG, BuiltinOptions_MatrixSetDiagOptions)
+CIRCLE_NODE(CircleMaximum, BuiltinOperator_MAXIMUM, BuiltinOptions_MaximumMinimumOptions)
+CIRCLE_NODE(CircleMean, BuiltinOperator_MEAN, BuiltinOptions_ReducerOptions)
+CIRCLE_NODE(CircleMinimum, BuiltinOperator_MINIMUM, BuiltinOptions_MaximumMinimumOptions)
+CIRCLE_NODE(CircleMirrorPad, BuiltinOperator_MIRROR_PAD, BuiltinOptions_MirrorPadOptions)
+CIRCLE_NODE(CircleMul, BuiltinOperator_MUL, BuiltinOptions_MulOptions)
+CIRCLE_NODE(CircleNeg, BuiltinOperator_NEG, BuiltinOptions_NegOptions)
+CIRCLE_NODE(CircleNonMaxSuppressionV4, BuiltinOperator_NON_MAX_SUPPRESSION_V4, BuiltinOptions_NonMaxSuppressionV4Options)
+CIRCLE_NODE(CircleNonMaxSuppressionV5, BuiltinOperator_NON_MAX_SUPPRESSION_V5, BuiltinOptions_NonMaxSuppressionV5Options)
+CIRCLE_NODE(CircleNotEqual, BuiltinOperator_NOT_EQUAL, BuiltinOptions_NotEqualOptions)
+CIRCLE_NODE(CircleOneHot, BuiltinOperator_ONE_HOT, BuiltinOptions_OneHotOptions)
+CIRCLE_NODE(CirclePack, BuiltinOperator_PACK, BuiltinOptions_PackOptions)
+CIRCLE_NODE(CirclePad, BuiltinOperator_PAD, BuiltinOptions_PadOptions)
+CIRCLE_NODE(CirclePadV2, BuiltinOperator_PADV2, BuiltinOptions_PadV2Options)
+CIRCLE_NODE(CirclePow, BuiltinOperator_POW, BuiltinOptions_PowOptions)
+CIRCLE_NODE(CirclePRelu, BuiltinOperator_PRELU, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleQuantize, BuiltinOperator_QUANTIZE, BuiltinOptions_QuantizeOptions)
+CIRCLE_NODE(CircleRange, BuiltinOperator_RANGE, BuiltinOptions_RangeOptions)
+CIRCLE_NODE(CircleRank, BuiltinOperator_RANK, BuiltinOptions_RankOptions)
+CIRCLE_NODE(CircleReduceAny, BuiltinOperator_REDUCE_ANY, BuiltinOptions_ReducerOptions)
+CIRCLE_NODE(CircleReduceMax, BuiltinOperator_REDUCE_MAX, BuiltinOptions_ReducerOptions)
+CIRCLE_NODE(CircleReduceMin, BuiltinOperator_REDUCE_MIN, BuiltinOptions_ReducerOptions)
+CIRCLE_NODE(CircleReduceProd, BuiltinOperator_REDUCE_PROD, BuiltinOptions_ReducerOptions)
+CIRCLE_NODE(CircleRelu, BuiltinOperator_RELU, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleRelu6, BuiltinOperator_RELU6, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleReluN1To1, BuiltinOperator_RELU_N1_TO_1, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleReshape, BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions)
+CIRCLE_NODE(CircleResizeBilinear, BuiltinOperator_RESIZE_BILINEAR, BuiltinOptions_ResizeBilinearOptions)
+CIRCLE_NODE(CircleResizeNearestNeighbor, BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, BuiltinOptions_ResizeNearestNeighborOptions)
+CIRCLE_NODE(CircleReverseSequence, BuiltinOperator_REVERSE_SEQUENCE, BuiltinOptions_ReverseSequenceOptions)
+CIRCLE_NODE(CircleReverseV2, BuiltinOperator_REVERSE_V2, BuiltinOptions_ReverseV2Options)
+CIRCLE_NODE(CircleRound, BuiltinOperator_ROUND, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleRsqrt, BuiltinOperator_RSQRT, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleScatterNd, BuiltinOperator_SCATTER_ND, BuiltinOptions_ScatterNdOptions)
+CIRCLE_NODE(CircleSegmentSum, BuiltinOperator_SEGMENT_SUM, BuiltinOptions_SegmentSumOptions)
+CIRCLE_NODE(CircleSelect, BuiltinOperator_SELECT, BuiltinOptions_SelectOptions)
+CIRCLE_NODE(CircleSelectV2, BuiltinOperator_SELECT_V2, BuiltinOptions_SelectV2Options)
+CIRCLE_NODE(CircleShape, BuiltinOperator_SHAPE, BuiltinOptions_ShapeOptions)
+CIRCLE_NODE(CircleSin, BuiltinOperator_SIN, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleSlice, BuiltinOperator_SLICE, BuiltinOptions_SliceOptions)
+CIRCLE_NODE(CircleSoftmax, BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions)
+CIRCLE_NODE(CircleSpaceToBatchND, BuiltinOperator_SPACE_TO_BATCH_ND, BuiltinOptions_SpaceToBatchNDOptions)
+CIRCLE_NODE(CircleSpaceToDepth, BuiltinOperator_SPACE_TO_DEPTH, BuiltinOptions_SpaceToDepthOptions)
+CIRCLE_NODE(CircleSparseToDense, BuiltinOperator_SPARSE_TO_DENSE, BuiltinOptions_SparseToDenseOptions)
+CIRCLE_NODE(CircleSplit, BuiltinOperator_SPLIT, BuiltinOptions_SplitOptions)
+CIRCLE_NODE(CircleSplitV, BuiltinOperator_SPLIT_V, BuiltinOptions_SplitVOptions)
+CIRCLE_NODE(CircleSqrt, BuiltinOperator_SQRT, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleSquare, BuiltinOperator_SQUARE, BuiltinOptions_SquareOptions)
+CIRCLE_NODE(CircleSquaredDifference, BuiltinOperator_SQUARED_DIFFERENCE, BuiltinOptions_SquaredDifferenceOptions)
+CIRCLE_NODE(CircleSqueeze, BuiltinOperator_SQUEEZE, BuiltinOptions_SqueezeOptions)
+CIRCLE_NODE(CircleStridedSlice, BuiltinOperator_STRIDED_SLICE, BuiltinOptions_StridedSliceOptions)
+CIRCLE_NODE(CircleSub, BuiltinOperator_SUB, BuiltinOptions_SubOptions)
+CIRCLE_NODE(CircleSum, BuiltinOperator_SUM, BuiltinOptions_ReducerOptions)
+CIRCLE_NODE(CircleSVDF, BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions)
+CIRCLE_NODE(CircleTanh, BuiltinOperator_TANH, BuiltinOptions_NONE)
+CIRCLE_NODE(CircleTile, BuiltinOperator_TILE, BuiltinOptions_TileOptions)
+CIRCLE_NODE(CircleTopKV2, BuiltinOperator_TOPK_V2, BuiltinOptions_TopKV2Options)
+CIRCLE_NODE(CircleTranspose, BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions)
+CIRCLE_NODE(CircleTransposeConv, BuiltinOperator_TRANSPOSE_CONV, BuiltinOptions_TransposeConvOptions)
+CIRCLE_NODE(CircleUnidirectionalSequenceLSTM, BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, BuiltinOptions_UnidirectionalSequenceLSTMOptions)
+CIRCLE_NODE(CircleUnique, BuiltinOperator_UNIQUE, BuiltinOptions_UniqueOptions)
+CIRCLE_NODE(CircleUnpack, BuiltinOperator_UNPACK, BuiltinOptions_UnpackOptions)
+CIRCLE_NODE(CircleWhere, BuiltinOperator_WHERE, BuiltinOptions_WhereOptions)
+CIRCLE_NODE(CircleWhile, BuiltinOperator_WHILE, BuiltinOptions_WhileOptions)
+CIRCLE_NODE(CircleZerosLike, BuiltinOperator_ZEROS_LIKE, BuiltinOptions_ZerosLikeOptions)
+// Circle Only
+CIRCLE_NODE(CircleBCQFullyConnected, BuiltinOperator_BCQ_FULLY_CONNECTED, BuiltinOptions_BCQFullyConnectedOptions)
+CIRCLE_NODE(CircleBCQGather, BuiltinOperator_BCQ_GATHER, BuiltinOptions_BCQGatherOptions)
+CIRCLE_NODE(CircleInstanceNorm, BuiltinOperator_INSTANCE_NORM, BuiltinOptions_InstanceNormOptions)
+// Virtual node(s)
+CIRCLE_VNODE(CircleBidirectionalSequenceLSTMOut)
+CIRCLE_VNODE(CircleConst)
+CIRCLE_VNODE(CircleInput)
+CIRCLE_VNODE(CircleOutput)
+CIRCLE_VNODE(CircleOutputDummy)
+CIRCLE_VNODE(CircleOutputExclude)
+CIRCLE_VNODE(CircleCustomOut)
+CIRCLE_VNODE(CircleIfOut)
+CIRCLE_VNODE(CircleNonMaxSuppressionV4Out)
+CIRCLE_VNODE(CircleNonMaxSuppressionV5Out)
+CIRCLE_VNODE(CircleSplitOut)
+CIRCLE_VNODE(CircleSplitVOut)
+CIRCLE_VNODE(CircleTopKV2Out)
+CIRCLE_VNODE(CircleUniqueOut)
+CIRCLE_VNODE(CircleUnpackOut)
+CIRCLE_VNODE(CircleVariable)
+CIRCLE_VNODE(CircleWhileOut)
diff --git a/compiler/luci/export/src/CircleTensorExporter.cpp b/compiler/luci/export/src/CircleTensorExporter.cpp
index 615402aa8..b3bb850cc 100644
--- a/compiler/luci/export/src/CircleTensorExporter.cpp
+++ b/compiler/luci/export/src/CircleTensorExporter.cpp
@@ -67,6 +67,9 @@ public:
luci::SparsityParam *sparsityparam(void) const { return _sparsityparam; }
void sparsityparam(luci::SparsityParam *sp) { _sparsityparam = sp; }
+ bool is_variable(void) const { return _is_variable; }
+ void is_variable(bool v) { _is_variable = v; }
+
private:
std::string _name;
@@ -77,6 +80,8 @@ private:
luci::CircleConst *_content = nullptr;
luci::CircleQuantParam *_quantparam = nullptr;
luci::SparsityParam *_sparsityparam = nullptr;
+
+ bool _is_variable = false;
};
class CircleTensorContext
@@ -145,6 +150,8 @@ void allocateCircleTensorInfo(CircleNode *node, CircleTensorContext &ctx)
tensor_info.quantparam(node->quantparam());
tensor_info.sparsityparam(node->sparsityparam());
+ tensor_info.is_variable(dynamic_cast<luci::CircleVariable *>(node) != nullptr);
+
set_tensor_index(node, tensor_index);
ctx.emplace_back(tensor_info);
@@ -592,9 +599,11 @@ void exportOpDefinedTensor(const CircleTensorInfo &info, FlatBufferBuilder &buil
auto buffer_id = get_buffer_id(builder, md, info.content());
auto name_offset = builder.CreateString(info.name());
- auto tensor_offset =
- CreateTensor(builder, shape_offset, info.dtype(), buffer_id, name_offset, quantparam,
- /*is_variable*/ false, sparsityparam, shape_signature_offset);
+
+ auto is_variable = info.is_variable();
+
+ auto tensor_offset = CreateTensor(builder, shape_offset, info.dtype(), buffer_id, name_offset,
+ quantparam, is_variable, sparsityparam, shape_signature_offset);
gd._tensors.push_back(tensor_offset);
}
diff --git a/compiler/luci/export/src/SerializedData.h b/compiler/luci/export/src/SerializedData.h
index a945eecf7..136a8ac49 100644
--- a/compiler/luci/export/src/SerializedData.h
+++ b/compiler/luci/export/src/SerializedData.h
@@ -23,7 +23,7 @@
#include <luci/IR/ExecutionPlanTable.h>
#include <vector>
-
+#include <string>
#include <unordered_map>
#include <map>
@@ -131,8 +131,8 @@ struct SerializedModelData final
* @param builtin_code
* @return idx of opcode in table of opcodes (see schema)
*/
- uint32_t registerBuiltinOpcode(circle::BuiltinOperator builtin_code, const int32_t op_version);
- uint32_t registerCustomOpcode(const std::string &custom_op);
+ uint32_t registerBuiltinOpcode(circle::BuiltinOperator builtin_code,
+ const std::string &custom_code, const int32_t op_version);
};
// Prerequisites for circle::Model object creation
diff --git a/compiler/luci/import/CMakeLists.txt b/compiler/luci/import/CMakeLists.txt
index 6630cab9f..1b2db23ae 100644
--- a/compiler/luci/import/CMakeLists.txt
+++ b/compiler/luci/import/CMakeLists.txt
@@ -12,13 +12,14 @@ target_include_directories(luci_import PUBLIC include)
target_link_libraries(luci_import PUBLIC luci_lang)
target_link_libraries(luci_import PUBLIC luci_profile)
target_link_libraries(luci_import PUBLIC luci_plan)
-target_link_libraries(luci_import PUBLIC mio_circle)
+target_link_libraries(luci_import PUBLIC mio_circle04)
target_link_libraries(luci_import PRIVATE luci_env)
target_link_libraries(luci_import PRIVATE luci_log)
target_link_libraries(luci_import PRIVATE luci_logex)
target_link_libraries(luci_import PRIVATE nncc_common)
target_link_libraries(luci_import PRIVATE locop)
target_link_libraries(luci_import PRIVATE oops)
+target_link_libraries(luci_import PRIVATE mio_circle04_helper)
install(TARGETS luci_import DESTINATION lib)
install(DIRECTORY include/ DESTINATION include
FILES_MATCHING PATTERN "*.h")
@@ -32,7 +33,3 @@ nnas_find_package(GTest REQUIRED)
GTest_AddTest(luci_import_test ${TESTS})
target_include_directories(luci_import_test PRIVATE src)
target_link_libraries(luci_import_test luci_import)
-target_link_libraries(luci_import_test oops)
-target_link_libraries(luci_import_test luci_plan)
-target_link_libraries(luci_import_test luci_lang)
-target_link_libraries(luci_import_test mio_circle)
diff --git a/compiler/luci/import/include/luci/Import/CircleReader.h b/compiler/luci/import/include/luci/Import/CircleReader.h
index fb38ba90b..a0519f661 100644
--- a/compiler/luci/import/include/luci/Import/CircleReader.h
+++ b/compiler/luci/import/include/luci/Import/CircleReader.h
@@ -35,19 +35,7 @@
namespace luci
{
-bool is_valid(const circle::OperatorCodeT &opcode);
-bool is_valid(const circle::OperatorCode *opcode);
-
-bool is_custom(const circle::OperatorCodeT &opcode);
-bool is_custom(const circle::OperatorCode *opcode);
-
-std::string opcode_name(const circle::OperatorCodeT &opcode);
-std::string opcode_name(const circle::OperatorCode *opcode);
-
-const char *tensor_name(const circle::TensorT &tensor);
const char *tensor_name(const circle::Tensor *tensor);
-
-const circle::QuantizationParametersT *tensor_quantization(const circle::TensorT &tensor);
const circle::QuantizationParameters *tensor_quantization(const circle::Tensor *tensor);
loco::DataType luci_datatype(circle::TensorType type);
@@ -57,14 +45,13 @@ MirrorPadMode luci_mirrorpad_mode(const circle::MirrorPadMode mode);
luci::CircleFullyConnected::WeightsFormat
luci_weights_format(const circle::FullyConnectedOptionsWeightsFormat weights_format);
std::unique_ptr<CircleQuantParam>
-luci_quantparam(const circle::QuantizationParametersT *quantization);
-std::unique_ptr<CircleQuantParam>
luci_quantparam(const circle::QuantizationParameters *quantization);
/// @brief Copy common tensor attributes such as name, type, etc. to node.
-void copy_tensor_attributes(const circle::TensorT &tensor, CircleNode *node);
void copy_tensor_attributes(const circle::Tensor *tensor, CircleNode *node);
+std::string fb_string2std_string(const flatbuffers::String *fb_str);
+
/**
* @brief Wrapper to use flatbuffers::Vector pointer as std::vector entity
*/
@@ -101,13 +88,6 @@ template <typename T> VectorWrapper<T> wrap(const flatbuffers::Vector<T> *vec)
*/
class CircleReader
{
-private: // unpack API
- using CircleBuffers_t = std::vector<std::unique_ptr<circle::BufferT>>;
- using CircleTensors_t = std::vector<std::unique_ptr<circle::TensorT>>;
- using CircleOperators_t = std::vector<std::unique_ptr<circle::OperatorT>>;
- using CircleOperatorCodes_t = std::vector<std::unique_ptr<circle::OperatorCodeT>>;
- using CircleMetadata_t = std::vector<std::unique_ptr<circle::MetadataT>>;
-
private: // direct API
using CircleBuffers = VectorWrapper<flatbuffers::Offset<circle::Buffer>>;
using CircleTensors = VectorWrapper<flatbuffers::Offset<circle::Tensor>>;
@@ -115,40 +95,21 @@ private: // direct API
using CircleOperatorCodes = VectorWrapper<flatbuffers::Offset<circle::OperatorCode>>;
using CircleMetadataSet = VectorWrapper<flatbuffers::Offset<circle::Metadata>>;
- using CircleSubGraphsPtr_t = flatbuffers::Vector<flatbuffers::Offset<circle::SubGraph>>;
- using CircleTensorsPtr_t = flatbuffers::Vector<flatbuffers::Offset<circle::Tensor>>;
-
public:
CircleReader() = default;
-public: // unpack API
- const CircleOperatorCodes_t &opcodes() const { return _model->operator_codes; }
- const CircleBuffers_t &buffers() const { return _model->buffers; }
- const CircleTensors_t &tensors() const { return _current_subgraph->tensors; }
- const CircleOperators_t &operators() const { return _current_subgraph->operators; }
- const std::vector<int32_t> &inputs() const { return _current_subgraph->inputs; }
- const std::vector<int32_t> &outputs() const { return _current_subgraph->outputs; }
- const std::string &name() const { return _current_subgraph->name; }
- const circle::DataFormat &data_format() const { return _current_subgraph->data_format; }
- const CircleMetadata_t &metadata() const { return _model->metadata; }
-
- const CircleTensorsPtr_t *tensors_ptr() const { return _tensors_ptr; }
-
- uint32_t num_subgraph() const { return _model->subgraphs.size(); }
-
- circle::BuiltinOperator builtin_code(const circle::OperatorT &op) const;
- std::string opcode_name(const circle::OperatorT &op) const;
-
public: // direct API
- CircleOperatorCodes native_opcodes() const { return wrap(_native_model->operator_codes()); }
- CircleBuffers native_buffers() const { return wrap(_native_model->buffers()); }
- CircleTensors native_tensors() const { return wrap(_native_subgraph->tensors()); }
- CircleOperators native_operators() const { return wrap(_native_subgraph->operators()); }
- VectorWrapper<int32_t> native_inputs() const { return wrap(_native_subgraph->inputs()); }
- VectorWrapper<int32_t> native_outputs() const { return wrap(_native_subgraph->outputs()); }
- std::string native_name() const { return _native_subgraph->name()->str(); }
- circle::DataFormat native_data_format() const { return _native_subgraph->data_format(); }
- CircleMetadataSet native_metadata() const { return wrap(_native_model->metadata()); }
+ CircleOperatorCodes opcodes() const { return wrap(_model->operator_codes()); }
+ CircleBuffers buffers() const { return wrap(_model->buffers()); }
+ CircleTensors tensors() const { return wrap(_current_subgraph->tensors()); }
+ CircleOperators operators() const { return wrap(_current_subgraph->operators()); }
+ VectorWrapper<int32_t> inputs() const { return wrap(_current_subgraph->inputs()); }
+ VectorWrapper<int32_t> outputs() const { return wrap(_current_subgraph->outputs()); }
+ std::string name() const { return fb_string2std_string(_current_subgraph->name()); }
+ circle::DataFormat data_format() const { return _current_subgraph->data_format(); }
+ CircleMetadataSet metadata() const { return wrap(_model->metadata()); }
+
+ uint32_t num_subgraph() const { return wrap(_model->subgraphs()).size(); }
circle::BuiltinOperator builtin_code(const circle::Operator *op) const;
std::string opcode_name(const circle::Operator *op) const;
@@ -158,12 +119,8 @@ public:
bool select_subgraph(uint32_t subgraph);
private:
- std::unique_ptr<const circle::ModelT> _model;
- const circle::SubGraphT *_current_subgraph{nullptr};
-
- const circle::Model *_native_model{nullptr};
- const CircleTensorsPtr_t *_tensors_ptr{nullptr};
- const circle::SubGraph *_native_subgraph{nullptr};
+ const circle::Model *_model{nullptr};
+ const circle::SubGraph *_current_subgraph{nullptr};
};
} // namespace luci
diff --git a/compiler/luci/import/include/luci/Import/GraphBuilderRegistry.h b/compiler/luci/import/include/luci/Import/GraphBuilderRegistry.h
index b8dc22fdd..93e34a56b 100644
--- a/compiler/luci/import/include/luci/Import/GraphBuilderRegistry.h
+++ b/compiler/luci/import/include/luci/Import/GraphBuilderRegistry.h
@@ -18,6 +18,7 @@
#define __LUCI_IMPORT_GRAPH_BUILDER_REGISTRY_H__
#include "GraphBuilderBase.h"
+#include "NodeBuilder.h"
#include <map>
@@ -32,6 +33,11 @@ struct GraphBuilderSource
* @brief Returns registered GraphBuilder pointer for operator (nullptr if not present)
*/
virtual const GraphBuilderBase *lookup(const circle::BuiltinOperator &op) const = 0;
+
+ /**
+ * @brief Returns registered NodeBuilderBase pointer for type (nullptr if not present)
+ */
+ virtual const NodeBuilderBase *lookup(const NodeBuilderType type) const = 0;
};
/**
@@ -61,6 +67,17 @@ public:
return _builder_map.at(op).get();
}
+ /**
+ * @brief Returns registered NodeBuilderBase pointer for type or nullptr if not registered
+ */
+ const NodeBuilderBase *lookup(const NodeBuilderType type) const final
+ {
+ if (_node_builders.find(type) == _node_builders.end())
+ return (_parent == nullptr) ? nullptr : _parent->lookup(type);
+
+ return _node_builders.at(type).get();
+ }
+
static GraphBuilderRegistry &get()
{
static GraphBuilderRegistry me;
@@ -73,11 +90,17 @@ public:
_builder_map[op] = std::move(builder);
}
+ void add(std::unique_ptr<NodeBuilderBase> &&builder)
+ {
+ _node_builders[builder->builder_type()] = std::move(builder);
+ }
+
private:
const GraphBuilderSource *_parent = nullptr;
private:
std::map<const circle::BuiltinOperator, std::unique_ptr<GraphBuilderBase>> _builder_map;
+ std::map<const NodeBuilderType, std::unique_ptr<NodeBuilderBase>> _node_builders;
};
} // namespace luci
diff --git a/compiler/luci/import/include/luci/Import/NodeBuilder.h b/compiler/luci/import/include/luci/Import/NodeBuilder.h
new file mode 100644
index 000000000..440b491b0
--- /dev/null
+++ b/compiler/luci/import/include/luci/Import/NodeBuilder.h
@@ -0,0 +1,58 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_IMPORT_NODE_BUILDER_H__
+#define __LUCI_IMPORT_NODE_BUILDER_H__
+
+#include "GraphBuilderContext.h"
+#include "GraphBuilderBase.h"
+
+#include <mio/circle/schema_generated.h>
+
+namespace luci
+{
+
+/**
+ * @brief Tensor types which requires separated node
+ */
+enum class NodeBuilderType
+{
+ BUFFER,
+ // TODO Extend this struct here if needed to add new type of NodeBuilderBase
+};
+
+/**
+ * @brief Creates nodes from given Tensor and context
+ */
+class NodeBuilderBase
+{
+public:
+ virtual CircleNode *build(TensorIndex tensor_idx, GraphBuilderContext *context) const = 0;
+ virtual NodeBuilderType builder_type() const = 0;
+};
+
+/**
+ * @brief Placeholder for builders of tensors with different types
+ */
+template <NodeBuilderType Type> class TypedNodeBuilder : public NodeBuilderBase
+{
+public:
+ NodeBuilderType builder_type() const final { return Type; }
+};
+
+} // namespace luci
+
+#endif // __LUCI_IMPORT_NODE_BUILDER_H__
diff --git a/compiler/luci/import/include/luci/Import/Nodes.h b/compiler/luci/import/include/luci/Import/Nodes.h
index f7d22e7aa..7a5045ede 100644
--- a/compiler/luci/import/include/luci/Import/Nodes.h
+++ b/compiler/luci/import/include/luci/Import/Nodes.h
@@ -122,6 +122,7 @@
#include "Nodes/CircleStridedSlice.h"
#include "Nodes/CircleSub.h"
#include "Nodes/CircleSum.h"
+#include "Nodes/CircleSVDF.h"
#include "Nodes/CircleTanh.h"
#include "Nodes/CircleTile.h"
#include "Nodes/CircleTopKV2.h"
@@ -130,6 +131,7 @@
#include "Nodes/CircleUnidirectionalSequenceLSTM.h"
#include "Nodes/CircleUnique.h"
#include "Nodes/CircleUnpack.h"
+#include "Nodes/CircleVariable.h"
#include "Nodes/CircleWhere.h"
#include "Nodes/CircleWhile.h"
#include "Nodes/CircleZerosLike.h"
diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleConst.h b/compiler/luci/import/include/luci/Import/Nodes/CircleConst.h
index 7d4f10a59..9e50ddbde 100644
--- a/compiler/luci/import/include/luci/Import/Nodes/CircleConst.h
+++ b/compiler/luci/import/include/luci/Import/Nodes/CircleConst.h
@@ -17,20 +17,21 @@
#ifndef __LUCI_IMPORT_OP_CIRCLE_CONST_H__
#define __LUCI_IMPORT_OP_CIRCLE_CONST_H__
-#include "luci/Import/GraphBuilderContext.h"
+#include "luci/Import/NodeBuilder.h"
#include <luci/IR/Nodes/CircleConst.h>
-/*
- * @note Circle does not have Const operator.
- * Methods here provide helper that creates CircleConst from
- * Tensor and Buffer in circle flatbuffer file.
- */
-
namespace luci
{
-CircleConst *create_circleconst(GraphBuilderContext *context, int32_t tensor_index);
+/**
+ * @brief Builder creates CircleConst node from Tensor with buffer.
+ */
+class CircleConstNodeBuilder : public TypedNodeBuilder<NodeBuilderType::BUFFER>
+{
+public:
+ CircleNode *build(TensorIndex tensor_index, GraphBuilderContext *ctx) const final;
+};
} // namespace luci
diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleSVDF.h b/compiler/luci/import/include/luci/Import/Nodes/CircleSVDF.h
new file mode 100644
index 000000000..a91f66019
--- /dev/null
+++ b/compiler/luci/import/include/luci/Import/Nodes/CircleSVDF.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_IMPORT_OP_CIRCLE_SVDF_H__
+#define __LUCI_IMPORT_OP_CIRCLE_SVDF_H__
+
+#include "luci/Import/GraphBuilder.h"
+
+namespace luci
+{
+
+class CircleSVDFBuilder : public GraphBuilder
+{
+public:
+ bool validate(const ValidateArgs &args) const final;
+
+private:
+ CircleNode *build_node(const circle::OperatorT &op, const std::vector<CircleNode *> &inputs,
+ loco::Graph *graph) const final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_IMPORT_OP_CIRCLE_SVDF_H__
diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleVariable.h b/compiler/luci/import/include/luci/Import/Nodes/CircleVariable.h
new file mode 100644
index 000000000..4d8961fa5
--- /dev/null
+++ b/compiler/luci/import/include/luci/Import/Nodes/CircleVariable.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_IMPORT_OP_CIRCLE_VARIABLE_H__
+#define __LUCI_IMPORT_OP_CIRCLE_VARIABLE_H__
+
+#include "luci/Import/GraphBuilderContext.h"
+
+#include <luci/IR/Nodes/CircleVariable.h>
+
+/*
+ * @note Circle does not have node for variable tensor
+ * Methods here provide helper that creates CircleVariable from
+ * Tensor having is_variable true value.
+ */
+
+namespace luci
+{
+
+CircleVariable *create_circlevariable(GraphBuilderContext *context, int32_t tensor_index);
+
+} // namespace luci
+
+#endif // __LUCI_IMPORT_OP_CIRCLE_VARIABLE_H__
diff --git a/compiler/luci/import/src/CircleImportMetadata.cpp b/compiler/luci/import/src/CircleImportMetadata.cpp
index 42dcebdaa..9c1fe7356 100644
--- a/compiler/luci/import/src/CircleImportMetadata.cpp
+++ b/compiler/luci/import/src/CircleImportMetadata.cpp
@@ -21,8 +21,10 @@
namespace
{
-uint32_t read_u32(const std::vector<uint8_t> &buffer, uint32_t idx)
+template <typename VECTORTYPE> uint32_t read_u32(const VECTORTYPE &buffer, uint32_t idx)
{
+ static_assert(std::is_same<typename VECTORTYPE::value_type, uint8_t>::value, "Types mismatch!");
+
uint32_t val = 0;
val += (buffer.at(idx + 0) << 0 * 8);
val += (buffer.at(idx + 1) << 1 * 8);
@@ -37,9 +39,11 @@ namespace
{
// 'source_table' is decoded to std::map<uint32_t, std::string> format.
-const std::map<uint32_t, std::string>
-decoded_source_table(const std::vector<uint8_t> &source_table_data)
+template <typename VECTORTYPE>
+const std::map<uint32_t, std::string> decoded_source_table(const VECTORTYPE &source_table_data)
{
+ static_assert(std::is_same<typename VECTORTYPE::value_type, uint8_t>::value, "Types mismatch!");
+
std::map<uint32_t, std::string> source_id_name_map;
uint32_t idx = 0;
@@ -86,9 +90,11 @@ decoded_source_table(const std::vector<uint8_t> &source_table_data)
}
// 'op_table' is decoded to std::map<uint32_t, std::set<uint32_t>> format.
-const std::map<uint32_t, std::set<uint32_t>>
-decoded_op_table(const std::vector<uint8_t> &op_table_data)
+template <typename VECTORTYPE>
+const std::map<uint32_t, std::set<uint32_t>> decoded_op_table(const VECTORTYPE &op_table_data)
{
+ static_assert(std::is_same<typename VECTORTYPE::value_type, uint8_t>::value, "Types mismatch!");
+
std::map<uint32_t, std::set<uint32_t>> node_source_ids_map;
uint32_t idx = 0;
@@ -135,9 +141,11 @@ decoded_op_table(const std::vector<uint8_t> &op_table_data)
}
// 'execution_plan_table' is decoded to std::map<uint32_t, std::vector<uint32_t>> format.
-const luci::ExecutionPlanTable
-decoded_execution_plan(const std::vector<uint8_t> &execution_plan_data)
+template <typename VECTORTYPE>
+const luci::ExecutionPlanTable decoded_execution_plan(const VECTORTYPE &execution_plan_data)
{
+ static_assert(std::is_same<typename VECTORTYPE::value_type, uint8_t>::value, "Types mismatch!");
+
luci::ExecutionPlanTable execution_plan_table;
uint32_t idx = 0;
@@ -156,6 +164,10 @@ decoded_execution_plan(const std::vector<uint8_t> &execution_plan_data)
idx += sizeof(uint32_t);
uint32_t size = read_u32(execution_plan_data, idx);
+
+ if (size == 0)
+ throw std::runtime_error("Op table decode error : empty execution plan entry");
+
idx += sizeof(uint32_t);
if (idx + sizeof(uint32_t) * size > execution_plan_data.size())
@@ -190,19 +202,22 @@ namespace luci
CircleImportMetadata::CircleImportMetadata(const luci::CircleReader &reader)
{
- const auto &metadata = reader.metadata();
+ const auto metadata = reader.metadata();
for (uint32_t i = 0; i < metadata.size(); ++i)
{
- const circle::MetadataT &meta = *metadata[i];
+ const auto *meta = metadata[i];
+ assert(meta != nullptr);
- assert(meta.buffer < reader.buffers().size());
- const std::vector<uint8_t> &buffer = reader.buffers()[meta.buffer]->data;
+ assert(meta->buffer() < reader.buffers().size());
+ assert(reader.buffers()[meta->buffer()] != nullptr);
+ const auto buffer = luci::wrap(reader.buffers()[meta->buffer()]->data());
- if (meta.name.compare("ONE_op_table") == 0)
+ assert(meta->name() != nullptr);
+ if (meta->name()->str().compare("ONE_op_table") == 0)
_op_table = decoded_op_table(buffer);
- else if (meta.name.compare("ONE_source_table") == 0)
+ else if (meta->name()->str().compare("ONE_source_table") == 0)
_source_table = decoded_source_table(buffer);
- else if (meta.name.compare("ONE_execution_plan_table") == 0)
+ else if (meta->name()->str().compare("ONE_execution_plan_table") == 0)
_execution_plan_table = decoded_execution_plan(buffer);
}
}
diff --git a/compiler/luci/import/src/CircleReader.cpp b/compiler/luci/import/src/CircleReader.cpp
index 14917ba06..a42c3f913 100644
--- a/compiler/luci/import/src/CircleReader.cpp
+++ b/compiler/luci/import/src/CircleReader.cpp
@@ -16,6 +16,9 @@
#include "luci/Import/CircleReader.h"
+#include <mio_circle/Helper.h>
+
+#include <algorithm>
#include <memory>
#include <sstream>
#include <string>
@@ -23,103 +26,14 @@
namespace luci
{
-bool is_valid(const circle::OperatorCodeT &opcode)
-{
- circle::BuiltinOperator code = opcode.builtin_code;
- return (circle::BuiltinOperator_MIN <= code && code <= circle::BuiltinOperator_MAX);
-}
-
-bool is_valid(const circle::OperatorCode *opcode)
-{
- assert(opcode != nullptr);
- circle::BuiltinOperator code = opcode->builtin_code();
- return (circle::BuiltinOperator_MIN <= code && code <= circle::BuiltinOperator_MAX);
-}
-
-bool is_custom(const circle::OperatorCodeT &opcode)
-{
- circle::BuiltinOperator code = opcode.builtin_code;
- return (code == circle::BuiltinOperator_CUSTOM);
-}
-
-bool is_custom(const circle::OperatorCode *opcode)
-{
- assert(opcode != nullptr);
- circle::BuiltinOperator code = opcode->builtin_code();
- return (code == circle::BuiltinOperator_CUSTOM);
-}
-
-std::string opcode_name(const circle::OperatorCodeT &opcode)
-{
- if (!is_valid(opcode))
- {
- std::ostringstream oss;
- oss << "(invalid)";
- return oss.str();
- }
-
- if (is_custom(opcode))
- {
- if (opcode.custom_code.empty())
- return "(invalid custom)";
-
- return opcode.custom_code;
- }
-
- circle::BuiltinOperator code = opcode.builtin_code;
- return circle::EnumNameBuiltinOperator(code);
-}
-
-std::string opcode_name(const circle::OperatorCode *opcode)
-{
- assert(opcode != nullptr);
-
- if (!is_valid(opcode))
- {
- std::ostringstream oss;
- oss << "(invalid)";
- return oss.str();
- }
-
- if (is_custom(opcode))
- {
- auto custom_code = opcode->custom_code()->str();
- if (custom_code.empty())
- return "(invalid custom)";
-
- return custom_code;
- }
-
- circle::BuiltinOperator code = opcode->builtin_code();
- return circle::EnumNameBuiltinOperator(code);
-}
-
-const char *tensor_name(const circle::TensorT &tensor)
-{
- static const char *kEmptyTensorName = "(noname)";
-
- if (!tensor.name.empty())
- return tensor.name.c_str();
-
- return kEmptyTensorName;
-}
-
const char *tensor_name(const circle::Tensor *tensor)
{
assert(tensor != nullptr);
- static const char *kEmptyTensorName = "(noname)";
- const auto tensor_name = tensor->name()->c_str();
-
- if (!std::string(tensor_name).empty())
- return tensor_name;
+ if (tensor->name() == nullptr || std::string(tensor->name()->c_str()).empty())
+ return "(noname)";
- return kEmptyTensorName;
-}
-
-const circle::QuantizationParametersT *tensor_quantization(const circle::TensorT &tensor)
-{
- return tensor.quantization.get();
+ return tensor->name()->c_str();
}
const circle::QuantizationParameters *tensor_quantization(const circle::Tensor *tensor)
@@ -334,41 +248,6 @@ std::unique_ptr<SparsityParam> luci_sparsityparam(const circle::SparsityParamete
return luci_sparsityparam(&sparsity);
}
-void copy_tensor_attributes(const circle::TensorT &tensor, CircleNode *node)
-{
- node->name(tensor_name(tensor));
- node->dtype(luci_datatype(tensor.type));
-
- assert(tensor.shape_signature.size() == 0 ||
- tensor.shape_signature.size() == tensor.shape.size());
-
- std::vector<int32_t> dims = tensor.shape; // in NHWC
- node->rank(dims.size());
- for (uint32_t r = 0; r < dims.size(); ++r)
- {
- if (tensor.shape_signature.size() > 0 && tensor.shape_signature.at(r) == -1)
- node->dim(r).unset();
- else
- node->dim(r).set(dims[r]);
- }
-
- const auto *quantization = tensor.quantization.get();
- if (quantization != nullptr)
- {
- auto quantparam = luci_quantparam(quantization);
- if (quantparam)
- node->quantparam(std::move(quantparam));
- }
-
- const auto *sparsity = tensor.sparsity.get();
- if (sparsity != nullptr)
- {
- auto sparsityparam = luci_sparsityparam(sparsity);
- if (sparsityparam)
- node->sparsityparam(std::move(sparsityparam));
- }
-}
-
void copy_tensor_attributes(const circle::Tensor *tensor, CircleNode *node)
{
assert(tensor != nullptr);
@@ -408,63 +287,60 @@ void copy_tensor_attributes(const circle::Tensor *tensor, CircleNode *node)
}
}
-circle::BuiltinOperator CircleReader::builtin_code(const circle::OperatorT &op) const
+std::string fb_string2std_string(const flatbuffers::String *fb_str)
{
- const auto &op_codes = opcodes();
- uint32_t index = op.opcode_index;
+ return fb_str == nullptr ? "" : fb_str->str();
+}
+
+circle::BuiltinOperator CircleReader::builtin_code(const circle::Operator *op) const
+{
+ assert(op != nullptr);
+
+ const auto op_codes = opcodes();
+ uint32_t index = op->opcode_index();
assert(index < op_codes.size());
- const circle::OperatorCodeT &opcode = *op_codes[index];
+ const auto opcode = op_codes[index];
+ assert(opcode != nullptr);
- return opcode.builtin_code;
+ return mio::circle::builtin_code_neutral(opcode);
}
-std::string CircleReader::opcode_name(const circle::OperatorT &op) const
+std::string CircleReader::opcode_name(const circle::Operator *op) const
{
- const auto &op_codes = opcodes();
- uint32_t index = op.opcode_index;
- assert(index < op_codes.size());
- const circle::OperatorCodeT &opcode = *op_codes[index];
+ assert(op != nullptr);
- if (!is_valid(opcode))
- {
- std::ostringstream oss;
- oss << "(invalid: " << index << ")";
- return oss.str();
- }
+ const auto op_codes = opcodes();
+ uint32_t index = op->opcode_index();
+ assert(index < op_codes.size());
+ const auto opcode = op_codes[index];
- return ::luci::opcode_name(opcode);
+ return mio::circle::opcode_name(opcode);
}
bool CircleReader::parse(const circle::Model *model)
{
assert(model != nullptr);
- _model.reset(model->UnPack());
-
// for direct pointer access
- _native_model = model;
+ _model = model;
return true;
}
bool CircleReader::select_subgraph(uint32_t sgindex)
{
- if (_model->subgraphs.size() <= sgindex)
+ if (num_subgraph() <= sgindex)
{
assert(false);
return false;
}
- _current_subgraph = _model->subgraphs[sgindex].get();
-
// for direct pointer access
- auto subgraphs = _native_model->subgraphs();
+ auto subgraphs = _model->subgraphs();
assert(subgraphs != nullptr);
- _native_subgraph = subgraphs->Get(sgindex);
- assert(_native_subgraph != nullptr);
-
- _tensors_ptr = _native_subgraph->tensors();
+ _current_subgraph = subgraphs->Get(sgindex);
+ assert(_current_subgraph != nullptr);
return true;
}
diff --git a/compiler/luci/import/src/GraphBuilder.cpp b/compiler/luci/import/src/GraphBuilder.cpp
index 356501c2f..59a08b546 100644
--- a/compiler/luci/import/src/GraphBuilder.cpp
+++ b/compiler/luci/import/src/GraphBuilder.cpp
@@ -29,10 +29,9 @@ CircleNode *GraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext
const std::vector<int32_t> &inputs = op.inputs;
const std::vector<int32_t> &outputs = op.outputs;
- const auto &tensors = context->reader()->tensors();
- const auto &opcodes = context->reader()->opcodes();
- auto tensors_ptr = context->reader()->tensors_ptr();
- assert(tensors_ptr != nullptr);
+ const auto tensors = context->reader()->tensors();
+ const auto opcodes = context->reader()->opcodes();
+ assert(!tensors.null());
std::vector<CircleNode *> input_nodes;
for (const int32_t input_tensor_index : inputs)
@@ -60,16 +59,18 @@ CircleNode *GraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext
// Set up node parameters.
assert(outputs.size() == 1);
{
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
+ const auto output_tensor = tensors[outputs[0]];
+ assert(output_tensor != nullptr);
copy_tensor_attributes(output_tensor, node);
// mark shape_status
- if (tensors_ptr->Get(outputs[0])->shape() == nullptr)
+ if (output_tensor->shape() == nullptr)
node->shape_status(ShapeStatus::NOSHAPE);
else
node->shape_status(ShapeStatus::VALID);
// mark operator version
- node->op_version(opcodes[op.opcode_index].get()->version);
+ assert(opcodes[op.opcode_index] != nullptr);
+ node->op_version(opcodes[op.opcode_index]->version());
}
// Register node's only output.
diff --git a/compiler/luci/import/src/GraphBuilderMultiOutput.cpp b/compiler/luci/import/src/GraphBuilderMultiOutput.cpp
index be553f4c0..4df8d1e5a 100644
--- a/compiler/luci/import/src/GraphBuilderMultiOutput.cpp
+++ b/compiler/luci/import/src/GraphBuilderMultiOutput.cpp
@@ -30,10 +30,9 @@ CircleNode *GraphBuilderMultiOutput::build(const circle::OperatorT &op,
const std::vector<int32_t> &inputs = op.inputs;
const std::vector<int32_t> &outputs = op.outputs;
- const auto &tensors = context->reader()->tensors();
- const auto &opcodes = context->reader()->opcodes();
- auto tensors_ptr = context->reader()->tensors_ptr();
- assert(tensors_ptr != nullptr);
+ const auto tensors = context->reader()->tensors();
+ const auto opcodes = context->reader()->opcodes();
+ assert(!tensors.null());
std::vector<CircleNode *> input_nodes;
for (const int32_t input_tensor_index : inputs)
@@ -64,12 +63,14 @@ CircleNode *GraphBuilderMultiOutput::build(const circle::OperatorT &op,
if (output_count > 0)
{
// Let's use attributes from output 0 for this node
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
+ const auto output_tensor = tensors[outputs[0]];
+ assert(output_tensor != nullptr);
node->name(tensor_name(output_tensor));
- node->dtype(luci_datatype(output_tensor.type));
+ node->dtype(luci_datatype(output_tensor->type()));
// mark operator version
- node->op_version(opcodes[op.opcode_index].get()->version);
+ assert(opcodes[op.opcode_index] != nullptr);
+ node->op_version(opcodes[op.opcode_index]->version());
// NOTE We don't set quantization for multiple output nodes but to virtual outputs
}
@@ -77,7 +78,8 @@ CircleNode *GraphBuilderMultiOutput::build(const circle::OperatorT &op,
// Create virtual outputs of Virtual Output node(s)
for (uint32_t n = 0; n < output_count; ++n)
{
- const circle::TensorT &output_tensor = *tensors[outputs[n]];
+ const auto output_tensor = tensors[outputs[n]];
+ assert(output_tensor != nullptr);
BuildOutArgs boa(node, n);
auto *nodeout = build_out(boa);
@@ -85,7 +87,7 @@ CircleNode *GraphBuilderMultiOutput::build(const circle::OperatorT &op,
copy_tensor_attributes(output_tensor, nodeout);
// NOTE name of CxxxOut nodes may have same name
// mark shape_status
- if (tensors_ptr->Get(outputs[n])->shape() == nullptr)
+ if (output_tensor->shape() == nullptr)
nodeout->shape_status(ShapeStatus::NOSHAPE);
else
nodeout->shape_status(ShapeStatus::VALID);
diff --git a/compiler/luci/import/src/GraphBuilderRegistry.cpp b/compiler/luci/import/src/GraphBuilderRegistry.cpp
index df07d9e48..fe2d830e9 100644
--- a/compiler/luci/import/src/GraphBuilderRegistry.cpp
+++ b/compiler/luci/import/src/GraphBuilderRegistry.cpp
@@ -131,6 +131,7 @@ GraphBuilderRegistry::GraphBuilderRegistry()
CIRCLE_NODE(STRIDED_SLICE, CircleStridedSliceGraphBuilder); // 45
CIRCLE_NODE(SUB, CircleSubGraphBuilder); // 41
CIRCLE_NODE(SUM, CircleSumGraphBuilder); // 74
+ CIRCLE_NODE(SVDF, CircleSVDFBuilder); // 27
CIRCLE_NODE(TANH, CircleTanhGraphBuilder); // 28
CIRCLE_NODE(TILE, CircleTileGraphBuilder); // 69
CIRCLE_NODE(TOPK_V2, CircleTopKV2GraphBuilder); // 48
@@ -150,7 +151,6 @@ GraphBuilderRegistry::GraphBuilderRegistry()
// BuiltinOperator_LSH_PROJECTION = 15,
// BuiltinOperator_LSTM = 16,
// BuiltinOperator_RNN = 24,
- // BuiltinOperator_SVDF = 27,
// BuiltinOperator_CONCAT_EMBEDDINGS = 29,
// BuiltinOperator_SKIP_GRAM = 30,
// BuiltinOperator_CALL = 31,
@@ -161,6 +161,13 @@ GraphBuilderRegistry::GraphBuilderRegistry()
// BuiltinOperator_ARG_MAX = 56,
// BuiltinOperator_HARD_SWISH = 117,
// BuiltinOperator_DENSIFY = 124,
+
+ // Register builders for nodes which not handles in builders registered above.
+#define CIRCLE_NODE(CLASS) add(std::make_unique<CLASS>())
+
+ CIRCLE_NODE(CircleConstNodeBuilder);
+
+#undef CIRCLE_NODE
}
} // namespace luci
diff --git a/compiler/luci/import/src/Importer.cpp b/compiler/luci/import/src/Importer.cpp
index 3f7f78591..15de03df2 100644
--- a/compiler/luci/import/src/Importer.cpp
+++ b/compiler/luci/import/src/Importer.cpp
@@ -23,6 +23,7 @@
#include "luci/Import/GraphBuilderRegistry.h"
#include "luci/Import/CircleReader.h"
#include "luci/Import/Nodes/CircleConst.h"
+#include "luci/Import/Nodes/CircleVariable.h"
#include <luci/IR/Module.h>
#include <luci/IR/CircleNodes.h>
@@ -50,18 +51,18 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r
luci::GraphBuilderContext gb_context(graph, &reader, nodefinder.get(), tensoroutputs.get());
- const auto &operators = reader.operators();
- const auto &tensors = reader.tensors();
- auto tensors_ptr = reader.tensors_ptr();
- assert(tensors_ptr != nullptr);
+ const auto operators = reader.operators();
+ const auto tensors = reader.tensors();
+ assert(!tensors.null());
auto circle_metadata = std::make_unique<luci::CircleImportMetadata>(reader);
// build a cache to identify if a tensor is output of an operator
// if this is set, we should not create a CircleConst for this tensor
for (uint32_t i = 0; i < operators.size(); ++i)
{
- const circle::OperatorT &op = *operators[i];
- const auto &outputs = op.outputs;
+ const auto op = operators[i];
+ assert(op != nullptr);
+ const auto outputs = luci::wrap(op->outputs());
for (uint32_t j = 0; j < outputs.size(); ++j)
{
@@ -77,10 +78,11 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r
{
auto input_node = graph->nodes()->create<luci::CircleInput>();
assert(input_node != nullptr);
- const circle::TensorT &tensor = *tensors[input];
+ const auto tensor = tensors[input];
+ assert(tensor != nullptr);
luci::copy_tensor_attributes(tensor, input_node);
- if (tensors_ptr->Get(input)->shape() == nullptr)
+ if (tensor->shape() == nullptr)
input_node->shape_status(luci::ShapeStatus::NOSHAPE);
else
input_node->shape_status(luci::ShapeStatus::VALID);
@@ -101,16 +103,18 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r
// Data type
graph_input->dtype(input_node->dtype());
- assert(tensor.shape_signature.size() == 0 ||
- tensor.shape_signature.size() == tensor.shape.size());
+ const auto tensor_shape_signature = luci::wrap(tensor->shape_signature());
+ const auto tensor_shape = luci::wrap(tensor->shape());
+ assert(tensor_shape_signature.size() == 0 ||
+ tensor_shape_signature.size() == tensor_shape.size());
// Shape of GraphInput
auto input_shape = std::make_unique<loco::TensorShape>();
- const std::vector<int32_t> &input_dims = tensor.shape; // in NHWC
+ const auto &input_dims = tensor_shape; // in NHWC
input_shape->rank(input_dims.size());
for (uint32_t r = 0; r < input_dims.size(); ++r)
{
- if (tensor.shape_signature.size() > 0 && tensor.shape_signature.at(r) == -1)
+ if (tensor_shape_signature.size() > 0 && tensor_shape_signature.at(r) == -1)
input_shape->dim(r).unset();
else
input_shape->dim(r).set(input_dims[r]);
@@ -118,15 +122,28 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r
graph_input->shape(std::move(input_shape));
}
- // Create CircleConst nodes for constant tensors.
+ // Create CircleNodes for constant tensors.
// NOTE Origin is intentionally not provided for constants.
+ auto const_builder = source.lookup(luci::NodeBuilderType::BUFFER);
+ if (not const_builder)
+ throw oops::UserExn("Not supported", "tensor with buffer builder");
+
for (uint32_t i = 0; i < tensors.size(); ++i)
{
- luci::CircleConst *const_node = luci::create_circleconst(&gb_context, i);
+ auto *const_node = const_builder->build(i, &gb_context);
if (const_node != nullptr)
nodefinder->enroll(i, const_node);
}
+ // Create CircleVariable nodes for variable tensors
+ // TODO Add Origin if needed, skip for now
+ for (uint32_t i = 0; i < tensors.size(); ++i)
+ {
+ luci::CircleVariable *variable_node = luci::create_circlevariable(&gb_context, i);
+ if (variable_node != nullptr)
+ nodefinder->enroll(i, variable_node);
+ }
+
// Import the operators.
// Note that operators in model are stored in execution order. This means that when importing
// an operator, its input operators have already been imported. We exploit this fact to set up
@@ -134,18 +151,23 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r
auto origin_table = circle_metadata->origin_table();
for (uint32_t i = 0; i < operators.size(); ++i)
{
- const circle::OperatorT &op = *operators[i];
+ const auto op = operators[i];
+ assert(op != nullptr);
circle::BuiltinOperator builtincode = reader.builtin_code(op);
if (const auto *builder = source.lookup(builtincode))
{
- luci::GraphBuilder::ValidateArgs args(op, reader);
+ // create temporary unpack API obj
+ circle::OperatorT oper_t;
+ op->UnPackTo(&oper_t);
+
+ luci::GraphBuilder::ValidateArgs args(oper_t, reader);
if (!builder->validate(args))
{
throw oops::UserExn("Invalid operator", reader.opcode_name(op));
}
- auto built_op = builder->build(op, &gb_context);
+ auto built_op = builder->build(oper_t, &gb_context);
set_node_id(built_op, i);
if (origin_table.find(i) != origin_table.end())
add_origin(built_op, origin_table.at(i));
@@ -161,7 +183,8 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r
// graph outputs
for (auto output : reader.outputs())
{
- const circle::TensorT &tensor = *tensors[output];
+ const auto tensor = tensors[output];
+ assert(tensor != nullptr);
auto output_node = graph->nodes()->create<luci::CircleOutput>();
assert(output_node != nullptr);
@@ -178,7 +201,7 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r
output_node->from(output_dummy);
luci::copy_tensor_attributes(tensor, output_dummy);
- if (tensors_ptr->Get(output)->shape() == nullptr)
+ if (tensor->shape() == nullptr)
output_dummy->shape_status(luci::ShapeStatus::NOSHAPE);
else
output_dummy->shape_status(luci::ShapeStatus::VALID);
@@ -197,16 +220,18 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r
// Set GraphInputOutputIndex for graph
output_node->index(graph_output->index());
- assert(tensor.shape_signature.size() == 0 ||
- tensor.shape_signature.size() == tensor.shape.size());
+ const auto tensor_shape_signature = luci::wrap(tensor->shape_signature());
+ const auto tensor_shape = luci::wrap(tensor->shape());
+ assert(tensor_shape_signature.size() == 0 ||
+ tensor_shape_signature.size() == tensor_shape.size());
// Shape of Output
auto output_shape = std::make_unique<loco::TensorShape>();
- const std::vector<int32_t> &output_dims = tensor.shape; // in NHWC
+ const auto &output_dims = tensor_shape; // in NHWC
output_shape->rank(output_dims.size());
for (uint32_t r = 0; r < output_dims.size(); ++r)
{
- if (tensor.shape_signature.size() > 0 && tensor.shape_signature.at(r) == -1)
+ if (tensor_shape_signature.size() > 0 && tensor_shape_signature.at(r) == -1)
output_shape->dim(r).unset();
else
output_shape->dim(r).set(output_dims[r]);
@@ -214,7 +239,7 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r
graph_output->shape(std::move(output_shape));
// Data type
- auto dtype = luci::luci_datatype(tensor.type);
+ auto dtype = luci::luci_datatype(tensor->type());
graph_output->dtype(dtype);
}
}
@@ -355,7 +380,12 @@ std::unique_ptr<Module> Importer::importModule(const circle::Model *model) const
{
if (auto circle_node = dynamic_cast<luci::CircleNode *>(node))
{
+ if (execution_plan_table.count(node_position) == 0)
+ continue;
+
auto node_plan = execution_plan_table[node_position];
+ assert(node_plan.size() > 0);
+
luci::add_execution_plan(
circle_node,
luci::CircleNodeExecutionPlan(
diff --git a/compiler/luci/import/src/Importer.test.cpp b/compiler/luci/import/src/Importer.test.cpp
index d963b4d49..91e4860ea 100644
--- a/compiler/luci/import/src/Importer.test.cpp
+++ b/compiler/luci/import/src/Importer.test.cpp
@@ -23,7 +23,7 @@
#include <mio/circle/schema_generated.h>
#include <flatbuffers/flatbuffers.h>
-TEST(TensorFlowLiteImport, Dummy)
+TEST(CircleImport, Dummy)
{
luci::Importer import;
@@ -68,6 +68,7 @@ struct BasicCircleModel
{
uint32_t id = model->operator_codes.size();
model->operator_codes.push_back(std::make_unique<circle::OperatorCodeT>());
+ model->operator_codes[id]->deprecated_builtin_code = opcode;
model->operator_codes[id]->builtin_code = opcode;
model->operator_codes[id]->version = 1;
return id;
@@ -179,7 +180,7 @@ struct SimpleRELUModel : public BasicCircleModel
/**
* This test checks that one op RELU model with execution plan is successfully imported
*/
-TEST(TensorFlowLiteImport, simple_plan)
+TEST(CircleImport, simple_plan)
{
SimpleRELUModel model;
auto metadata_buffer_id = model.add_buffer();
@@ -240,7 +241,7 @@ TEST(TensorFlowLiteImport, simple_plan)
/**
* This test checks that model with incomplete execution plan is successfully imported
*/
-TEST(TensorFlowLiteImport, DISABLED_incomplete_plan_NEG)
+TEST(CircleImport, incomplete_plan_NEG)
{
SimpleRELUModel model;
auto metadata_buffer_id = model.add_buffer();
@@ -287,7 +288,7 @@ TEST(TensorFlowLiteImport, DISABLED_incomplete_plan_NEG)
/**
* This test checks that corrupted execution plan induce exception
*/
-TEST(TensorFlowLiteImport, corrupted_plan_NEG)
+TEST(CircleImport, corrupted_plan_NEG)
{
SimpleRELUModel model;
auto metadata_buffer_id = model.add_buffer();
@@ -309,3 +310,44 @@ TEST(TensorFlowLiteImport, corrupted_plan_NEG)
ASSERT_ANY_THROW(import.importModule(model_ptr));
}
+
+/**
+ * This test checks that empty execution plan entry induce exception
+ */
+TEST(CircleImport, corrupted_plan_entry_NEG)
+{
+ SimpleRELUModel model;
+ auto metadata_buffer_id = model.add_buffer();
+ model.add_plan_metadata(metadata_buffer_id);
+
+ model.add_plan_entry(metadata_buffer_id, 1, {100});
+
+ // add corrupted entry with 0 size
+ {
+ auto &buffer = model.model->buffers[metadata_buffer_id]->data;
+ auto old_size = buffer.size();
+
+ // Allocate space for new entry:
+ // 4 bytes for entry id
+ // 4 bytes for entry size
+ buffer.resize(old_size + 8);
+ uint32_t *number_of_entries_ptr = reinterpret_cast<uint32_t *>(buffer.data());
+ *number_of_entries_ptr += 1;
+
+ uint32_t *entry_data_ptr = reinterpret_cast<uint32_t *>(buffer.data() + old_size);
+
+ entry_data_ptr[0] = *number_of_entries_ptr - 1; // entry id
+ entry_data_ptr[1] = 0; // entry size
+ }
+
+ model.add_plan_entry(metadata_buffer_id, 3, {200});
+
+ flatbuffers::FlatBufferBuilder fbb;
+ auto model_offset = circle::Model::Pack(fbb, model.model.get(), nullptr);
+ circle::FinishModelBuffer(fbb, model_offset);
+
+ auto model_ptr = circle::GetModel(fbb.GetBufferPointer());
+ luci::Importer import;
+
+ ASSERT_ANY_THROW(import.importModule(model_ptr));
+}
diff --git a/compiler/luci/import/src/Nodes/CircleCast.cpp b/compiler/luci/import/src/Nodes/CircleCast.cpp
index 3e8c08bfa..acde823b1 100644
--- a/compiler/luci/import/src/Nodes/CircleCast.cpp
+++ b/compiler/luci/import/src/Nodes/CircleCast.cpp
@@ -42,12 +42,14 @@ bool CircleCastGraphBuilder::validate(const ValidateArgs &args) const
const auto *options = args.op.builtin_options.AsCastOptions();
if (options != nullptr)
{
- const auto &tensors = args.reader.tensors();
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
+ const auto tensors = args.reader.tensors();
+ const auto output_tensor = tensors[outputs[0]];
+ assert(output_tensor != nullptr);
auto name = tensor_name(output_tensor);
- const auto &tensor_in = tensors.at(inputs.at(0));
- if (tensor_in->type != options->in_data_type)
+ const auto tensor_in = tensors.at(inputs.at(0));
+ assert(tensor_in != nullptr);
+ if (tensor_in->type() != options->in_data_type)
{
if (settings->get(luci::UserSettings::Key::DisableValidation))
{
@@ -57,7 +59,7 @@ bool CircleCastGraphBuilder::validate(const ValidateArgs &args) const
return false;
}
const auto &tensor_out = tensors.at(outputs[0]);
- if (tensor_out->type != options->out_data_type)
+ if (tensor_out->type() != options->out_data_type)
{
if (settings->get(luci::UserSettings::Key::DisableValidation))
{
diff --git a/compiler/luci/import/src/Nodes/CircleConst.cpp b/compiler/luci/import/src/Nodes/CircleConst.cpp
index 11fbb4e54..a4f190dd9 100644
--- a/compiler/luci/import/src/Nodes/CircleConst.cpp
+++ b/compiler/luci/import/src/Nodes/CircleConst.cpp
@@ -30,10 +30,10 @@
namespace
{
-std::ostream &operator<<(std::ostream &os, const std::vector<int32_t> &vect)
+std::ostream &operator<<(std::ostream &os, const luci::VectorWrapper<int32_t> &vect)
{
uint32_t seq = 0;
- for (auto &v : vect)
+ for (const auto &v : vect)
{
if (seq)
os << ", ";
@@ -46,7 +46,8 @@ std::ostream &operator<<(std::ostream &os, const std::vector<int32_t> &vect)
using namespace luci;
template <loco::DataType DT>
-void copy_data(const std::vector<uint8_t> &raw_data, uint32_t num_elements, CircleConst *const_node)
+void copy_data(const VectorWrapper<uint8_t> &raw_data, uint32_t num_elements,
+ CircleConst *const_node)
{
using T = typename loco::DataTypeImpl<DT>::Type;
@@ -67,8 +68,8 @@ void copy_data(const std::vector<uint8_t> &raw_data, uint32_t num_elements, Circ
}
template <>
-void copy_data<loco::DataType::STRING>(const std::vector<uint8_t> &raw_data, uint32_t num_elements,
- CircleConst *const_node)
+void copy_data<loco::DataType::STRING>(const VectorWrapper<uint8_t> &raw_data,
+ uint32_t num_elements, CircleConst *const_node)
{
assert(const_node->sparsityparam() == nullptr);
@@ -106,17 +107,26 @@ void copy_data<loco::DataType::STRING>(const std::vector<uint8_t> &raw_data, uin
namespace luci
{
-CircleConst *create_circleconst(GraphBuilderContext *context, int32_t tensor_index)
+CircleNode *CircleConstNodeBuilder::build(TensorIndex tensor_index,
+ GraphBuilderContext *context) const
{
+ assert(tensor_index >= 0);
LOGGER(l);
auto graph = context->graph();
auto reader = context->reader();
- const auto &tensors = reader->tensors();
- const circle::TensorT &const_tensor = *tensors[tensor_index];
+ const auto tensors = reader->tensors();
+ const auto const_tensor = tensors[tensor_index];
+ assert(const_tensor != nullptr);
+ if (const_tensor->is_variable())
+ {
+ // Create CircleVariable for variable
+ return nullptr;
+ }
- const std::vector<uint8_t> &buffer = reader->buffers()[const_tensor.buffer]->data;
- std::vector<int32_t> const_dims = const_tensor.shape; // in NHWC
+ assert(reader->buffers()[const_tensor->buffer()] != nullptr);
+ const auto buffer = wrap(reader->buffers()[const_tensor->buffer()]->data());
+ const auto const_dims = wrap(const_tensor->shape()); // in NHWC
if (const_dims.size() == 0 && buffer.empty())
{
// unknown shape tensor and scalar tensor
@@ -150,7 +160,7 @@ CircleConst *create_circleconst(GraphBuilderContext *context, int32_t tensor_ind
<< const_dims << std::endl;
if (num_elements > 0)
{
- switch (luci_datatype(const_tensor.type))
+ switch (luci_datatype(const_tensor->type()))
{
case loco::DataType::FLOAT32:
copy_data<loco::DataType::FLOAT32>(buffer, num_elements, const_node);
@@ -186,7 +196,7 @@ CircleConst *create_circleconst(GraphBuilderContext *context, int32_t tensor_ind
default:
throw oops::UserExn("Unsupported tensor type",
- circle::EnumNameTensorType(const_tensor.type));
+ circle::EnumNameTensorType(const_tensor->type()));
}
}
diff --git a/compiler/luci/import/src/Nodes/CircleCustom.cpp b/compiler/luci/import/src/Nodes/CircleCustom.cpp
index 01ac3e2a0..4e78d5fb7 100644
--- a/compiler/luci/import/src/Nodes/CircleCustom.cpp
+++ b/compiler/luci/import/src/Nodes/CircleCustom.cpp
@@ -39,13 +39,15 @@ CircleNode *CircleCustomGraphBuilder::build_node(const BuildNodeArgs &bna) const
node->inputs(idx, bna.input_nodes[idx]);
}
- const auto &opcodes = bna.context->reader()->opcodes();
+ const auto opcodes = bna.context->reader()->opcodes();
const uint32_t opcode_index = bna.op.opcode_index;
- const circle::OperatorCodeT &opcode = *opcodes[opcode_index];
+ const auto opcode = opcodes[opcode_index];
+ assert(opcode != nullptr);
node->custom_options(
std::vector<uint8_t>{bna.op.custom_options.begin(), bna.op.custom_options.end()});
- node->custom_code(opcode.custom_code);
+ assert(opcode->custom_code() != nullptr);
+ node->custom_code(opcode->custom_code()->c_str());
// NOTE Operator version of custom is always 1
diff --git a/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp b/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp
index 49eb30a83..83fc2e37d 100644
--- a/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp
+++ b/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp
@@ -34,9 +34,10 @@ bool CircleDepthToSpaceGraphBuilder::validate(const ValidateArgs &args) const
const auto &outputs = args.op.outputs;
const auto *options = args.op.builtin_options.AsDepthToSpaceOptions();
- const auto &tensors = args.reader.tensors();
+ const auto tensors = args.reader.tensors();
+ assert(tensors[outputs[0]] != nullptr && tensors[inputs.at(0)] != nullptr);
- if (tensors[outputs[0]]->type != tensors[inputs.at(0)]->type)
+ if (tensors[outputs[0]]->type() != tensors[inputs.at(0)]->type())
{
return false;
}
diff --git a/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp b/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp
index 727487c6a..a24e4160d 100644
--- a/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp
+++ b/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp
@@ -32,19 +32,21 @@ bool CircleDepthwiseConv2DGraphBuilder::validate(const ValidateArgs &args) const
if (args.op.outputs.size() != 1)
return false;
- const auto &tensors = args.reader.tensors();
+ const auto tensors = args.reader.tensors();
// input shape
- const auto &input = tensors.at(args.op.inputs.at(0));
- const auto &input_shape = input->shape;
+ const auto input = tensors.at(args.op.inputs.at(0));
+ assert(input != nullptr);
+ const auto input_shape = wrap(input->shape());
// input shape must be rank 4
if (input_shape.size() != 4)
return false;
// filter shape
- const auto &filter = tensors.at(args.op.inputs.at(1));
- const auto &filter_shape = filter->shape;
+ const auto filter = tensors.at(args.op.inputs.at(1));
+ assert(filter != nullptr);
+ const auto filter_shape = wrap(filter->shape());
// filter shape must be rank 4
if (filter_shape.size() != 4)
diff --git a/compiler/luci/import/src/Nodes/CircleElu.cpp b/compiler/luci/import/src/Nodes/CircleElu.cpp
index 41696a65a..e5d7a4c7a 100644
--- a/compiler/luci/import/src/Nodes/CircleElu.cpp
+++ b/compiler/luci/import/src/Nodes/CircleElu.cpp
@@ -31,10 +31,11 @@ bool CircleEluGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
- switch (tensor->type)
+ switch (tensor->type())
{
case circle::TensorType_FLOAT64:
break;
@@ -48,7 +49,8 @@ bool CircleEluGraphBuilder::validate(const ValidateArgs &args) const
return false;
}
- if (tensors[outputs[0]]->type != tensor->type)
+ assert(tensors[outputs[0]] != nullptr);
+ if (tensors[outputs[0]]->type() != tensor->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleEqual.cpp b/compiler/luci/import/src/Nodes/CircleEqual.cpp
index 4909692b4..b326d9b5d 100644
--- a/compiler/luci/import/src/Nodes/CircleEqual.cpp
+++ b/compiler/luci/import/src/Nodes/CircleEqual.cpp
@@ -29,9 +29,10 @@ bool CircleEqualGraphBuilder::validate(const ValidateArgs &args) const
return false;
const auto &inputs = args.op.inputs;
- const auto &tensors = args.reader.tensors();
+ const auto tensors = args.reader.tensors();
- return tensors[inputs.at(0)]->type == tensors[inputs.at(1)]->type;
+ assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(1)] != nullptr);
+ return tensors[inputs.at(0)]->type() == tensors[inputs.at(1)]->type();
}
CircleNode *CircleEqualGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleExp.cpp b/compiler/luci/import/src/Nodes/CircleExp.cpp
index 5bb7bb664..82c26f0e5 100644
--- a/compiler/luci/import/src/Nodes/CircleExp.cpp
+++ b/compiler/luci/import/src/Nodes/CircleExp.cpp
@@ -30,9 +30,10 @@ bool CircleExpGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
// input type check
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- switch (tensor->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ switch (tensor->type())
{
case circle::TensorType_FLOAT16:
case circle::TensorType_FLOAT32:
diff --git a/compiler/luci/import/src/Nodes/CircleExpandDims.cpp b/compiler/luci/import/src/Nodes/CircleExpandDims.cpp
index ee0fbdc7e..67d9b7e9e 100644
--- a/compiler/luci/import/src/Nodes/CircleExpandDims.cpp
+++ b/compiler/luci/import/src/Nodes/CircleExpandDims.cpp
@@ -29,9 +29,10 @@ bool CircleExpandDimsGraphBuilder::validate(const ValidateArgs &args) const
return false;
const auto &inputs = args.op.inputs;
- const auto &tensors = args.reader.tensors();
+ const auto tensors = args.reader.tensors();
- return tensors[inputs.at(1)]->type == circle::TensorType_INT32;
+ assert(tensors[inputs.at(1)] != nullptr);
+ return tensors[inputs.at(1)]->type() == circle::TensorType_INT32;
}
CircleNode *CircleExpandDimsGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp b/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp
index ce329326a..67eeddf91 100644
--- a/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp
+++ b/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp
@@ -30,15 +30,18 @@ bool CircleFloorDivGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor_in_0 = tensors.at(inputs.at(0));
- const auto &tensor_in_1 = tensors.at(inputs.at(1));
- const auto &tensor_out = tensors.at(outputs[0]);
-
- if (tensor_in_0->type != tensor_in_1->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor_in_0 = tensors.at(inputs.at(0));
+ const auto tensor_in_1 = tensors.at(inputs.at(1));
+ const auto tensor_out = tensors.at(outputs[0]);
+ assert(tensor_in_0 != nullptr);
+ assert(tensor_in_1 != nullptr);
+ assert(tensor_out != nullptr);
+
+ if (tensor_in_0->type() != tensor_in_1->type())
return false;
- if (tensor_out->type != tensor_in_1->type)
+ if (tensor_out->type() != tensor_in_1->type())
{
return false;
}
diff --git a/compiler/luci/import/src/Nodes/CircleFloorMod.cpp b/compiler/luci/import/src/Nodes/CircleFloorMod.cpp
index d8420a43c..d2a275b62 100644
--- a/compiler/luci/import/src/Nodes/CircleFloorMod.cpp
+++ b/compiler/luci/import/src/Nodes/CircleFloorMod.cpp
@@ -29,10 +29,11 @@ bool CircleFloorModGraphBuilder::validate(const ValidateArgs &args) const
return false;
const auto &inputs = args.op.inputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor_in_0 = tensors.at(inputs.at(0));
- const auto &tensor_in_1 = tensors.at(inputs.at(1));
- if (tensor_in_0->type != tensor_in_1->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor_in_0 = tensors.at(inputs.at(0));
+ const auto tensor_in_1 = tensors.at(inputs.at(1));
+ assert(tensor_in_0 != nullptr && tensor_in_1 != nullptr);
+ if (tensor_in_0->type() != tensor_in_1->type())
return false;
// TODO dtype check
diff --git a/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp b/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp
index 58750d79a..cc7be1693 100644
--- a/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp
+++ b/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp
@@ -42,6 +42,7 @@ CircleNode *CircleFullyConnectedGraphBuilder::build_node(const circle::OperatorT
const auto *options = op.builtin_options.AsFullyConnectedOptions();
node->fusedActivationFunction(luci_actfunc(options->fused_activation_function));
node->weights_format(luci_weights_format(options->weights_format));
+ node->keep_num_dims(options->keep_num_dims);
return node;
}
diff --git a/compiler/luci/import/src/Nodes/CircleGatherNd.cpp b/compiler/luci/import/src/Nodes/CircleGatherNd.cpp
index a4bb26a10..d336878ad 100644
--- a/compiler/luci/import/src/Nodes/CircleGatherNd.cpp
+++ b/compiler/luci/import/src/Nodes/CircleGatherNd.cpp
@@ -31,10 +31,11 @@ bool CircleGatherNdGraphBuilder::validate(const ValidateArgs &args) const
return false;
const auto &inputs = args.op.inputs;
- auto &indices_tensor = args.reader.tensors()[inputs.at(1)];
+ auto indices_tensor = args.reader.tensors()[inputs.at(1)];
+ assert(indices_tensor != nullptr);
- if (!(indices_tensor->type == circle::TensorType::TensorType_INT32 ||
- indices_tensor->type == circle::TensorType::TensorType_INT64))
+ if (!(indices_tensor->type() == circle::TensorType::TensorType_INT32 ||
+ indices_tensor->type() == circle::TensorType::TensorType_INT64))
{
return false;
}
diff --git a/compiler/luci/import/src/Nodes/CircleGreater.cpp b/compiler/luci/import/src/Nodes/CircleGreater.cpp
index f9c00346c..7f031b0ba 100644
--- a/compiler/luci/import/src/Nodes/CircleGreater.cpp
+++ b/compiler/luci/import/src/Nodes/CircleGreater.cpp
@@ -37,17 +37,19 @@ bool CircleGreaterGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
+ const auto tensors = args.reader.tensors();
- if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type)
+ assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(1)] != nullptr);
+ if (tensors[inputs.at(0)]->type() != tensors[inputs.at(1)]->type())
return false;
// NOTE: real models do have output dtype NOT BOOL
- if (tensors[outputs[0]]->type != circle::TensorType_BOOL)
+ assert(tensors[outputs[0]] != nullptr);
+ if (tensors[outputs[0]]->type() != circle::TensorType_BOOL)
{
if (settings->get(luci::UserSettings::Key::DisableValidation))
{
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
+ const auto output_tensor = tensors[outputs[0]];
auto name = tensor_name(output_tensor);
WARN(l) << "Warning: import Greater(" << name << ") output dtype is not boolean";
}
diff --git a/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp b/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp
index e20038fd9..ac4ce62f5 100644
--- a/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp
+++ b/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp
@@ -30,14 +30,16 @@ bool CircleGreaterEqualGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
+ const auto tensors = args.reader.tensors();
- if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type)
+ assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(1)] != nullptr);
+ if (tensors[inputs.at(0)]->type() != tensors[inputs.at(1)]->type())
{
return false;
}
- return tensors[outputs[0]]->type == circle::TensorType::TensorType_BOOL;
+ assert(tensors[outputs[0]] != nullptr);
+ return tensors[outputs[0]]->type() == circle::TensorType::TensorType_BOOL;
}
CircleNode *CircleGreaterEqualGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleIf.cpp b/compiler/luci/import/src/Nodes/CircleIf.cpp
index ffdbf0b79..e8a50ff32 100644
--- a/compiler/luci/import/src/Nodes/CircleIf.cpp
+++ b/compiler/luci/import/src/Nodes/CircleIf.cpp
@@ -42,12 +42,13 @@ bool CircleIfGraphBuilder::validate(const ValidateArgs &args) const
return false;
// input 0 should be BOOL type
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- if (tensor->type != circle::TensorType_BOOL)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ if (tensor->type() != circle::TensorType_BOOL)
return false;
- const auto &shape = tensor->shape;
+ const auto shape = wrap(tensor->shape());
if (shape.size() != 1 && shape.size() != 0)
return false;
diff --git a/compiler/luci/import/src/Nodes/CircleLess.cpp b/compiler/luci/import/src/Nodes/CircleLess.cpp
index f9b99bebe..5c5ae51e1 100644
--- a/compiler/luci/import/src/Nodes/CircleLess.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLess.cpp
@@ -30,10 +30,11 @@ bool CircleLessGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
- switch (tensor->type)
+ switch (tensor->type())
{
case circle::TensorType_FLOAT32:
case circle::TensorType_FLOAT64:
@@ -48,12 +49,14 @@ bool CircleLessGraphBuilder::validate(const ValidateArgs &args) const
return false;
}
- if (tensors[inputs.at(1)]->type != tensor->type)
+ assert(tensors[inputs.at(1)] != nullptr);
+ if (tensors[inputs.at(1)]->type() != tensor->type())
{
return false;
}
- return tensors[outputs[0]]->type == circle::TensorType_BOOL;
+ assert(tensors[outputs[0]] != nullptr);
+ return tensors[outputs[0]]->type() == circle::TensorType_BOOL;
}
CircleNode *CircleLessGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleLessEqual.cpp b/compiler/luci/import/src/Nodes/CircleLessEqual.cpp
index bb1712137..8a2aea8db 100644
--- a/compiler/luci/import/src/Nodes/CircleLessEqual.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLessEqual.cpp
@@ -30,14 +30,16 @@ bool CircleLessEqualGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
+ const auto tensors = args.reader.tensors();
- if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type)
+ assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(1)] != nullptr);
+ if (tensors[inputs.at(0)]->type() != tensors[inputs.at(1)]->type())
{
return false;
}
- return tensors[outputs[0]]->type == circle::TensorType::TensorType_BOOL;
+ assert(tensors[outputs[0]] != nullptr);
+ return tensors[outputs[0]]->type() == circle::TensorType::TensorType_BOOL;
}
CircleNode *CircleLessEqualGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleLog.cpp b/compiler/luci/import/src/Nodes/CircleLog.cpp
index 26b575070..f41926829 100644
--- a/compiler/luci/import/src/Nodes/CircleLog.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLog.cpp
@@ -32,9 +32,10 @@ bool CircleLogGraphBuilder::validate(const ValidateArgs &args) const
// input type check
// Must be one of bfloat16, half, float32, float64, complex64, complex128.
// Currently circle supports half(float16), float32, float64, complex64.
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- switch (tensor->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ switch (tensor->type())
{
case circle::TensorType_FLOAT16:
case circle::TensorType_FLOAT32:
diff --git a/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp b/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp
index b13fc2735..b61fb6f3e 100644
--- a/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp
@@ -30,11 +30,12 @@ bool CircleLogicalAndGraphBuilder::validate(const ValidateArgs &args) const
// Only BOOL type is allowed for inputs
const auto &inputs = args.op.inputs;
- const auto &tensors = args.reader.tensors();
+ const auto tensors = args.reader.tensors();
for (auto input : inputs)
{
- const auto &tensor = tensors.at(input);
- if (tensor->type != circle::TensorType::TensorType_BOOL)
+ const auto tensor = tensors.at(input);
+ assert(tensor != nullptr);
+ if (tensor->type() != circle::TensorType::TensorType_BOOL)
return false;
}
diff --git a/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp b/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp
index f68218349..43e9ed39f 100644
--- a/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp
@@ -30,9 +30,10 @@ bool CircleLogicalNotGraphBuilder::validate(const ValidateArgs &args) const
// Only BOOL type is allowed for the input
const auto &inputs = args.op.inputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- if (tensor->type != circle::TensorType::TensorType_BOOL)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ if (tensor->type() != circle::TensorType::TensorType_BOOL)
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp b/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp
index 8c9023dd3..6354e7dc1 100644
--- a/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp
@@ -30,11 +30,12 @@ bool CircleLogicalOrGraphBuilder::validate(const ValidateArgs &args) const
// Only BOOL type is allowed for inputs
const auto &inputs = args.op.inputs;
- const auto &tensors = args.reader.tensors();
+ const auto tensors = args.reader.tensors();
for (auto input : inputs)
{
- const auto &tensor = tensors.at(input);
- if (tensor->type != circle::TensorType::TensorType_BOOL)
+ const auto tensor = tensors.at(input);
+ assert(tensor != nullptr);
+ if (tensor->type() != circle::TensorType::TensorType_BOOL)
return false;
}
diff --git a/compiler/luci/import/src/Nodes/CircleLogistic.cpp b/compiler/luci/import/src/Nodes/CircleLogistic.cpp
index 0f92a9bb4..b0d08e039 100644
--- a/compiler/luci/import/src/Nodes/CircleLogistic.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLogistic.cpp
@@ -30,8 +30,9 @@ bool CircleLogisticGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
- if (tensors.at(inputs.at(0))->type != tensors.at(outputs[0])->type)
+ const auto tensors = args.reader.tensors();
+ assert(tensors.at(inputs.at(0)) != nullptr && tensors.at(outputs[0]) != nullptr);
+ if (tensors.at(inputs.at(0))->type() != tensors.at(outputs[0])->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp b/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp
index 590a07f2d..384b98586 100644
--- a/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp
+++ b/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp
@@ -30,10 +30,11 @@ bool CircleMatrixDiagGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
- if (tensors[outputs[0]]->type != tensor->type)
+ assert(tensors[outputs[0]] != nullptr && tensor != nullptr);
+ if (tensors[outputs[0]]->type() != tensor->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp b/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp
index edd7d2ae2..64870c057 100644
--- a/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp
+++ b/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp
@@ -30,10 +30,11 @@ bool CircleMatrixSetDiagGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
- if (tensors[outputs[0]]->type != tensor->type)
+ assert(tensors[outputs[0]] != nullptr && tensor != nullptr);
+ if (tensors[outputs[0]]->type() != tensor->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp
index d3d69506b..e86f2ba81 100644
--- a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp
+++ b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp
@@ -35,20 +35,26 @@ bool CircleNonMaxSuppressionV4GraphBuilder::validate(const ValidateArgs &args) c
if (outputs.size() != 2)
return false;
- const auto &tensors = args.reader.tensors();
- const auto &boxes_tensor = tensors.at(inputs[0]);
- if (boxes_tensor->shape.size() != 2)
+ const auto tensors = args.reader.tensors();
+ const auto boxes_tensor = tensors.at(inputs[0]);
+ assert(boxes_tensor != nullptr);
+ const auto boxes_tensor_shape = wrap(boxes_tensor->shape());
+ if (boxes_tensor_shape.size() != 2)
return false;
- if (boxes_tensor->shape.at(1) != 4)
+ if (boxes_tensor_shape.at(1) != 4)
return false;
- if (boxes_tensor->shape.at(0) != tensors.at(inputs[1])->shape.at(0))
+ assert(tensors.at(inputs[1]) != nullptr);
+ if (boxes_tensor_shape.at(0) != wrap(tensors.at(inputs[1])->shape()).at(0))
return false;
- if (tensors.at(inputs[2])->type != circle::TensorType_INT32)
+ assert(tensors.at(inputs[2]) != nullptr);
+ if (tensors.at(inputs[2])->type() != circle::TensorType_INT32)
return false;
- if (tensors.at(inputs[3])->type != circle::TensorType_FLOAT32)
+ assert(tensors.at(inputs[3]) != nullptr);
+ if (tensors.at(inputs[3])->type() != circle::TensorType_FLOAT32)
return false;
- if (tensors.at(inputs[4])->type != circle::TensorType_FLOAT32)
+ assert(tensors.at(inputs[4]) != nullptr);
+ if (tensors.at(inputs[4])->type() != circle::TensorType_FLOAT32)
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp
index d797d4cb7..a60eed4e4 100644
--- a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp
+++ b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp
@@ -35,22 +35,29 @@ bool CircleNonMaxSuppressionV5GraphBuilder::validate(const ValidateArgs &args) c
if (outputs.size() != 3)
return false;
- const auto &tensors = args.reader.tensors();
- const auto &boxes_tensor = tensors.at(inputs[0]);
- if (boxes_tensor->shape.size() != 2)
+ const auto tensors = args.reader.tensors();
+ const auto boxes_tensor = tensors.at(inputs[0]);
+ assert(boxes_tensor != nullptr);
+ const auto boxes_tensor_shape = wrap(boxes_tensor->shape());
+ if (boxes_tensor_shape.size() != 2)
return false;
- if (boxes_tensor->shape.at(1) != 4)
+ if (boxes_tensor_shape.at(1) != 4)
return false;
- if (boxes_tensor->shape.at(0) != tensors.at(inputs[1])->shape.at(0))
+ assert(tensors.at(inputs[1]) != nullptr);
+ if (boxes_tensor_shape.at(0) != wrap(tensors.at(inputs[1])->shape()).at(0))
return false;
- if (tensors.at(inputs[2])->type != circle::TensorType_INT32)
+ assert(tensors.at(inputs[2]) != nullptr);
+ if (tensors.at(inputs[2])->type() != circle::TensorType_INT32)
return false;
- if (tensors.at(inputs[3])->type != circle::TensorType_FLOAT32)
+ assert(tensors.at(inputs[3]) != nullptr);
+ if (tensors.at(inputs[3])->type() != circle::TensorType_FLOAT32)
return false;
- if (tensors.at(inputs[4])->type != circle::TensorType_FLOAT32)
+ assert(tensors.at(inputs[4]) != nullptr);
+ if (tensors.at(inputs[4])->type() != circle::TensorType_FLOAT32)
return false;
- if (tensors.at(inputs[5])->type != circle::TensorType_FLOAT32)
+ assert(tensors.at(inputs[5]) != nullptr);
+ if (tensors.at(inputs[5])->type() != circle::TensorType_FLOAT32)
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleNotEqual.cpp b/compiler/luci/import/src/Nodes/CircleNotEqual.cpp
index a0b8f9e4f..3f5c1e033 100644
--- a/compiler/luci/import/src/Nodes/CircleNotEqual.cpp
+++ b/compiler/luci/import/src/Nodes/CircleNotEqual.cpp
@@ -30,14 +30,16 @@ bool CircleNotEqualGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
+ const auto tensors = args.reader.tensors();
- if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type)
+ assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(1)] != nullptr);
+ if (tensors[inputs.at(0)]->type() != tensors[inputs.at(1)]->type())
{
return false;
}
- return tensors[outputs[0]]->type == circle::TensorType::TensorType_BOOL;
+ assert(tensors[outputs[0]] != nullptr);
+ return tensors[outputs[0]]->type() == circle::TensorType::TensorType_BOOL;
}
CircleNode *CircleNotEqualGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleOneHot.cpp b/compiler/luci/import/src/Nodes/CircleOneHot.cpp
index 3952cc21a..6e5f8e16f 100644
--- a/compiler/luci/import/src/Nodes/CircleOneHot.cpp
+++ b/compiler/luci/import/src/Nodes/CircleOneHot.cpp
@@ -32,21 +32,25 @@ bool CircleOneHotGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto *options = args.op.builtin_options.AsOneHotOptions();
- const auto &tensors = args.reader.tensors();
- const auto &indices = tensors.at(inputs.at(0));
- const auto &depth = tensors.at(inputs.at(1));
- const auto &on_value = tensors.at(inputs.at(2));
- const auto &off_value = tensors.at(inputs.at(3));
+ const auto tensors = args.reader.tensors();
+ const auto indices = tensors.at(inputs.at(0));
+ const auto depth = tensors.at(inputs.at(1));
+ const auto on_value = tensors.at(inputs.at(2));
+ const auto off_value = tensors.at(inputs.at(3));
+ assert(indices != nullptr);
+ assert(depth != nullptr);
+ assert(on_value != nullptr);
+ assert(off_value != nullptr);
- if (options->axis < -1 || options->axis > static_cast<int32_t>(indices->shape.size()))
+ if (options->axis < -1 || options->axis > static_cast<int32_t>(wrap(indices->shape()).size()))
return false;
- if (depth->shape.size() != 0)
+ if (wrap(depth->shape()).size() != 0)
return false;
- if (on_value->shape.size() != 0)
+ if (wrap(on_value->shape()).size() != 0)
return false;
- if (off_value->shape.size() != 0)
+ if (wrap(off_value->shape()).size() != 0)
return false;
- if (on_value->type != off_value->type)
+ if (on_value->type() != off_value->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleReduceAny.cpp b/compiler/luci/import/src/Nodes/CircleReduceAny.cpp
index 13205dd7a..ebe2368e0 100644
--- a/compiler/luci/import/src/Nodes/CircleReduceAny.cpp
+++ b/compiler/luci/import/src/Nodes/CircleReduceAny.cpp
@@ -28,17 +28,20 @@ bool CircleReduceAnyGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor_0 = tensors.at(inputs.at(0));
- const auto &tensor_1 = tensors.at(inputs.at(1));
- const auto &tensor_o = tensors.at(outputs[0]);
+ const auto tensors = args.reader.tensors();
+ const auto tensor_0 = tensors.at(inputs.at(0));
+ const auto tensor_1 = tensors.at(inputs.at(1));
+ const auto tensor_o = tensors.at(outputs[0]);
+ assert(tensor_0 != nullptr);
+ assert(tensor_1 != nullptr);
+ assert(tensor_o != nullptr);
- if (tensor_0->type != circle::TensorType_BOOL)
+ if (tensor_0->type() != circle::TensorType_BOOL)
return false;
- if (tensor_o->type != circle::TensorType_BOOL)
+ if (tensor_o->type() != circle::TensorType_BOOL)
return false;
- switch (tensor_1->type)
+ switch (tensor_1->type())
{
case circle::TensorType_INT32:
case circle::TensorType_INT64:
diff --git a/compiler/luci/import/src/Nodes/CircleReduceProd.cpp b/compiler/luci/import/src/Nodes/CircleReduceProd.cpp
index 3549c1a18..3b874b7c9 100644
--- a/compiler/luci/import/src/Nodes/CircleReduceProd.cpp
+++ b/compiler/luci/import/src/Nodes/CircleReduceProd.cpp
@@ -27,13 +27,14 @@ bool CircleReduceProdGraphBuilder::validate(const ValidateArgs &args) const
return false;
const auto &inputs = args.op.inputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor_1 = tensors.at(inputs.at(1));
+ const auto tensors = args.reader.tensors();
+ const auto tensor_1 = tensors.at(inputs.at(1));
+ assert(tensor_1 != nullptr);
// TODO check input types
// Check for reduction_indices types
- switch (tensor_1->type)
+ switch (tensor_1->type())
{
case circle::TensorType_INT32:
case circle::TensorType_INT64:
diff --git a/compiler/luci/import/src/Nodes/CircleReshape.cpp b/compiler/luci/import/src/Nodes/CircleReshape.cpp
index 401dff0fc..3421620ce 100644
--- a/compiler/luci/import/src/Nodes/CircleReshape.cpp
+++ b/compiler/luci/import/src/Nodes/CircleReshape.cpp
@@ -34,12 +34,13 @@ bool CircleReshapeGraphBuilder::validate(const ValidateArgs &args) const
if (args.op.inputs.size() == 2)
{
const auto &inputs = args.op.inputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor_in = tensors.at(inputs.at(1));
+ const auto tensors = args.reader.tensors();
+ const auto tensor_in = tensors.at(inputs.at(1));
+ assert(tensor_in != nullptr);
// NOTE fix this if there is any other case
// TensorFlow lite and circle only supports S32
- if (tensor_in->type != circle::TensorType::TensorType_INT32)
+ if (tensor_in->type() != circle::TensorType::TensorType_INT32)
return false;
}
diff --git a/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp b/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp
index 2fbb7a87c..c9cc792bb 100644
--- a/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp
+++ b/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp
@@ -30,12 +30,15 @@ bool CircleReverseSequenceGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor_in = tensors.at(inputs.at(0));
- const auto &tensor_lengths = tensors.at(inputs.at(1));
- const auto &tensor_out = tensors.at(outputs[0]);
+ const auto tensors = args.reader.tensors();
+ const auto tensor_in = tensors.at(inputs.at(0));
+ const auto tensor_lengths = tensors.at(inputs.at(1));
+ const auto tensor_out = tensors.at(outputs[0]);
+ assert(tensor_in != nullptr);
+ assert(tensor_lengths != nullptr);
+ assert(tensor_out != nullptr);
- switch (tensor_lengths->type)
+ switch (tensor_lengths->type())
{
case circle::TensorType_INT32:
case circle::TensorType_INT64:
@@ -44,7 +47,7 @@ bool CircleReverseSequenceGraphBuilder::validate(const ValidateArgs &args) const
return false;
}
- if (tensor_in->type != tensor_out->type)
+ if (tensor_in->type() != tensor_out->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleReverseV2.cpp b/compiler/luci/import/src/Nodes/CircleReverseV2.cpp
index ca7653201..c19a0fdd2 100644
--- a/compiler/luci/import/src/Nodes/CircleReverseV2.cpp
+++ b/compiler/luci/import/src/Nodes/CircleReverseV2.cpp
@@ -30,12 +30,15 @@ bool CircleReverseV2GraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor_in = tensors.at(inputs.at(0));
- const auto &tensor_axis = tensors.at(inputs.at(1));
- const auto &tensor_out = tensors.at(outputs[0]);
+ const auto tensors = args.reader.tensors();
+ const auto tensor_in = tensors.at(inputs.at(0));
+ const auto tensor_axis = tensors.at(inputs.at(1));
+ const auto tensor_out = tensors.at(outputs[0]);
+ assert(tensor_in != nullptr);
+ assert(tensor_axis != nullptr);
+ assert(tensor_out != nullptr);
- switch (tensor_axis->type)
+ switch (tensor_axis->type())
{
case circle::TensorType_INT32:
case circle::TensorType_INT64:
@@ -44,7 +47,7 @@ bool CircleReverseV2GraphBuilder::validate(const ValidateArgs &args) const
return false;
}
- if (tensor_out->type != tensor_in->type)
+ if (tensor_out->type() != tensor_in->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleRound.cpp b/compiler/luci/import/src/Nodes/CircleRound.cpp
index d13e0fafe..08cfae6c2 100644
--- a/compiler/luci/import/src/Nodes/CircleRound.cpp
+++ b/compiler/luci/import/src/Nodes/CircleRound.cpp
@@ -33,11 +33,13 @@ bool CircleRoundGraphBuilder::validate(const ValidateArgs &args) const
// Must be one of the following types
// bfloat16, half (float16), float32, float64, complex64, complex128
// Currently, circle supports float16, float32, complex64
- const auto &tensors = args.reader.tensors();
- const auto &tensor_in = tensors.at(inputs.at(0));
- const auto &tensor_out = tensors.at(outputs[0]);
+ const auto tensors = args.reader.tensors();
+ const auto tensor_in = tensors.at(inputs.at(0));
+ const auto tensor_out = tensors.at(outputs[0]);
+ assert(tensor_in != nullptr);
+ assert(tensor_out != nullptr);
- switch (tensor_in->type)
+ switch (tensor_in->type())
{
case circle::TensorType_FLOAT16:
case circle::TensorType_FLOAT32:
@@ -49,7 +51,7 @@ bool CircleRoundGraphBuilder::validate(const ValidateArgs &args) const
return false;
}
- if (tensor_out->type != tensor_in->type)
+ if (tensor_out->type() != tensor_in->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleRsqrt.cpp b/compiler/luci/import/src/Nodes/CircleRsqrt.cpp
index a9ca90832..e3bc68f8b 100644
--- a/compiler/luci/import/src/Nodes/CircleRsqrt.cpp
+++ b/compiler/luci/import/src/Nodes/CircleRsqrt.cpp
@@ -32,9 +32,10 @@ bool CircleRsqrtGraphBuilder::validate(const ValidateArgs &args) const
// Must be one of the following types
// bfloat16, half (float16), float32, float64, complex64, complex128
// Currently, circle supports float16, float32, complex64
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- switch (tensor->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ switch (tensor->type())
{
case circle::TensorType_UINT8:
case circle::TensorType_INT16:
diff --git a/compiler/luci/import/src/Nodes/CircleSVDF.cpp b/compiler/luci/import/src/Nodes/CircleSVDF.cpp
new file mode 100644
index 000000000..83a025177
--- /dev/null
+++ b/compiler/luci/import/src/Nodes/CircleSVDF.cpp
@@ -0,0 +1,67 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Import/Nodes/CircleSVDF.h"
+
+#include <luci/IR/Nodes/CircleSVDF.h>
+
+#include <loco.h>
+
+namespace luci
+{
+
+bool CircleSVDFBuilder::validate(const ValidateArgs &args) const
+{
+ const auto &inputs = args.op.inputs;
+ if (!(inputs.size() == 4 || inputs.size() == 5))
+ return false;
+
+ return true;
+}
+
+CircleNode *CircleSVDFBuilder::build_node(const circle::OperatorT &op,
+ const std::vector<CircleNode *> &inputs,
+ loco::Graph *graph) const
+{
+ auto *node = graph->nodes()->create<CircleSVDF>();
+ node->input(inputs.at(0));
+ node->weight_feature(inputs.at(1));
+ node->weight_time(inputs.at(2));
+ if (inputs.size() == 4)
+ {
+ auto *bias = graph->nodes()->create<CircleOutputExclude>();
+ // CircleOutputExclude doesn't need a type, but since all nodes must have a type,
+ // a dummy type is inserted.
+ bias->dtype(inputs.at(0)->dtype());
+ node->bias(bias);
+
+ node->input_activation_state(inputs.at(3));
+ }
+ else
+ {
+ node->bias(inputs.at(3));
+ node->input_activation_state(inputs.at(4));
+ }
+
+ const auto *options = op.builtin_options.AsSVDFOptions();
+ node->svdf_rank(options->rank);
+ node->fusedActivationFunction(luci_actfunc(options->fused_activation_function));
+ node->asymmetric_quantize_inputs(options->asymmetric_quantize_inputs);
+
+ return node;
+}
+
+} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleScatterNd.cpp b/compiler/luci/import/src/Nodes/CircleScatterNd.cpp
index f8c175110..ebe252527 100644
--- a/compiler/luci/import/src/Nodes/CircleScatterNd.cpp
+++ b/compiler/luci/import/src/Nodes/CircleScatterNd.cpp
@@ -30,14 +30,15 @@ bool CircleScatterNdGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
// indices must have the same type as shape
- const auto &tensors = args.reader.tensors();
+ const auto tensors = args.reader.tensors();
- if (tensors[inputs.at(0)]->type != tensors[inputs.at(2)]->type)
+ assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(2)] != nullptr);
+ if (tensors[inputs.at(0)]->type() != tensors[inputs.at(2)]->type())
return false;
// indices must be either int32 or int64
- if (tensors[inputs.at(0)]->type != circle::TensorType_INT32 &&
- tensors[inputs.at(0)]->type != circle::TensorType_INT64)
+ if (tensors[inputs.at(0)]->type() != circle::TensorType_INT32 &&
+ tensors[inputs.at(0)]->type() != circle::TensorType_INT64)
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp b/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp
index bfa333e8d..01d1aab44 100644
--- a/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp
@@ -30,12 +30,15 @@ bool CircleSegmentSumGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor_in = tensors.at(inputs.at(0));
- const auto &tensor_out = tensors.at(outputs[0]);
- const auto &tensor_ids = tensors.at(inputs.at(1));
+ const auto tensors = args.reader.tensors();
+ const auto tensor_in = tensors.at(inputs.at(0));
+ const auto tensor_out = tensors.at(outputs[0]);
+ const auto tensor_ids = tensors.at(inputs.at(1));
+ assert(tensor_in != nullptr);
+ assert(tensor_out != nullptr);
+ assert(tensor_ids != nullptr);
- switch (tensor_ids->type)
+ switch (tensor_ids->type())
{
case circle::TensorType_INT32:
case circle::TensorType_INT64:
@@ -44,7 +47,7 @@ bool CircleSegmentSumGraphBuilder::validate(const ValidateArgs &args) const
return false;
}
- if (tensor_out->type != tensor_in->type)
+ if (tensor_out->type() != tensor_in->type())
{
return false;
}
diff --git a/compiler/luci/import/src/Nodes/CircleSelect.cpp b/compiler/luci/import/src/Nodes/CircleSelect.cpp
index 36a5fa8a8..002f62f6c 100644
--- a/compiler/luci/import/src/Nodes/CircleSelect.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSelect.cpp
@@ -29,9 +29,10 @@ bool CircleSelectGraphBuilder::validate(const ValidateArgs &args) const
return false;
const auto &inputs = args.op.inputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- if (tensor->type != circle::TensorType_BOOL)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ if (tensor->type() != circle::TensorType_BOOL)
return false;
// TODO check dtypes for input 1, 2
diff --git a/compiler/luci/import/src/Nodes/CircleSelectV2.cpp b/compiler/luci/import/src/Nodes/CircleSelectV2.cpp
index 556c8fa33..062fdc143 100644
--- a/compiler/luci/import/src/Nodes/CircleSelectV2.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSelectV2.cpp
@@ -29,14 +29,16 @@ bool CircleSelectV2GraphBuilder::validate(const ValidateArgs &args) const
return false;
const auto &inputs = args.op.inputs;
- const auto &tensors = args.reader.tensors();
- const auto &condition = tensors.at(inputs.at(0));
- if (condition->type != circle::TensorType_BOOL)
+ const auto tensors = args.reader.tensors();
+ const auto condition = tensors.at(inputs.at(0));
+ assert(condition != nullptr);
+ if (condition->type() != circle::TensorType_BOOL)
return false;
- const auto &t = tensors.at(inputs.at(1));
- const auto &e = tensors.at(inputs.at(2));
- if (t->type != e->type)
+ const auto t = tensors.at(inputs.at(1));
+ const auto e = tensors.at(inputs.at(2));
+ assert(t != nullptr && e != nullptr);
+ if (t->type() != e->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleSin.cpp b/compiler/luci/import/src/Nodes/CircleSin.cpp
index 22f461123..51ebf0355 100644
--- a/compiler/luci/import/src/Nodes/CircleSin.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSin.cpp
@@ -30,9 +30,10 @@ bool CircleSinGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
// input type check
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- switch (tensor->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ switch (tensor->type())
{
case circle::TensorType_FLOAT16:
case circle::TensorType_FLOAT32:
diff --git a/compiler/luci/import/src/Nodes/CircleSquare.cpp b/compiler/luci/import/src/Nodes/CircleSquare.cpp
index 7ff2b84e6..bec84b4c0 100644
--- a/compiler/luci/import/src/Nodes/CircleSquare.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSquare.cpp
@@ -29,13 +29,13 @@ bool CircleSquareGraphBuilder::validate(const ValidateArgs &args) const
return false;
const auto &inputs = args.op.inputs;
- // Must be one of the following types
- // bfloat16, half (float16), float32, float64, complex64, complex128
- // Currently, circle supports float16, float32, complex64
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- switch (tensor->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ switch (tensor->type())
{
+ case circle::TensorType_UINT8:
+ case circle::TensorType_INT16:
case circle::TensorType_INT32:
case circle::TensorType_INT64:
case circle::TensorType_FLOAT16:
diff --git a/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp b/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp
index 33440d5ab..1983465d3 100644
--- a/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp
@@ -32,9 +32,10 @@ bool CircleSquaredDifferenceGraphBuilder::validate(const ValidateArgs &args) con
const auto &outputs = args.op.outputs;
// Inputs must be one of the following types
// bfloat16, half(float16), float32, float64, int32, int64, complex64, complex128
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- switch (tensor->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ switch (tensor->type())
{
case circle::TensorType_FLOAT16:
case circle::TensorType_FLOAT32:
@@ -53,11 +54,13 @@ bool CircleSquaredDifferenceGraphBuilder::validate(const ValidateArgs &args) con
}
// Input types must match
- if (tensors.at(inputs.at(0))->type != tensors.at(inputs.at(1))->type)
+ assert(tensors.at(inputs.at(0)) != nullptr && tensors.at(inputs.at(1)) != nullptr);
+ if (tensors.at(inputs.at(0))->type() != tensors.at(inputs.at(1))->type())
return false;
// Input and output types must match
- if (tensors.at(inputs.at(0))->type != tensors.at(outputs[0])->type)
+ assert(tensors.at(outputs[0]) != nullptr);
+ if (tensors.at(inputs.at(0))->type() != tensors.at(outputs[0])->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleTanh.cpp b/compiler/luci/import/src/Nodes/CircleTanh.cpp
index 95625a0e4..80a0e887f 100644
--- a/compiler/luci/import/src/Nodes/CircleTanh.cpp
+++ b/compiler/luci/import/src/Nodes/CircleTanh.cpp
@@ -30,8 +30,9 @@ bool CircleTanhGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
- if (tensors.at(inputs.at(0))->type != tensors.at(outputs[0])->type)
+ const auto tensors = args.reader.tensors();
+ assert(tensors.at(inputs.at(0)) != nullptr && tensors.at(outputs[0]) != nullptr);
+ if (tensors.at(inputs.at(0))->type() != tensors.at(outputs[0])->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleTile.cpp b/compiler/luci/import/src/Nodes/CircleTile.cpp
index 6da44130c..c41a6ba3f 100644
--- a/compiler/luci/import/src/Nodes/CircleTile.cpp
+++ b/compiler/luci/import/src/Nodes/CircleTile.cpp
@@ -32,9 +32,10 @@ bool CircleTileGraphBuilder::validate(const ValidateArgs &args) const
auto outputs = args.op.outputs;
// Multiples (inputs.at(1)) must be one of the following types
// int32, int64
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(1));
- switch (tensor->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(1));
+ assert(tensor != nullptr);
+ switch (tensor->type())
{
case circle::TensorType_INT32:
case circle::TensorType_INT64:
@@ -44,7 +45,8 @@ bool CircleTileGraphBuilder::validate(const ValidateArgs &args) const
}
// Type of input and output must be the same
- if (tensors.at(inputs.at(0))->type != tensors.at(outputs[0])->type)
+ assert(tensors.at(inputs.at(0)) != nullptr && tensors.at(outputs[0]) != nullptr);
+ if (tensors.at(inputs.at(0))->type() != tensors.at(outputs[0])->type())
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleTopKV2.cpp b/compiler/luci/import/src/Nodes/CircleTopKV2.cpp
index 49f858798..9f9173738 100644
--- a/compiler/luci/import/src/Nodes/CircleTopKV2.cpp
+++ b/compiler/luci/import/src/Nodes/CircleTopKV2.cpp
@@ -35,9 +35,10 @@ bool CircleTopKV2GraphBuilder::validate(const ValidateArgs &args) const
if (outputs.size() != 2)
return false;
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(1));
- if (tensor->type != circle::TensorType_INT32)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(1));
+ assert(tensor != nullptr);
+ if (tensor->type() != circle::TensorType_INT32)
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp b/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp
index 5a60e2f54..041983dac 100644
--- a/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp
+++ b/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp
@@ -31,11 +31,13 @@ bool CircleTransposeConvGraphBuilder::validate(const ValidateArgs &args) const
return false;
const auto &inputs = args.op.inputs;
- const auto &tensors = args.reader.tensors();
- const auto &filter_tensor = tensors.at(inputs.at(1));
- const auto &filter_shape = filter_tensor.get()->shape;
- const auto &ifm_tensor = tensors.at(inputs.at(2));
- const auto &ifm_shape = ifm_tensor.get()->shape;
+ const auto tensors = args.reader.tensors();
+ const auto filter_tensor = tensors.at(inputs.at(1));
+ assert(filter_tensor != nullptr);
+ const auto filter_shape = wrap(filter_tensor->shape());
+ const auto ifm_tensor = tensors.at(inputs.at(2));
+ assert(ifm_tensor != nullptr);
+ const auto ifm_shape = wrap(ifm_tensor->shape());
// ifm and filters must be 4-D tensor
if (ifm_shape.size() != 4)
@@ -45,7 +47,7 @@ bool CircleTransposeConvGraphBuilder::validate(const ValidateArgs &args) const
// input shape : [batch, height, width, in_channels]
// filters shape : [output_channels, height, weight, in_channels]
- if (ifm_tensor.get()->shape.at(3) != filter_tensor.get()->shape.at(3))
+ if (ifm_shape.at(3) != filter_shape.at(3))
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleUnpack.cpp b/compiler/luci/import/src/Nodes/CircleUnpack.cpp
index 9bfc76b57..6b3401609 100644
--- a/compiler/luci/import/src/Nodes/CircleUnpack.cpp
+++ b/compiler/luci/import/src/Nodes/CircleUnpack.cpp
@@ -46,8 +46,8 @@ bool CircleUnpackGraphBuilder::validate(const ValidateArgs &args) const
{
if (settings->get(luci::UserSettings::Key::DisableValidation))
{
- const auto &tensors = args.reader.tensors();
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
+ const auto tensors = args.reader.tensors();
+ const auto output_tensor = tensors[outputs[0]];
auto name = tensor_name(output_tensor);
WARN(l) << "Warning: import Unpack(" << name << ") 'num' is not same as outputs used";
}
@@ -58,9 +58,10 @@ bool CircleUnpackGraphBuilder::validate(const ValidateArgs &args) const
if (options->num < 0)
return false;
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
- const auto &shape = tensor->shape;
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ const auto shape = wrap(tensor->shape());
auto shape_size = static_cast<int32_t>(shape.size());
if (shape_size > 0)
{
diff --git a/compiler/luci/import/src/Nodes/CircleVariable.cpp b/compiler/luci/import/src/Nodes/CircleVariable.cpp
new file mode 100644
index 000000000..23ae9e7be
--- /dev/null
+++ b/compiler/luci/import/src/Nodes/CircleVariable.cpp
@@ -0,0 +1,80 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Import/Nodes/CircleVariable.h"
+
+#include <luci/IR/Nodes/CircleVariable.h>
+#include <luci/Log.h>
+
+#include <cassert>
+#include <ostream>
+#include <string>
+#include <vector>
+
+namespace
+{
+
+std::ostream &operator<<(std::ostream &os, const luci::VectorWrapper<int32_t> &vect)
+{
+ uint32_t seq = 0;
+ for (const auto &v : vect)
+ {
+ if (seq)
+ os << ", ";
+ os << v;
+ seq++;
+ }
+ return os;
+}
+
+} // namespace
+
+namespace luci
+{
+
+CircleVariable *create_circlevariable(GraphBuilderContext *context, int32_t tensor_index)
+{
+ LOGGER(l);
+
+ auto graph = context->graph();
+ auto reader = context->reader();
+ const auto tensors = reader->tensors();
+ const auto variable_tensor = tensors[tensor_index];
+ assert(variable_tensor != nullptr);
+
+ if (not variable_tensor->is_variable())
+ {
+ // not a variable
+ return nullptr;
+ }
+ {
+ // check if there is no buffer as we don't support this for now
+ // TODO use buffer when this is enabled in Kernel
+ assert(reader->buffers()[variable_tensor->buffer()] != nullptr);
+ assert(reader->buffers()[variable_tensor->buffer()]->data() == nullptr);
+ }
+
+ auto variable_node = graph->nodes()->create<CircleVariable>();
+ copy_tensor_attributes(variable_tensor, variable_node);
+ variable_node->shape_status(luci::ShapeStatus::VALID);
+
+ INFO(l) << "[luci] NodeFinder variable node(" << tensor_index << ") -> " << variable_node << " "
+ << wrap(variable_tensor->shape()) << std::endl;
+
+ return variable_node;
+}
+
+} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleWhere.cpp b/compiler/luci/import/src/Nodes/CircleWhere.cpp
index 8e4f1a0c4..bc6199ace 100644
--- a/compiler/luci/import/src/Nodes/CircleWhere.cpp
+++ b/compiler/luci/import/src/Nodes/CircleWhere.cpp
@@ -30,14 +30,16 @@ bool CircleWhereGraphBuilder::validate(const ValidateArgs &args) const
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- const auto &tensors = args.reader.tensors();
- const auto &tensor_condition = tensors.at(inputs.at(0));
- const auto &tensor_out = tensors.at(outputs[0]);
+ const auto tensors = args.reader.tensors();
+ const auto tensor_condition = tensors.at(inputs.at(0));
+ const auto tensor_out = tensors.at(outputs[0]);
+ assert(tensor_condition != nullptr);
+ assert(tensor_out != nullptr);
- if (tensor_condition->type != circle::TensorType_BOOL)
+ if (tensor_condition->type() != circle::TensorType_BOOL)
return false;
- if (tensor_out->type != circle::TensorType_INT64)
+ if (tensor_out->type() != circle::TensorType_INT64)
return false;
return true;
diff --git a/compiler/luci/import/src/Nodes/CircleWhile.cpp b/compiler/luci/import/src/Nodes/CircleWhile.cpp
index 26147562f..27a392b2a 100644
--- a/compiler/luci/import/src/Nodes/CircleWhile.cpp
+++ b/compiler/luci/import/src/Nodes/CircleWhile.cpp
@@ -67,8 +67,8 @@ CircleNode *CircleWhileGraphBuilder::build(const circle::OperatorT &op,
const std::vector<int32_t> &inputs = op.inputs;
const std::vector<int32_t> &outputs = op.outputs;
- const auto &tensors = context->reader()->tensors();
- const auto &opcodes = context->reader()->opcodes();
+ const auto tensors = context->reader()->tensors();
+ const auto opcodes = context->reader()->opcodes();
std::vector<CircleNode *> input_nodes;
for (const int32_t input_tensor_index : inputs)
@@ -96,9 +96,11 @@ CircleNode *CircleWhileGraphBuilder::build(const circle::OperatorT &op,
assert(outputs.size() > 0);
{
// Lets use name of output 0 as While name
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
+ const auto output_tensor = tensors[outputs[0]];
+ assert(output_tensor != nullptr);
node->name(tensor_name(output_tensor));
- node->op_version(opcodes[op.opcode_index].get()->version);
+ assert(opcodes[op.opcode_index] != nullptr);
+ node->op_version(opcodes[op.opcode_index]->version());
// NOTE We don't set quantization for While itself but to virtual outputs
}
@@ -106,7 +108,8 @@ CircleNode *CircleWhileGraphBuilder::build(const circle::OperatorT &op,
// Create virtual outputs of While
for (uint32_t n = 0; n < output_count; ++n)
{
- const circle::TensorT &output_tensor = *tensors[outputs[n]];
+ const auto output_tensor = tensors[outputs[n]];
+ assert(output_tensor != nullptr);
auto *nodeout = graph->nodes()->create<CircleWhileOut>();
diff --git a/compiler/luci/import/src/ValidateHelpers.cpp b/compiler/luci/import/src/ValidateHelpers.cpp
index 27306ba90..fc027704b 100644
--- a/compiler/luci/import/src/ValidateHelpers.cpp
+++ b/compiler/luci/import/src/ValidateHelpers.cpp
@@ -26,9 +26,10 @@ bool validate_batch_space_nd(const GraphBuilderBase::ValidateArgs &args)
return false;
// input 1 and 2 should have INT32/INT64 type
- const auto &tensors = args.reader.tensors();
- const auto &tensor_1 = tensors.at(inputs.at(1));
- switch (tensor_1->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor_1 = tensors.at(inputs.at(1));
+ assert(tensor_1 != nullptr);
+ switch (tensor_1->type())
{
case circle::TensorType_INT32:
case circle::TensorType_INT64:
@@ -36,8 +37,9 @@ bool validate_batch_space_nd(const GraphBuilderBase::ValidateArgs &args)
default:
return false;
}
- const auto &tensor_2 = tensors.at(inputs.at(2));
- switch (tensor_2->type)
+ const auto tensor_2 = tensors.at(inputs.at(2));
+ assert(tensor_2 != nullptr);
+ switch (tensor_2->type())
{
case circle::TensorType_INT32:
case circle::TensorType_INT64:
@@ -47,8 +49,9 @@ bool validate_batch_space_nd(const GraphBuilderBase::ValidateArgs &args)
}
// Only support input shape dimension 3 and 4 only
- const auto &tensor_0 = tensors.at(inputs.at(0));
- const auto t_0_s = tensor_0->shape.size();
+ const auto tensor_0 = tensors.at(inputs.at(0));
+ assert(tensor_0 != nullptr);
+ const auto t_0_s = wrap(tensor_0->shape()).size();
if (t_0_s != 3 && t_0_s != 4)
return false;
@@ -68,10 +71,10 @@ bool validate_minmax(const GraphBuilderBase::ValidateArgs &args)
if (outputs.size() != 1)
return false;
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
-
- switch (tensor->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
+ assert(tensor != nullptr);
+ switch (tensor->type())
{
case circle::TensorType_FLOAT16:
case circle::TensorType_FLOAT32:
@@ -84,10 +87,12 @@ bool validate_minmax(const GraphBuilderBase::ValidateArgs &args)
return false;
}
- if (tensors[inputs.at(1)]->type != tensor->type)
+ assert(tensors[inputs.at(1)] != nullptr);
+ if (tensors[inputs.at(1)]->type() != tensor->type())
return false;
- if (tensors[outputs[0]]->type != tensor->type)
+ assert(tensors[outputs[0]] != nullptr);
+ if (tensors[outputs[0]]->type() != tensor->type())
return false;
return true;
@@ -104,10 +109,10 @@ bool validate_reduce_minmax(const GraphBuilderBase::ValidateArgs &args)
if (outputs.size() != 1)
return false;
- const auto &tensors = args.reader.tensors();
- const auto &tensor_axis = tensors.at(inputs.at(1));
-
- switch (tensor_axis->type)
+ const auto tensors = args.reader.tensors();
+ const auto tensor_axis = tensors.at(inputs.at(1));
+ assert(tensor_axis != nullptr);
+ switch (tensor_axis->type())
{
case circle::TensorType_INT32:
case circle::TensorType_INT64:
diff --git a/compiler/luci/lang/include/luci/IR/CircleNodes.h b/compiler/luci/lang/include/luci/IR/CircleNodes.h
index a313f9d5b..d89ea03cc 100644
--- a/compiler/luci/lang/include/luci/IR/CircleNodes.h
+++ b/compiler/luci/lang/include/luci/IR/CircleNodes.h
@@ -29,7 +29,6 @@
#include "Nodes/CircleCast.h"
#include "Nodes/CircleCeil.h"
#include "Nodes/CircleConcatenation.h"
-#include "Nodes/CircleConst.h"
#include "Nodes/CircleConv2D.h"
#include "Nodes/CircleCos.h"
#include "Nodes/CircleCustom.h"
@@ -119,6 +118,7 @@
#include "Nodes/CircleStridedSlice.h"
#include "Nodes/CircleSub.h"
#include "Nodes/CircleSum.h"
+#include "Nodes/CircleSVDF.h"
#include "Nodes/CircleTanh.h"
#include "Nodes/CircleTile.h"
#include "Nodes/CircleTopKV2.h"
@@ -135,18 +135,21 @@
#include "Nodes/CircleBCQGather.h"
#include "Nodes/CircleInstanceNorm.h"
// Virtual nodes
+#include "Nodes/CircleConst.h"
#include "Nodes/CircleInput.h"
#include "Nodes/CircleOutput.h"
+#include "Nodes/CircleVariable.h"
+// Multi-output virtual nodes
#include "Nodes/CircleBidirectionalSequenceLSTMOut.h"
#include "Nodes/CircleCustomOut.h"
#include "Nodes/CircleIfOut.h"
#include "Nodes/CircleNonMaxSuppressionV4Out.h"
#include "Nodes/CircleNonMaxSuppressionV5Out.h"
-#include "Nodes/CircleUnpackOut.h"
-#include "Nodes/CircleUniqueOut.h"
#include "Nodes/CircleSplitOut.h"
#include "Nodes/CircleSplitVOut.h"
#include "Nodes/CircleTopKV2Out.h"
+#include "Nodes/CircleUniqueOut.h"
+#include "Nodes/CircleUnpackOut.h"
#include "Nodes/CircleWhileOut.h"
#include <loco/IR/Graph.h>
diff --git a/compiler/luci/lang/include/luci/IR/CircleNodes.lst b/compiler/luci/lang/include/luci/IR/CircleNodes.lst
index 914aa16e4..1472008df 100644
--- a/compiler/luci/lang/include/luci/IR/CircleNodes.lst
+++ b/compiler/luci/lang/include/luci/IR/CircleNodes.lst
@@ -116,6 +116,7 @@ CIRCLE_NODE(SQUEEZE, CircleSqueeze)
CIRCLE_NODE(STRIDED_SLICE, CircleStridedSlice)
CIRCLE_NODE(SUB, CircleSub)
CIRCLE_NODE(SUM, CircleSum)
+CIRCLE_NODE(SVDF, CircleSVDF)
CIRCLE_NODE(TANH, CircleTanh)
CIRCLE_NODE(TILE, CircleTile)
CIRCLE_NODE(TOPK_V2, CircleTopKV2)
@@ -132,12 +133,14 @@ CIRCLE_NODE(BCQ_FULLY_CONNECTED, CircleBCQFullyConnected)
CIRCLE_NODE(BCQ_GATHER, CircleBCQGather)
CIRCLE_NODE(INSTANCE_NORM, CircleInstanceNorm)
// Virtual node(s)
-CIRCLE_VNODE(CIRCLEBIDIRECTIONAL_SEQUENCE_LSTM_OUT, CircleBidirectionalSequenceLSTMOut)
CIRCLE_VNODE(CIRCLECONST, CircleConst)
CIRCLE_VNODE(CIRCLEINPUT, CircleInput)
CIRCLE_VNODE(CIRCLEOUTPUT, CircleOutput)
CIRCLE_VNODE(CIRCLEOUTPUTDUMMY, CircleOutputDummy)
CIRCLE_VNODE(CIRCLEOUTPUTEXCLUDE, CircleOutputExclude)
+CIRCLE_VNODE(CIRCLEVARIABLE, CircleVariable)
+// Multi-output virtual nodes
+CIRCLE_VNODE(CIRCLEBIDIRECTIONAL_SEQUENCE_LSTM_OUT, CircleBidirectionalSequenceLSTMOut)
CIRCLE_VNODE(CIRCLECUSTOMOUT, CircleCustomOut)
CIRCLE_VNODE(CIRCLEIFOUT, CircleIfOut)
CIRCLE_VNODE(CIRCLENONMAXSUPPRESSIONV4OUT, CircleNonMaxSuppressionV4Out)
diff --git a/compiler/luci/lang/include/luci/IR/CircleQuantParam.h b/compiler/luci/lang/include/luci/IR/CircleQuantParam.h
index 694437303..8afc80a76 100644
--- a/compiler/luci/lang/include/luci/IR/CircleQuantParam.h
+++ b/compiler/luci/lang/include/luci/IR/CircleQuantParam.h
@@ -32,6 +32,10 @@ struct CircleQuantParam
int32_t quantized_dimension{0};
};
+struct CircleNode;
+
+void copy_quantparam(const luci::CircleNode *src, luci::CircleNode *dst);
+
} // namespace luci
#endif // __LUCI_IR_CIRCLEQUANTPARAM_H__
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h
index 2862cadb2..dc5aeb267 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h
@@ -58,8 +58,12 @@ public:
WeightsFormat weights_format(void) const { return _weights_format; }
void weights_format(WeightsFormat weights_format) { _weights_format = weights_format; }
+ bool keep_num_dims(void) const { return _keep_num_dims; }
+ void keep_num_dims(bool keep_num_dims) { _keep_num_dims = keep_num_dims; }
+
private:
WeightsFormat _weights_format{WeightsFormat::DEFAULT};
+ bool _keep_num_dims{false};
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSVDF.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSVDF.h
new file mode 100644
index 000000000..839d11e04
--- /dev/null
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSVDF.h
@@ -0,0 +1,70 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_IR_CIRCLE_SVDF_H__
+#define __LUCI_IR_CIRCLE_SVDF_H__
+
+#include "luci/IR/CircleNodeDecl.h"
+#include "luci/IR/CircleOpcode.h"
+
+#include "luci/IR/LuciNodeMixins.h"
+
+namespace luci
+{
+
+/**
+ * @brief SVDF in Circle
+ */
+class CircleSVDF final : public FixedArityNode<5, CircleNodeImpl<CircleOpcode::SVDF>>,
+ public CircleNodeMixin<CircleNodeTrait::FusedActFunc>
+{
+public:
+ CircleSVDF() = default;
+
+public:
+ loco::Node *input(void) const { return at(0)->node(); }
+ void input(loco::Node *node) { at(0)->node(node); }
+
+ loco::Node *weight_feature(void) const { return at(1)->node(); }
+ void weight_feature(loco::Node *node) { at(1)->node(node); }
+
+ loco::Node *weight_time(void) const { return at(2)->node(); }
+ void weight_time(loco::Node *node) { at(2)->node(node); }
+
+ loco::Node *bias(void) const { return at(3)->node(); }
+ void bias(loco::Node *node) { at(3)->node(node); }
+
+ loco::Node *input_activation_state(void) const { return at(4)->node(); }
+ void input_activation_state(loco::Node *node) { at(4)->node(node); }
+
+public:
+ bool asymmetric_quantize_inputs() const { return _asymmetric_quantize_inputs; }
+ void asymmetric_quantize_inputs(bool asymmetric_quantize_inputs)
+ {
+ _asymmetric_quantize_inputs = asymmetric_quantize_inputs;
+ }
+
+ int32_t svdf_rank() const { return _rank; }
+ void svdf_rank(int32_t svdf_rank) { _rank = svdf_rank; }
+
+private:
+ bool _asymmetric_quantize_inputs = false;
+ int32_t _rank = 0;
+};
+
+} // namespace luci
+
+#endif // __LUCI_IR_CIRCLE_SVDF_H__
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleVariable.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleVariable.h
new file mode 100644
index 000000000..8c15b66c9
--- /dev/null
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleVariable.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_IR_CIRCLE_VARIABLE_H__
+#define __LUCI_IR_CIRCLE_VARIABLE_H__
+
+#include "luci/IR/CircleNodeDecl.h"
+#include "luci/IR/CircleOpcode.h"
+
+#include "luci/IR/CircleNodeMixins.h"
+
+namespace luci
+{
+
+/**
+ * @brief Virtual CircleVariable in Circle for 'variable' Tensor
+ */
+class CircleVariable final : public FixedArityNode<0, CircleNodeImpl<CircleOpcode::CIRCLEVARIABLE>>
+{
+public:
+ CircleVariable() = default;
+};
+
+} // namespace luci
+
+#endif // __LUCI_IR_CIRCLE_VARIABLE_H__
diff --git a/compiler/luci/lang/src/CircleQuantParam.cpp b/compiler/luci/lang/src/CircleQuantParam.cpp
new file mode 100644
index 000000000..89671d3c3
--- /dev/null
+++ b/compiler/luci/lang/src/CircleQuantParam.cpp
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/IR/CircleQuantParam.h"
+#include "luci/IR/CircleNode.h"
+
+#include <memory>
+
+namespace luci
+{
+
+/**
+ * @brief copy CircleQuantParam of src to dst
+ */
+void copy_quantparam(const luci::CircleNode *src, luci::CircleNode *dst)
+{
+ auto q = src->quantparam();
+ if (q == nullptr)
+ dst->quantparam(nullptr);
+ else
+ {
+ auto qparam = std::make_unique<luci::CircleQuantParam>();
+ qparam->scale = q->scale;
+ qparam->zerop = q->zerop;
+ qparam->min = q->min;
+ qparam->max = q->max;
+ qparam->quantized_dimension = q->quantized_dimension;
+
+ dst->quantparam(std::move(qparam));
+ }
+}
+
+} // namespace luci
diff --git a/compiler/luci/lang/src/CircleQuantParam.test.cpp b/compiler/luci/lang/src/CircleQuantParam.test.cpp
new file mode 100644
index 000000000..520ca05cc
--- /dev/null
+++ b/compiler/luci/lang/src/CircleQuantParam.test.cpp
@@ -0,0 +1,78 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// NOTE any node will do for testing
+#include "luci/IR/Nodes/CircleAdd.h"
+
+#include <loco/IR/Graph.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+luci::CircleAdd *build_simple_add_graph(loco::Graph *g)
+{
+ auto node = g->nodes()->create<luci::CircleAdd>();
+
+ node->name("name");
+ node->dtype(loco::DataType::FLOAT32);
+ node->rank(1);
+ node->dim(0).set(3);
+ node->shape_status(luci::ShapeStatus::VALID);
+ node->fusedActivationFunction(luci::FusedActFunc::NONE);
+
+ auto qparam = std::make_unique<luci::CircleQuantParam>();
+ qparam->scale = {1.0};
+ qparam->zerop = {0};
+ qparam->min = {0.0};
+ qparam->max = {1.0};
+ qparam->quantized_dimension = 0;
+ node->quantparam(std::move(qparam));
+
+ return node;
+}
+
+} // namespace
+
+TEST(CircleNodeCloneTest, copy_quantparam)
+{
+ auto g = loco::make_graph();
+ auto node = build_simple_add_graph(g.get());
+
+ auto copy = g->nodes()->create<luci::CircleAdd>();
+ luci::copy_quantparam(node, copy);
+
+ const auto *qparam_node = node->quantparam();
+ const auto *qparam_copy = copy->quantparam();
+ ASSERT_EQ(qparam_node->scale, qparam_copy->scale);
+ ASSERT_EQ(qparam_node->zerop, qparam_copy->zerop);
+ ASSERT_EQ(qparam_node->quantized_dimension, qparam_copy->quantized_dimension);
+}
+
+TEST(CircleNodeCloneTest, copy_quantparam_NEG)
+{
+ auto g = loco::make_graph();
+ auto node = build_simple_add_graph(g.get());
+
+ node->quantparam(nullptr);
+
+ auto copy = g->nodes()->create<luci::CircleAdd>();
+ luci::copy_quantparam(node, copy);
+
+ const auto *qparam_copy = copy->quantparam();
+ ASSERT_EQ(qparam_copy, nullptr);
+}
diff --git a/compiler/luci/lang/src/Nodes/CircleFullyConnected.test.cpp b/compiler/luci/lang/src/Nodes/CircleFullyConnected.test.cpp
index bb0e3c51b..15a780085 100644
--- a/compiler/luci/lang/src/Nodes/CircleFullyConnected.test.cpp
+++ b/compiler/luci/lang/src/Nodes/CircleFullyConnected.test.cpp
@@ -32,6 +32,7 @@ TEST(CircleFullyConnectedTest, constructor)
ASSERT_EQ(nullptr, fc_node.weights());
ASSERT_EQ(nullptr, fc_node.bias());
ASSERT_EQ(luci::FusedActFunc::UNDEFINED, fc_node.fusedActivationFunction());
+ ASSERT_EQ(false, fc_node.keep_num_dims());
}
TEST(CircleFullyConnectedTest, input_NEG)
diff --git a/compiler/luci/lang/src/Nodes/CircleSVDF.test.cpp b/compiler/luci/lang/src/Nodes/CircleSVDF.test.cpp
new file mode 100644
index 000000000..833ae0732
--- /dev/null
+++ b/compiler/luci/lang/src/Nodes/CircleSVDF.test.cpp
@@ -0,0 +1,101 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/IR/Nodes/CircleSVDF.h"
+
+#include "luci/IR/CircleDialect.h"
+#include "luci/IR/CircleNodeVisitor.h"
+
+#include <gtest/gtest.h>
+
+TEST(CircleSVDFTest, constructor)
+{
+ luci::CircleSVDF svdf_node;
+
+ ASSERT_EQ(luci::CircleDialect::get(), svdf_node.dialect());
+ ASSERT_EQ(luci::CircleOpcode::SVDF, svdf_node.opcode());
+
+ ASSERT_EQ(nullptr, svdf_node.input());
+ ASSERT_EQ(nullptr, svdf_node.weight_feature());
+ ASSERT_EQ(nullptr, svdf_node.weight_time());
+ ASSERT_EQ(nullptr, svdf_node.bias());
+ ASSERT_EQ(nullptr, svdf_node.input_activation_state());
+
+ ASSERT_EQ(false, svdf_node.asymmetric_quantize_inputs());
+ ASSERT_EQ(0, svdf_node.svdf_rank());
+}
+
+TEST(CircleSVDFTest, input_NEG)
+{
+ luci::CircleSVDF svdf_node;
+ luci::CircleSVDF node;
+
+ svdf_node.input(&node);
+ svdf_node.weight_feature(&node);
+ svdf_node.weight_time(&node);
+ svdf_node.bias(&node);
+ svdf_node.input_activation_state(&node);
+
+ ASSERT_NE(nullptr, svdf_node.input());
+ ASSERT_NE(nullptr, svdf_node.weight_feature());
+ ASSERT_NE(nullptr, svdf_node.weight_time());
+ ASSERT_NE(nullptr, svdf_node.bias());
+ ASSERT_NE(nullptr, svdf_node.input_activation_state());
+
+ svdf_node.input(nullptr);
+ svdf_node.weight_feature(nullptr);
+ svdf_node.weight_time(nullptr);
+ svdf_node.bias(nullptr);
+ svdf_node.input_activation_state(nullptr);
+
+ ASSERT_EQ(nullptr, svdf_node.input());
+ ASSERT_EQ(nullptr, svdf_node.weight_feature());
+ ASSERT_EQ(nullptr, svdf_node.weight_time());
+ ASSERT_EQ(nullptr, svdf_node.bias());
+ ASSERT_EQ(nullptr, svdf_node.input_activation_state());
+}
+
+TEST(CircleSVDFTest, arity_NEG)
+{
+ luci::CircleSVDF svdf_node;
+
+ ASSERT_NO_THROW(svdf_node.arg(4));
+ ASSERT_THROW(svdf_node.arg(5), std::out_of_range);
+}
+
+TEST(CircleSVDFTest, visit_mutable_NEG)
+{
+ struct TestVisitor final : public luci::CircleNodeMutableVisitor<void>
+ {
+ };
+
+ luci::CircleSVDF svdf_node;
+
+ TestVisitor tv;
+ ASSERT_THROW(svdf_node.accept(&tv), std::exception);
+}
+
+TEST(CircleSVDFTest, visit_NEG)
+{
+ struct TestVisitor final : public luci::CircleNodeVisitor<void>
+ {
+ };
+
+ luci::CircleSVDF svdf_node;
+
+ TestVisitor tv;
+ ASSERT_THROW(svdf_node.accept(&tv), std::exception);
+}
diff --git a/compiler/luci/lang/src/Nodes/CircleVariable.test.cpp b/compiler/luci/lang/src/Nodes/CircleVariable.test.cpp
new file mode 100644
index 000000000..e1864f8da
--- /dev/null
+++ b/compiler/luci/lang/src/Nodes/CircleVariable.test.cpp
@@ -0,0 +1,61 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/IR/Nodes/CircleVariable.h"
+
+#include "luci/IR/CircleDialect.h"
+#include "luci/IR/CircleNodeVisitor.h"
+
+#include <gtest/gtest.h>
+
+TEST(CircleVariableTest, constructor)
+{
+ luci::CircleVariable var_node;
+
+ ASSERT_EQ(luci::CircleDialect::get(), var_node.dialect());
+ ASSERT_EQ(luci::CircleOpcode::CIRCLEVARIABLE, var_node.opcode());
+}
+
+TEST(CircleVariableTest, arity_NEG)
+{
+ luci::CircleVariable var_node;
+
+ ASSERT_THROW(var_node.arg(0), std::out_of_range);
+}
+
+TEST(CircleVariableTest, visit_mutable_NEG)
+{
+ struct TestVisitor final : public luci::CircleNodeMutableVisitor<void>
+ {
+ };
+
+ luci::CircleVariable var_node;
+
+ TestVisitor tv;
+ ASSERT_THROW(var_node.accept(&tv), std::exception);
+}
+
+TEST(CircleVariableTest, visit_NEG)
+{
+ struct TestVisitor final : public luci::CircleNodeVisitor<void>
+ {
+ };
+
+ luci::CircleVariable var_node;
+
+ TestVisitor tv;
+ ASSERT_THROW(var_node.accept(&tv), std::exception);
+}
diff --git a/compiler/luci/logex/CMakeLists.txt b/compiler/luci/logex/CMakeLists.txt
index aed9fb79b..b8a2111dd 100644
--- a/compiler/luci/logex/CMakeLists.txt
+++ b/compiler/luci/logex/CMakeLists.txt
@@ -1,5 +1,7 @@
# TODO Find how to test logging-ex utility
file(GLOB_RECURSE SOURCES "src/*.cpp")
+file(GLOB_RECURSE TESTS "src/*.test.cpp")
+list(REMOVE_ITEM SOURCES ${TESTS})
if (NOT LUCI_LIBRARY_TYPE)
set(LUCI_LIBRARY_TYPE "SHARED")
@@ -13,7 +15,17 @@ target_link_libraries(luci_logex PRIVATE luci_log)
target_link_libraries(luci_logex PRIVATE luci_lang)
target_link_libraries(luci_logex PRIVATE hermes_std)
target_link_libraries(luci_logex PRIVATE nncc_common)
-target_link_libraries(luci_logex PRIVATE pepper_str)
install(TARGETS luci_logex DESTINATION lib)
install(DIRECTORY include/ DESTINATION include
FILES_MATCHING PATTERN "*.h")
+
+if(NOT ENABLE_TEST)
+ return()
+endif(NOT ENABLE_TEST)
+
+nnas_find_package(GTest REQUIRED)
+
+GTest_AddTest(luci_logex_test ${TESTS})
+target_include_directories(luci_logex_test PRIVATE src)
+target_link_libraries(luci_logex_test luci_logex)
+target_link_libraries(luci_logex_test luci_lang)
diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp b/compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp
new file mode 100644
index 000000000..eff0830b4
--- /dev/null
+++ b/compiler/luci/logex/src/CircleNodeSummaryBuilder.cpp
@@ -0,0 +1,265 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License")
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleNodeSummaryBuilder.h"
+#include "CircleNodeSummaryBuilders.h"
+
+#include <luci/IR/CircleDialect.h>
+
+#include <memory>
+
+namespace
+{
+
+std::string circle_opname(luci::CircleOpcode opcode)
+{
+ static const std::string prefix{"circle."};
+
+ switch (opcode)
+ {
+#define CIRCLE_NODE(OPCODE, CLASS) \
+ case luci::CircleOpcode::OPCODE: \
+ return prefix + #OPCODE;
+#define CIRCLE_VNODE CIRCLE_NODE
+#include <luci/IR/CircleNodes.lst>
+#undef CIRCLE_VNODE
+#undef CIRCLE_NODE
+ default:
+ break;
+ };
+
+ return prefix + "Invalid";
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool CircleNodeSummaryBuilder::build(const loco::Node *node, const locop::SymbolTable *tbl,
+ locop::NodeSummary &s)
+{
+ if (node->dialect() != luci::CircleDialect::get())
+ return false;
+
+ auto ptr_to_str = [](const void *ptr) {
+ std::stringstream ss;
+ ss << ptr;
+ return ss.str();
+ };
+
+ auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
+ if (const auto builder = create_builder(circle_node))
+ {
+ if (!builder->validate(circle_node))
+ {
+ s.state(locop::NodeDesc::State::Invalid);
+ return false;
+ }
+
+ auto input_names = builder->get_input_names(circle_node);
+ assert(node->arity() == input_names.size());
+ for (uint32_t i = 0; i < node->arity(); ++i)
+ s.args().append(input_names.at(i), tbl->lookup(node->arg(i)));
+
+ builder->build_attributes(circle_node, s);
+ builder->update_status(s);
+
+ s.opname(circle_opname(circle_node->opcode()));
+ s.comments().append("[" + circle_node->name() + "] = " + ptr_to_str(node));
+
+ return true;
+ }
+ else
+ {
+ // When SummaryBuilder is not implemented, return false
+ return false;
+ }
+}
+
+bool CircleNodeSummaryBuilder::validate(const luci::CircleNode *) { return true; }
+
+std::vector<std::string> CircleNodeSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ // Return empty names for default
+ return std::vector<std::string>();
+}
+
+void CircleNodeSummaryBuilder::build_attributes(const luci::CircleNode *, locop::NodeSummary &)
+{
+ // Do nothing for default
+}
+
+void CircleNodeSummaryBuilder::update_status(locop::NodeSummary &s)
+{
+ s.state(locop::NodeDesc::State::Complete);
+}
+
+std::unique_ptr<CircleNodeSummaryBuilder>
+CircleNodeSummaryBuilder::create_builder(const luci::CircleNode *node)
+{
+ switch (node->opcode())
+ {
+#define CIRCLE_NODE(OPCODE, CLASS) \
+ case luci::CircleOpcode::OPCODE: \
+ { \
+ return std::make_unique<CLASS>(); \
+ }
+
+ CIRCLE_NODE(ABS, CircleAbsSummaryBuilder)
+ CIRCLE_NODE(ADD, CircleAddSummaryBuilder)
+ CIRCLE_NODE(ADD_N, CircleAddNSummaryBuilder)
+ CIRCLE_NODE(ARG_MAX, CircleArgMaxSummaryBuilder)
+ CIRCLE_NODE(ARG_MIN, CircleArgMinSummaryBuilder)
+ CIRCLE_NODE(AVERAGE_POOL_2D, CircleAveragePool2DSummaryBuilder)
+ CIRCLE_NODE(BATCH_MATMUL, CircleBatchMatMulSummaryBuilder)
+ CIRCLE_NODE(BATCH_TO_SPACE_ND, CircleBatchToSpaceNDSummaryBuilder)
+ CIRCLE_NODE(BCQ_FULLY_CONNECTED, CircleBCQFullyConnectedSummaryBuilder)
+ CIRCLE_NODE(BCQ_GATHER, CircleBCQGatherSummaryBuilder)
+ CIRCLE_NODE(BIDIRECTIONAL_SEQUENCE_LSTM, CircleBidirectionalSequenceLSTMSummaryBuilder)
+ CIRCLE_NODE(CAST, CircleCastSummaryBuilder)
+ CIRCLE_NODE(CEIL, CircleCeilSummaryBuilder)
+ CIRCLE_NODE(CONCATENATION, CircleConcatenationSummaryBuilder)
+ CIRCLE_NODE(CIRCLECONST, CircleConstSummaryBuilder)
+ CIRCLE_NODE(CONV_2D, CircleConv2DSummaryBuilder)
+ CIRCLE_NODE(COS, CircleCosSummaryBuilder)
+ CIRCLE_NODE(CUSTOM, CircleCustomSummaryBuilder)
+ CIRCLE_NODE(DEPTH_TO_SPACE, CircleDepthToSpaceSummaryBuilder)
+ CIRCLE_NODE(DEPTHWISE_CONV_2D, CircleDepthwiseConv2DSummaryBuilder)
+ CIRCLE_NODE(DEQUANTIZE, CircleDequantizeSummaryBuilder)
+ CIRCLE_NODE(DIV, CircleDivSummaryBuilder)
+ CIRCLE_NODE(ELU, CircleEluSummaryBuilder)
+ CIRCLE_NODE(EQUAL, CircleEqualSummaryBuilder)
+ CIRCLE_NODE(EXP, CircleExpSummaryBuilder)
+ CIRCLE_NODE(EXPAND_DIMS, CircleExpandDimsSummaryBuilder)
+ CIRCLE_NODE(FAKE_QUANT, CircleFakeQuantSummaryBuilder)
+ CIRCLE_NODE(FILL, CircleFillSummaryBuilder)
+ CIRCLE_NODE(FLOOR, CircleFloorSummaryBuilder)
+ CIRCLE_NODE(FLOOR_DIV, CircleFloorDivSummaryBuilder)
+ CIRCLE_NODE(FLOOR_MOD, CircleFloorModSummaryBuilder)
+ CIRCLE_NODE(FULLY_CONNECTED, CircleFullyConnectedSummaryBuilder)
+ CIRCLE_NODE(GATHER, CircleGatherSummaryBuilder)
+ CIRCLE_NODE(GATHER_ND, CircleGatherNdSummaryBuilder)
+ CIRCLE_NODE(GREATER, CircleGreaterSummaryBuilder)
+ CIRCLE_NODE(GREATER_EQUAL, CircleGreaterEqualSummaryBuilder)
+ CIRCLE_NODE(IF, CircleIfSummaryBuilder)
+ CIRCLE_NODE(INSTANCE_NORM, CircleInstanceNormSummaryBuilder)
+ CIRCLE_NODE(L2_NORMALIZATION, CircleL2NormalizeSummaryBuilder)
+ CIRCLE_NODE(L2_POOL_2D, CircleL2Pool2DSummaryBuilder)
+ CIRCLE_NODE(LEAKY_RELU, CircleLeakyReluSummaryBuilder)
+ CIRCLE_NODE(LESS, CircleLessSummaryBuilder)
+ CIRCLE_NODE(LESS_EQUAL, CircleLessEqualSummaryBuilder)
+ CIRCLE_NODE(LOCAL_RESPONSE_NORMALIZATION, CircleLocalResponseNormalizationSummaryBuilder)
+ CIRCLE_NODE(LOG, CircleLogSummaryBuilder)
+ CIRCLE_NODE(LOGICAL_AND, CircleLogicalAndSummaryBuilder)
+ CIRCLE_NODE(LOGICAL_NOT, CircleLogicalNotSummaryBuilder)
+ CIRCLE_NODE(LOGICAL_OR, CircleLogicalOrSummaryBuilder)
+ CIRCLE_NODE(LOGISTIC, CircleLogisticSummaryBuilder)
+ CIRCLE_NODE(LOG_SOFTMAX, CircleLogSoftmaxSummaryBuilder)
+ CIRCLE_NODE(MATRIX_DIAG, CircleMatrixDiagSummaryBuilder)
+ CIRCLE_NODE(MATRIX_SET_DIAG, CircleMatrixSetDiagSummaryBuilder)
+ CIRCLE_NODE(MAXIMUM, CircleMaximumSummaryBuilder)
+ CIRCLE_NODE(MAX_POOL_2D, CircleMaxPool2DSummaryBuilder)
+ CIRCLE_NODE(MEAN, CircleMeanSummaryBuilder)
+ CIRCLE_NODE(MINIMUM, CircleMinimumSummaryBuilder)
+ CIRCLE_NODE(MIRROR_PAD, CircleMirrorPadSummaryBuilder)
+ CIRCLE_NODE(MUL, CircleMulSummaryBuilder)
+ CIRCLE_NODE(NEG, CircleNegSummaryBuilder)
+ CIRCLE_NODE(NON_MAX_SUPPRESSION_V4, CircleNonMaxSuppressionV4SummaryBuilder)
+ CIRCLE_NODE(NON_MAX_SUPPRESSION_V5, CircleNonMaxSuppressionV5SummaryBuilder)
+ CIRCLE_NODE(NOT_EQUAL, CircleNotEqualSummaryBuilder)
+ CIRCLE_NODE(ONE_HOT, CircleOneHotSummaryBuilder)
+ CIRCLE_NODE(PACK, CirclePackSummaryBuilder)
+ CIRCLE_NODE(PAD, CirclePadSummaryBuilder)
+ CIRCLE_NODE(PADV2, CirclePadV2SummaryBuilder)
+ CIRCLE_NODE(POW, CirclePowSummaryBuilder)
+ CIRCLE_NODE(PRELU, CirclePReluSummaryBuilder)
+ CIRCLE_NODE(QUANTIZE, CircleQuantizeSummaryBuilder)
+ CIRCLE_NODE(RANGE, CircleRangeSummaryBuilder)
+ CIRCLE_NODE(RANK, CircleRankSummaryBuilder)
+ CIRCLE_NODE(REDUCE_ANY, CircleReduceAnySummaryBuilder)
+ CIRCLE_NODE(REDUCE_MAX, CircleReduceMaxSummaryBuilder)
+ CIRCLE_NODE(REDUCE_MIN, CircleReduceMinSummaryBuilder)
+ CIRCLE_NODE(REDUCE_PROD, CircleReduceProdSummaryBuilder)
+ CIRCLE_NODE(RELU, CircleReluSummaryBuilder)
+ CIRCLE_NODE(RELU6, CircleRelu6SummaryBuilder)
+ CIRCLE_NODE(RELU_N1_TO_1, CircleReluN1To1SummaryBuilder)
+ CIRCLE_NODE(RESHAPE, CircleReshapeSummaryBuilder)
+ CIRCLE_NODE(RESIZE_BILINEAR, CircleResizeBilinearSummaryBuilder)
+ CIRCLE_NODE(RESIZE_NEAREST_NEIGHBOR, CircleResizeNearestNeighborSummaryBuilder)
+ CIRCLE_NODE(REVERSE_SEQUENCE, CircleReverseSequenceSummaryBuilder)
+ CIRCLE_NODE(REVERSE_V2, CircleReverseV2SummaryBuilder)
+ CIRCLE_NODE(ROUND, CircleRoundSummaryBuilder)
+ CIRCLE_NODE(RSQRT, CircleRsqrtSummaryBuilder)
+ CIRCLE_NODE(SCATTER_ND, CircleScatterNdSummaryBuilder)
+ CIRCLE_NODE(SEGMENT_SUM, CircleSegmentSumSummaryBuilder)
+ CIRCLE_NODE(SELECT, CircleSelectSummaryBuilder)
+ CIRCLE_NODE(SELECT_V2, CircleSelectV2SummaryBuilder)
+ CIRCLE_NODE(SHAPE, CircleShapeSummaryBuilder)
+ CIRCLE_NODE(SIN, CircleSinSummaryBuilder)
+ CIRCLE_NODE(SLICE, CircleSliceSummaryBuilder)
+ CIRCLE_NODE(SOFTMAX, CircleSoftmaxSummaryBuilder)
+ CIRCLE_NODE(SPACE_TO_BATCH_ND, CircleSpaceToBatchNDSummaryBuilder)
+ CIRCLE_NODE(SPACE_TO_DEPTH, CircleSpaceToDepthSummaryBuilder)
+ CIRCLE_NODE(SPARSE_TO_DENSE, CircleSparseToDenseSummaryBuilder)
+ CIRCLE_NODE(SPLIT, CircleSplitSummaryBuilder)
+ CIRCLE_NODE(SPLIT_V, CircleSplitVSummaryBuilder)
+ CIRCLE_NODE(SQRT, CircleSqrtSummaryBuilder)
+ CIRCLE_NODE(SQUARE, CircleSquareSummaryBuilder)
+ CIRCLE_NODE(SQUARED_DIFFERENCE, CircleSquaredDifferenceSummaryBuilder)
+ CIRCLE_NODE(SQUEEZE, CircleSqueezeSummaryBuilder)
+ CIRCLE_NODE(STRIDED_SLICE, CircleStridedSliceSummaryBuilder)
+ CIRCLE_NODE(SUB, CircleSubSummaryBuilder)
+ CIRCLE_NODE(SUM, CircleSumSummaryBuilder)
+ CIRCLE_NODE(SVDF, CircleSVDFSummaryBuilder)
+ CIRCLE_NODE(TANH, CircleTanhSummaryBuilder)
+ CIRCLE_NODE(TILE, CircleTileSummaryBuilder)
+ CIRCLE_NODE(TOPK_V2, CircleTopKV2SummaryBuilder)
+ CIRCLE_NODE(TRANSPOSE, CircleTransposeSummaryBuilder)
+ CIRCLE_NODE(TRANSPOSE_CONV, CircleTransposeConvSummaryBuilder)
+ CIRCLE_NODE(UNIDIRECTIONAL_SEQUENCE_LSTM, CircleUnidirectionalSequenceLSTMSummaryBuilder)
+ CIRCLE_NODE(UNIQUE, CircleUniqueSummaryBuilder)
+ CIRCLE_NODE(UNPACK, CircleUnpackSummaryBuilder)
+ CIRCLE_NODE(WHERE, CircleWhereSummaryBuilder)
+ CIRCLE_NODE(WHILE, CircleWhileSummaryBuilder)
+ CIRCLE_NODE(ZEROS_LIKE, CircleZerosLikeSummaryBuilder)
+
+ CIRCLE_NODE(CIRCLEBIDIRECTIONAL_SEQUENCE_LSTM_OUT,
+ CircleBidirectionalSequenceLSTMOutSummaryBuilder)
+ CIRCLE_NODE(CIRCLECUSTOMOUT, CircleCustomOutSummaryBuilder)
+ CIRCLE_NODE(CIRCLEIFOUT, CircleIfOutSummaryBuilder)
+ CIRCLE_NODE(CIRCLEINPUT, CircleInputSummaryBuilder)
+ CIRCLE_NODE(CIRCLENONMAXSUPPRESSIONV4OUT, CircleNonMaxSuppressionV4OutSummaryBuilder)
+ CIRCLE_NODE(CIRCLENONMAXSUPPRESSIONV5OUT, CircleNonMaxSuppressionV5OutSummaryBuilder)
+ CIRCLE_NODE(CIRCLEOUTPUT, CircleOutputSummaryBuilder)
+ CIRCLE_NODE(CIRCLEOUTPUTDUMMY, CircleOutputDummySummaryBuilder)
+ CIRCLE_NODE(CIRCLEOUTPUTEXCLUDE, CircleOutputExcludeSummaryBuilder)
+ CIRCLE_NODE(CIRCLESPLITOUT, CircleSplitOutSummaryBuilder)
+ CIRCLE_NODE(CIRCLESPLITVOUT, CircleSplitVOutSummaryBuilder)
+ CIRCLE_NODE(CIRCLETOPKV2OUT, CircleTopKV2OutSummaryBuilder)
+ CIRCLE_NODE(CIRCLEUNIQUEOUT, CircleUniqueOutSummaryBuilder)
+ CIRCLE_NODE(CIRCLEUNPACKOUT, CircleUnpackOutSummaryBuilder)
+ CIRCLE_NODE(CIRCLEVARIABLE, CircleVariableSummaryBuilder)
+ CIRCLE_NODE(CIRCLEWHILEOUT, CircleWhileOutSummaryBuilder)
+
+ default:
+ return nullptr;
+
+#undef CIRCLE_NODE
+ }
+}
+
+} // namespace luci
diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilder.h b/compiler/luci/logex/src/CircleNodeSummaryBuilder.h
new file mode 100644
index 000000000..e21d77310
--- /dev/null
+++ b/compiler/luci/logex/src/CircleNodeSummaryBuilder.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_LOGEX_CIRCLE_NODE_SUMMARY_BUILDER__
+#define __LUCI_LOGEX_CIRCLE_NODE_SUMMARY_BUILDER__
+
+#include <luci/IR/CircleNode.h>
+#include <locop/NodeSummary.h>
+#include <locop/SymbolTable.h>
+
+#include <memory>
+#include <sstream>
+#include <vector>
+
+namespace luci
+{
+
+class CircleNodeSummaryBuilder
+{
+public:
+ bool build(const loco::Node *node, const locop::SymbolTable *tbl, locop::NodeSummary &s);
+
+private:
+ /**
+ * @brief Template methods for building node summary.
+ * Default behavior is building a node which has no input.
+ */
+ virtual bool validate(const luci::CircleNode *node);
+ virtual std::vector<std::string> get_input_names(const luci::CircleNode *node);
+ virtual void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+ virtual void update_status(locop::NodeSummary &s);
+
+private:
+ std::unique_ptr<CircleNodeSummaryBuilder> create_builder(const luci::CircleNode *node);
+};
+
+} // namespace luci
+
+#endif // __LUCI_LOGEX_CIRCLE_NODE_SUMMARY_BUILDER__
diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilder.test.cpp b/compiler/luci/logex/src/CircleNodeSummaryBuilder.test.cpp
new file mode 100644
index 000000000..89ea213e0
--- /dev/null
+++ b/compiler/luci/logex/src/CircleNodeSummaryBuilder.test.cpp
@@ -0,0 +1,309 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleNodeSummaryBuilder.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <locop/NodeSummary.h>
+#include <locop/SymbolTable.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+class MockSymbolTable : public locop::SymbolTable
+{
+ std::string lookup(const loco::Node *) const override
+ {
+ return "Do nothing because it is mocking Symbol Table!";
+ }
+};
+
+class CircleNodeSummaryBuilderTest : public ::testing::Test
+{
+protected:
+ bool mock_build(const loco::Node *node)
+ {
+ return luci::CircleNodeSummaryBuilder().build(node, &_tbl, _s);
+ }
+
+protected:
+ MockSymbolTable _tbl;
+ locop::NodeSummary _s;
+};
+
+} // namespace
+
+TEST_F(CircleNodeSummaryBuilderTest, Add_validate)
+{
+ luci::CircleAdd node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, Add_validate_fused_NEG)
+{
+ luci::CircleAdd node;
+ node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, AveragePool2D_validate)
+{
+ luci::CircleAveragePool2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ node.padding(luci::Padding::SAME);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, AveragePool2D_validate_fused_NEG)
+{
+ luci::CircleAveragePool2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ node.padding(luci::Padding::SAME);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, AveragePool2D_validate_padding_NEG)
+{
+ luci::CircleAveragePool2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ node.padding(luci::Padding::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, BCQFullyConnected_validate)
+{
+ luci::CircleBCQFullyConnected node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, BCQFullyConnected_validate_fused_NEG)
+{
+ luci::CircleBCQFullyConnected node;
+ node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, Concatenation_validate)
+{
+ luci::CircleConcatenation node(2);
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, Concatenation_validate_fused_NEG)
+{
+ luci::CircleConcatenation node(2);
+ node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, Conv2D_validate)
+{
+ luci::CircleConv2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ node.padding(luci::Padding::SAME);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, Conv2D_validate_fused_NEG)
+{
+ luci::CircleConv2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ node.padding(luci::Padding::SAME);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, Conv2D_validate_padding_NEG)
+{
+ luci::CircleConv2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ node.padding(luci::Padding::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, DepthwiseConv2D_validate)
+{
+ luci::CircleDepthwiseConv2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ node.padding(luci::Padding::SAME);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, DepthwiseConv2D_validate_fused_NEG)
+{
+ luci::CircleDepthwiseConv2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ node.padding(luci::Padding::SAME);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, DepthwiseConv2D_validate_padding_NEG)
+{
+ luci::CircleDepthwiseConv2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ node.padding(luci::Padding::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, FullyConnected_validate)
+{
+ luci::CircleFullyConnected node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, FullyConnected_validate_fused_NEG)
+{
+ luci::CircleFullyConnected node;
+ node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, InstanceNorm_validate)
+{
+ luci::CircleInstanceNorm node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, InstanceNorm_validate_fused_NEG)
+{
+ luci::CircleInstanceNorm node;
+ node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, L2Normalize_validate)
+{
+ luci::CircleL2Normalize node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, L2Normalize_validate_fused_NEG)
+{
+ luci::CircleL2Normalize node;
+ node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, L2Pool2D_validate)
+{
+ luci::CircleL2Pool2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ node.padding(luci::Padding::SAME);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, L2Pool2D_validate_fused_NEG)
+{
+ luci::CircleL2Pool2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ node.padding(luci::Padding::SAME);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, L2Pool2D_validate_padding_NEG)
+{
+ luci::CircleL2Pool2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ node.padding(luci::Padding::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, MaxPool2D_validate)
+{
+ luci::CircleMaxPool2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ node.padding(luci::Padding::SAME);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, MaxPool2D_validate_fused_NEG)
+{
+ luci::CircleMaxPool2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ node.padding(luci::Padding::SAME);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, MaxPool2D_validate_padding_NEG)
+{
+ luci::CircleMaxPool2D node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ node.padding(luci::Padding::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, MirrorPad_validate)
+{
+ luci::CircleMirrorPad node;
+ node.mode(luci::MirrorPadMode::REFLECT);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, MirrorPad_validate_mirror_padding_NEG)
+{
+ luci::CircleMirrorPad node;
+ node.mode(luci::MirrorPadMode::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, Mul_validate)
+{
+ luci::CircleMul node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, Mul_validate_fused_NEG)
+{
+ luci::CircleMul node;
+ node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, SVDF_validate)
+{
+ luci::CircleSVDF node;
+ node.fusedActivationFunction(luci::FusedActFunc::RELU);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, SVDF_validate_fused_NEG)
+{
+ luci::CircleSVDF node;
+ node.fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, TransposeConv_validate)
+{
+ luci::CircleTransposeConv node;
+ node.padding(luci::Padding::SAME);
+ EXPECT_TRUE(mock_build(&node));
+}
+
+TEST_F(CircleNodeSummaryBuilderTest, TransposeConv_validate_padding_NEG)
+{
+ luci::CircleTransposeConv node;
+ node.padding(luci::Padding::UNDEFINED);
+ EXPECT_FALSE(mock_build(&node));
+}
diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp b/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp
new file mode 100644
index 000000000..6df9270e3
--- /dev/null
+++ b/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp
@@ -0,0 +1,1128 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleNodeSummaryBuilders.h"
+
+#include <luci/IR/CircleNode.h>
+#include <luci/IR/CircleNodes.h>
+#include <loco/IR/Node.h>
+
+#include <string>
+#include <vector>
+
+namespace
+{
+
+std::string to_str(loco::DataType type)
+{
+ switch (type)
+ {
+ case loco::DataType::U8:
+ return "UINT8";
+ case loco::DataType::U16:
+ return "UINT16";
+ case loco::DataType::U32:
+ return "UINT32";
+ case loco::DataType::U64:
+ return "UINT64";
+
+ case loco::DataType::S8:
+ return "INT8";
+ case loco::DataType::S16:
+ return "INT16";
+ case loco::DataType::S32:
+ return "INT32";
+ case loco::DataType::S64:
+ return "INT64";
+
+ case loco::DataType::FLOAT16:
+ return "FLOAT16";
+ case loco::DataType::FLOAT32:
+ return "FLOAT32";
+ case loco::DataType::FLOAT64:
+ return "FLOAT64";
+
+ case loco::DataType::BOOL:
+ return "BOOL";
+
+ default:
+ return "Error";
+ }
+}
+
+std::string to_str(bool value) { return value ? "true" : "false"; }
+
+std::string to_str(luci::FusedActFunc fused)
+{
+ switch (fused)
+ {
+ case luci::FusedActFunc::NONE:
+ return "NONE";
+ case luci::FusedActFunc::RELU:
+ return "RELU";
+ case luci::FusedActFunc::RELU_N1_TO_1:
+ return "RELU_N1_TO_1";
+ case luci::FusedActFunc::RELU6:
+ return "RELU6";
+ case luci::FusedActFunc::TANH:
+ return "TANH";
+ case luci::FusedActFunc::SIGN_BIT:
+ return "SIGN_BIT";
+ default:
+ return "Error";
+ }
+}
+
+std::string to_str(luci::Padding padding)
+{
+ switch (padding)
+ {
+ case luci::Padding::SAME:
+ return "SAME";
+ case luci::Padding::VALID:
+ return "VALID";
+ default:
+ return "Error";
+ }
+}
+
+std::string to_str(const luci::Stride *stride)
+{
+ return std::to_string(stride->h()) + "," + std::to_string(stride->w());
+}
+
+std::string to_str(const luci::Filter *filter)
+{
+ return std::to_string(filter->h()) + "," + std::to_string(filter->w());
+}
+
+std::string to_str(luci::MirrorPadMode mode)
+{
+ switch (mode)
+ {
+ case luci::MirrorPadMode::REFLECT:
+ return "REFLECT";
+ case luci::MirrorPadMode::SYMMETRIC:
+ return "SYMMETRIC";
+ default:
+ return "Error";
+ }
+}
+
+} // namespace
+
+namespace luci
+{
+
+std::vector<std::string> CircleNodeWithXSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"x"};
+}
+
+std::vector<std::string>
+CircleNodeWithINPUTSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input"};
+}
+
+std::vector<std::string> CircleNodeWithXYSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"x", "y"};
+}
+
+std::vector<std::string>
+CircleNodeWithFEATURESSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"features"};
+}
+
+} // namespace luci
+
+namespace luci
+{
+
+bool CircleAddSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto add = loco::must_cast<const luci::CircleAdd *>(node);
+ if (add->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+void CircleAddSummaryBuilder::build_attributes(const luci::CircleNode *node, locop::NodeSummary &s)
+{
+ auto add = loco::must_cast<const luci::CircleAdd *>(node);
+ s.args().append("fused_activation_function", to_str(add->fusedActivationFunction()));
+}
+
+std::vector<std::string> CircleAddNSummaryBuilder::get_input_names(const luci::CircleNode *node)
+{
+ return std::vector<std::string>(node->arity(), "inputs");
+}
+
+std::vector<std::string> CircleArgMaxSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "dimension"};
+}
+
+void CircleArgMaxSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto argmax = loco::must_cast<const luci::CircleArgMax *>(node);
+ s.args().append("output_type", to_str(argmax->output_type()));
+}
+
+std::vector<std::string> CircleArgMinSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "dimension"};
+}
+
+void CircleArgMinSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto argmin = loco::must_cast<const luci::CircleArgMin *>(node);
+ s.args().append("output_type", to_str(argmin->output_type()));
+}
+
+bool CircleAveragePool2DSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto avgpool = loco::must_cast<const luci::CircleAveragePool2D *>(node);
+ if (avgpool->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return false;
+ if (avgpool->padding() == luci::Padding::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+std::vector<std::string>
+CircleAveragePool2DSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"value"};
+}
+
+void CircleAveragePool2DSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto avgpool = loco::must_cast<const luci::CircleAveragePool2D *>(node);
+ s.args().append("filter(h,w)", to_str(avgpool->filter()));
+ s.args().append("stride(h,w)", to_str(avgpool->stride()));
+ s.args().append("padding", to_str(avgpool->padding()));
+ s.args().append("fused_activation_function", to_str(avgpool->fusedActivationFunction()));
+}
+
+void CircleBatchMatMulSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto batchmatmul = loco::must_cast<const luci::CircleBatchMatMul *>(node);
+ s.args().append("adj_x", to_str(batchmatmul->adj_x()));
+ s.args().append("adj_y", to_str(batchmatmul->adj_y()));
+}
+
+std::vector<std::string>
+CircleBatchToSpaceNDSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "block_shape", "crops"};
+}
+
+bool CircleBCQFullyConnectedSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto bcq_fc = loco::must_cast<const luci::CircleBCQFullyConnected *>(node);
+ if (bcq_fc->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+std::vector<std::string>
+CircleBCQFullyConnectedSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "weights_scales", "weights_binary", "bias", "weights_clusters"};
+}
+
+void CircleBCQFullyConnectedSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto bcq_fc = loco::must_cast<const luci::CircleBCQFullyConnected *>(node);
+ s.args().append("fused_activation_function", to_str(bcq_fc->fusedActivationFunction()));
+ s.args().append("weights_hidden_size", std::to_string(bcq_fc->weights_hidden_size()));
+}
+
+std::vector<std::string> CircleBCQGatherSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input_scales", "input_binary", "indices", "input_clusters"};
+}
+
+void CircleBCQGatherSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto bcq_gather = loco::must_cast<const luci::CircleBCQGather *>(node);
+ s.args().append("axis", std::to_string(bcq_gather->axis()));
+ s.args().append("input_hidden_size", std::to_string(bcq_gather->input_hidden_size()));
+}
+
+std::vector<std::string>
+CircleBidirectionalSequenceLSTMSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input",
+ "fw_input_to_input_weights",
+ "fw_input_to_forget_weights",
+ "fw_input_to_cell_weights",
+ "fw_input_to_output_weights",
+ "fw_recurrent_to_input_weights",
+ "fw_recurrent_to_forget_weights",
+ "fw_recurrent_to_cell_weights",
+ "fw_recurrent_to_output_weights",
+ "fw_cell_to_input_weights",
+ "fw_cell_to_forget_weights",
+ "fw_cell_to_output_weights",
+ "fw_input_gate_bias",
+ "fw_forget_gate_bias",
+ "fw_cell_gate_bias",
+ "fw_output_gate_bias",
+ "fw_projection_weights",
+ "fw_projection_bias",
+ "bw_input_to_input_weights",
+ "bw_input_to_forget_weights",
+ "bw_input_to_cell_weights",
+ "bw_input_to_output_weights",
+ "bw_recurrent_to_input_weights",
+ "bw_recurrent_to_forget_weights",
+ "bw_recurrent_to_cell_weights",
+ "bw_recurrent_to_output_weights",
+ "bw_cell_to_input_weights",
+ "bw_cell_to_forget_weights",
+ "bw_cell_to_output_weights",
+ "bw_input_gate_bias",
+ "bw_forget_gate_bias",
+ "bw_cell_gate_bias",
+ "bw_output_gate_bias",
+ "bw_projection_weights",
+ "bw_projection_bias",
+ "fw_activation_state",
+ "fw_cell_state",
+ "bw_activation_state",
+ "bw_cell_state",
+ "auxillary_input",
+ "fw_auxillary_input_to_input_weights",
+ "fw_auxillary_input_to_forget_weights",
+ "fw_auxillary_input_to_cell_weights",
+ "fw_auxillary_input_to_output_weights",
+ "bw_auxillary_input_to_input_weights",
+ "bw_auxillary_input_to_forget_weights",
+ "bw_auxillary_input_to_cell_weights",
+ "bw_auxillary_input_to_output_weights"};
+}
+
+void CircleBidirectionalSequenceLSTMSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto lstm = loco::must_cast<const luci::CircleBidirectionalSequenceLSTM *>(node);
+ s.args().append("cell_clip", to_str(lstm->cell_clip()));
+ s.args().append("proj_clip", to_str(lstm->proj_clip()));
+ s.args().append("merge_outputs", to_str(lstm->merge_outputs()));
+ s.args().append("time_major", to_str(lstm->time_major()));
+ s.args().append("asymmetric_quantize_inputs", to_str(lstm->asymmetric_quantize_inputs()));
+}
+
+std::vector<std::string> CircleCastSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"x"};
+}
+
+void CircleCastSummaryBuilder::build_attributes(const luci::CircleNode *node, locop::NodeSummary &s)
+{
+ auto cast = loco::must_cast<const luci::CircleCast *>(node);
+ s.args().append("in_data_type", to_str(cast->in_data_type()));
+ s.args().append("out_data_type", to_str(cast->out_data_type()));
+}
+
+bool CircleConcatenationSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto concat = loco::must_cast<const luci::CircleConcatenation *>(node);
+ if (concat->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+std::vector<std::string>
+CircleConcatenationSummaryBuilder::get_input_names(const luci::CircleNode *node)
+{
+ return std::vector<std::string>(node->arity(), "values");
+}
+
+void CircleConcatenationSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto concat = loco::must_cast<const luci::CircleConcatenation *>(node);
+ s.args().append("axis", std::to_string(concat->axis()));
+ s.args().append("fused_activation_function", to_str(concat->fusedActivationFunction()));
+}
+
+void CircleConstSummaryBuilder::update_status(locop::NodeSummary &s)
+{
+ s.state(locop::NodeDesc::State::PartiallyKnown);
+}
+
+bool CircleConv2DSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto conv2d = loco::must_cast<const luci::CircleConv2D *>(node);
+ if (conv2d->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return false;
+ if (conv2d->padding() == luci::Padding::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+std::vector<std::string> CircleConv2DSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "filter", "bias"};
+}
+
+void CircleConv2DSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto conv2d = loco::must_cast<const luci::CircleConv2D *>(node);
+ s.args().append("stride(h,w)", to_str(conv2d->stride()));
+ s.args().append("dilation(h,w)", to_str(conv2d->dilation()));
+ s.args().append("padding", to_str(conv2d->padding()));
+ s.args().append("fused_activation_function", to_str(conv2d->fusedActivationFunction()));
+}
+
+std::vector<std::string> CircleCustomSummaryBuilder::get_input_names(const luci::CircleNode *node)
+{
+ auto input_names = std::vector<std::string>();
+ for (uint32_t i = 0; i < node->arity(); ++i)
+ input_names.push_back("input" + std::to_string(i));
+ return input_names;
+}
+
+void CircleCustomSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto custom = loco::must_cast<const luci::CircleCustom *>(node);
+ s.args().append("custom_code", custom->custom_code());
+}
+
+void CircleDepthToSpaceSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto depth_to_space = loco::must_cast<const luci::CircleDepthToSpace *>(node);
+ s.args().append("block_size", std::to_string(depth_to_space->block_size()));
+}
+
+bool CircleDepthwiseConv2DSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto dw_conv2d = loco::must_cast<const luci::CircleDepthwiseConv2D *>(node);
+ if (dw_conv2d->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return false;
+ if (dw_conv2d->padding() == luci::Padding::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+std::vector<std::string>
+CircleDepthwiseConv2DSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "filter", "bias"};
+}
+
+void CircleDepthwiseConv2DSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto dw_conv2d = loco::must_cast<const luci::CircleDepthwiseConv2D *>(node);
+ s.args().append("stride(h,w)", to_str(dw_conv2d->stride()));
+ s.args().append("dilation(h,w)", to_str(dw_conv2d->dilation()));
+ s.args().append("padding", to_str(dw_conv2d->padding()));
+ s.args().append("depthMultiplier", std::to_string(dw_conv2d->depthMultiplier()));
+ s.args().append("fused_activation_function", to_str(dw_conv2d->fusedActivationFunction()));
+}
+
+std::vector<std::string> CircleExpandDimsSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "axis"};
+}
+
+std::vector<std::string> CircleFakeQuantSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"inputs"};
+}
+
+void CircleFakeQuantSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto fake_quant = loco::must_cast<const luci::CircleFakeQuant *>(node);
+ s.args().append("min", std::to_string(fake_quant->min()));
+ s.args().append("max", std::to_string(fake_quant->max()));
+ s.args().append("num_bits", std::to_string(fake_quant->num_bits()));
+ s.args().append("narrow_range", to_str(fake_quant->narrow_range()));
+}
+
+std::vector<std::string> CircleFillSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"dims", "value"};
+}
+
+bool CircleFullyConnectedSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto fc = loco::must_cast<const luci::CircleFullyConnected *>(node);
+ if (fc->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+std::vector<std::string>
+CircleFullyConnectedSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "weights", "bias"};
+}
+
+void CircleFullyConnectedSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto fc = loco::must_cast<const luci::CircleFullyConnected *>(node);
+ s.args().append("fused_activation_function", to_str(fc->fusedActivationFunction()));
+}
+
+std::vector<std::string> CircleGatherSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"params", "indices"};
+}
+
+void CircleGatherSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto gather = loco::must_cast<const luci::CircleGather *>(node);
+ s.args().append("axis", std::to_string(gather->axis()));
+}
+
+std::vector<std::string> CircleGatherNdSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"params", "indices"};
+}
+
+std::vector<std::string> CircleIfSummaryBuilder::get_input_names(const luci::CircleNode *node)
+{
+ auto circle_if = loco::must_cast<const luci::CircleIf *>(node);
+
+ auto input_names = std::vector<std::string>();
+ input_names.push_back("cond");
+ for (uint32_t i = 0; i < circle_if->input_count(); ++i)
+ input_names.push_back("input");
+
+ return input_names;
+}
+
+void CircleIfSummaryBuilder::build_attributes(const luci::CircleNode *node, locop::NodeSummary &s)
+{
+ auto circle_if = loco::must_cast<const luci::CircleIf *>(node);
+
+ if (circle_if->then_graph() != nullptr)
+ s.args().append("then_graph", circle_if->then_graph()->name());
+ else
+ s.args().append("then_branch", std::to_string(circle_if->then_branch()));
+
+ if (circle_if->else_graph() != nullptr)
+ s.args().append("else_graph", circle_if->else_graph()->name());
+ else
+ s.args().append("else_branch", std::to_string(circle_if->else_branch()));
+}
+
+bool CircleInstanceNormSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto instnorm = loco::must_cast<const luci::CircleInstanceNorm *>(node);
+ if (instnorm->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+std::vector<std::string> CircleInstanceNormSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "gamma", "beta"};
+}
+
+void CircleInstanceNormSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto instnorm = loco::must_cast<const luci::CircleInstanceNorm *>(node);
+ s.args().append("epsilon", std::to_string(instnorm->epsilon()));
+ s.args().append("fused_activation_function", to_str(instnorm->fusedActivationFunction()));
+}
+
+bool CircleL2NormalizeSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto l2norm = loco::must_cast<const luci::CircleL2Normalize *>(node);
+ if (l2norm->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+std::vector<std::string> CircleL2NormalizeSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"x"};
+}
+
+void CircleL2NormalizeSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto l2norm = loco::must_cast<const luci::CircleL2Normalize *>(node);
+ s.args().append("fused_activation_function", to_str(l2norm->fusedActivationFunction()));
+}
+
+bool CircleL2Pool2DSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto l2pool = loco::must_cast<const luci::CircleL2Pool2D *>(node);
+ if (l2pool->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return false;
+ if (l2pool->padding() == luci::Padding::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+std::vector<std::string> CircleL2Pool2DSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"value"};
+}
+
+void CircleL2Pool2DSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto l2pool = loco::must_cast<const luci::CircleL2Pool2D *>(node);
+ s.args().append("filter(h,w)", to_str(l2pool->filter()));
+ s.args().append("stride(h,w)", to_str(l2pool->stride()));
+ s.args().append("padding", to_str(l2pool->padding()));
+ s.args().append("fused_activation_function", to_str(l2pool->fusedActivationFunction()));
+}
+
+void CircleLeakyReluSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto leaky_relu = loco::must_cast<const luci::CircleLeakyRelu *>(node);
+ s.args().append("alpha", std::to_string(leaky_relu->alpha()));
+}
+
+void CircleLocalResponseNormalizationSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto lrn = loco::must_cast<const luci::CircleLocalResponseNormalization *>(node);
+ s.args().append("radius", std::to_string(lrn->radius()));
+ s.args().append("bias", std::to_string(lrn->bias()));
+ s.args().append("alpha", std::to_string(lrn->alpha()));
+ s.args().append("beta", std::to_string(lrn->beta()));
+}
+
+std::vector<std::string> CircleLogSoftmaxSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"logits"};
+}
+
+std::vector<std::string> CircleMatrixDiagSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"diagonal"};
+}
+
+std::vector<std::string>
+CircleMatrixSetDiagSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "diagonal"};
+}
+
+bool CircleMaxPool2DSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto maxpool = loco::must_cast<const luci::CircleMaxPool2D *>(node);
+ if (maxpool->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return false;
+ if (maxpool->padding() == luci::Padding::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+std::vector<std::string> CircleMaxPool2DSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"value"};
+}
+
+void CircleMaxPool2DSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto maxpool = loco::must_cast<const luci::CircleMaxPool2D *>(node);
+ s.args().append("filter(h,w)", to_str(maxpool->filter()));
+ s.args().append("stride(h,w)", to_str(maxpool->stride()));
+ s.args().append("padding", to_str(maxpool->padding()));
+ s.args().append("fused_activation_function", to_str(maxpool->fusedActivationFunction()));
+}
+
+bool CircleMirrorPadSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto mirror_pad = loco::must_cast<const luci::CircleMirrorPad *>(node);
+ if (mirror_pad->mode() == luci::MirrorPadMode::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+std::vector<std::string> CircleMirrorPadSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "paddings"};
+}
+
+void CircleMirrorPadSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto mirror_pad = loco::must_cast<const luci::CircleMirrorPad *>(node);
+ s.args().append("mode", to_str(mirror_pad->mode()));
+}
+
+bool CircleMulSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto mul = loco::must_cast<const luci::CircleMul *>(node);
+ if (mul->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+void CircleMulSummaryBuilder::build_attributes(const luci::CircleNode *node, locop::NodeSummary &s)
+{
+ auto mul = loco::must_cast<const luci::CircleMul *>(node);
+ s.args().append("fused_activation_function", to_str(mul->fusedActivationFunction()));
+}
+
+std::vector<std::string>
+CircleNonMaxSuppressionV4SummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"boxes", "scores", "max_output_size", "iou_threshold", "score_threshold"};
+}
+
+std::vector<std::string>
+CircleNonMaxSuppressionV5SummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"boxes", "scores", "max_output_size",
+ "iou_threshold", "score_threshold", "soft_nms_sigma"};
+}
+
+std::vector<std::string> CircleOneHotSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"indices", "depth", "on_value", "off_value"};
+}
+
+void CircleOneHotSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto onehot = loco::must_cast<const luci::CircleOneHot *>(node);
+ s.args().append("axis", std::to_string(onehot->axis()));
+}
+
+std::vector<std::string> CirclePackSummaryBuilder::get_input_names(const luci::CircleNode *node)
+{
+ return std::vector<std::string>(node->arity(), "values");
+}
+
+void CirclePackSummaryBuilder::build_attributes(const luci::CircleNode *node, locop::NodeSummary &s)
+{
+ auto pack = loco::must_cast<const luci::CirclePack *>(node);
+ s.args().append("values_count", std::to_string(pack->values_count()));
+ s.args().append("axis", std::to_string(pack->axis()));
+}
+
+std::vector<std::string> CirclePadSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "paddings"};
+}
+
+std::vector<std::string> CirclePadV2SummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "paddings", "constant_values"};
+}
+
+std::vector<std::string> CirclePReluSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "alpha"};
+}
+
+std::vector<std::string> CircleRangeSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"start", "limit", "delta"};
+}
+
+std::vector<std::string> CircleReshapeSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"tensor", "shape"};
+}
+
+void CircleReshapeSummaryBuilder::update_status(locop::NodeSummary &s)
+{
+ s.state(locop::NodeDesc::State::PartiallyKnown);
+}
+
+std::vector<std::string>
+CircleResizeBilinearSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "size"};
+}
+
+void CircleResizeBilinearSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto resize_bilinear = loco::must_cast<const luci::CircleResizeBilinear *>(node);
+ s.args().append("align_corners", to_str(resize_bilinear->align_corners()));
+ s.args().append("half_pixel_centers", to_str(resize_bilinear->half_pixel_centers()));
+}
+
+std::vector<std::string>
+CircleResizeNearestNeighborSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "size"};
+}
+
+void CircleResizeNearestNeighborSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto resize_nn = loco::must_cast<const luci::CircleResizeNearestNeighbor *>(node);
+ s.args().append("align_corners", to_str(resize_nn->align_corners()));
+}
+
+std::vector<std::string>
+CircleReverseSequenceSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "seq_lengths"};
+}
+
+void CircleReverseSequenceSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto reverse_seq = loco::must_cast<const luci::CircleReverseSequence *>(node);
+ s.args().append("seq_axis", std::to_string(reverse_seq->seq_axis()));
+ s.args().append("batch_axis", std::to_string(reverse_seq->batch_axis()));
+}
+
+std::vector<std::string> CircleReverseV2SummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"tensor", "axis"};
+}
+
+std::vector<std::string> CircleScatterNdSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"indices", "updates", "shape"};
+}
+
+std::vector<std::string> CircleSegmentSumSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "segment_ids"};
+}
+
+std::vector<std::string> CircleSelectSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"condition", "t", "e"};
+}
+
+std::vector<std::string> CircleSelectV2SummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"condition", "t", "e"};
+}
+
+void CircleShapeSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto shape = loco::must_cast<const luci::CircleShape *>(node);
+ s.args().append("out_type", to_str(shape->out_type()));
+}
+
+std::vector<std::string> CircleSliceSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "begin", "size"};
+}
+
+std::vector<std::string> CircleSoftmaxSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"logits"};
+}
+
+void CircleSoftmaxSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto softmax = loco::must_cast<const luci::CircleSoftmax *>(node);
+ s.args().append("beta", to_str(softmax->beta()));
+}
+
+std::vector<std::string>
+CircleSpaceToBatchNDSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "block_shape", "paddings"};
+}
+
+void CircleSpaceToDepthSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto space_to_depth = loco::must_cast<const luci::CircleSpaceToDepth *>(node);
+ s.args().append("block_size", to_str(space_to_depth->block_size()));
+}
+
+std::vector<std::string>
+CircleSparseToDenseSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"indices", "output_shape", "values", "default_value"};
+}
+
+void CircleSparseToDenseSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto sparse_to_dense = loco::must_cast<const luci::CircleSparseToDense *>(node);
+ s.args().append("validate_indices", to_str(sparse_to_dense->validate_indices()));
+}
+
+std::vector<std::string> CircleSplitSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"split_dim", "input"};
+}
+
+void CircleSplitSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto split = loco::must_cast<const luci::CircleSplit *>(node);
+ s.args().append("num_split", std::to_string(split->num_split()));
+}
+
+std::vector<std::string> CircleSplitVSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "size_splits", "split_dim"};
+}
+
+void CircleSplitVSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto split_v = loco::must_cast<const luci::CircleSplitV *>(node);
+ s.args().append("num_split", std::to_string(split_v->num_split()));
+}
+
+void CircleSqueezeSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto squeeze = loco::must_cast<const luci::CircleSqueeze *>(node);
+
+ std::string squeeze_dims = "(";
+ for (size_t i = 0; i < squeeze->squeeze_dims().size(); ++i)
+ {
+ if (i != 0)
+ squeeze_dims += ", ";
+ squeeze_dims += std::to_string(squeeze->squeeze_dims().at(i));
+ }
+ squeeze_dims += ")";
+
+ s.args().append("squeeze_dims", squeeze_dims);
+}
+
+std::vector<std::string> CircleStridedSliceSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "begin", "end", "strides"};
+}
+
+void CircleStridedSliceSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto strided_slice = loco::must_cast<const luci::CircleStridedSlice *>(node);
+ s.args().append("begin_mask", std::to_string(strided_slice->begin_mask()));
+ s.args().append("end_mask", std::to_string(strided_slice->end_mask()));
+ s.args().append("ellipsis_mask", std::to_string(strided_slice->ellipsis_mask()));
+ s.args().append("new_axis_mask", std::to_string(strided_slice->new_axis_mask()));
+ s.args().append("shrink_axis_mask", std::to_string(strided_slice->shrink_axis_mask()));
+}
+
+bool CircleSVDFSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto svdf = loco::must_cast<const luci::CircleSVDF *>(node);
+ if (svdf->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+std::vector<std::string> CircleSVDFSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "weight_feature", "weight_time", "bias", "State"};
+}
+
+void CircleSVDFSummaryBuilder::build_attributes(const luci::CircleNode *node, locop::NodeSummary &s)
+{
+ auto svdf = loco::must_cast<const luci::CircleSVDF *>(node);
+ s.args().append("rank", to_str(svdf->svdf_rank()));
+ s.args().append("asymmetric_quantize_inputs", to_str(svdf->asymmetric_quantize_inputs()));
+ s.args().append("fused_activation_function", to_str(svdf->fusedActivationFunction()));
+}
+
+std::vector<std::string> CircleTileSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "multiples"};
+}
+
+std::vector<std::string> CircleTopKV2SummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input", "k"};
+}
+
+std::vector<std::string> CircleTransposeSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"a", "perm"};
+}
+
+bool CircleTransposeConvSummaryBuilder::validate(const luci::CircleNode *node)
+{
+ auto transpose_conv = loco::must_cast<const luci::CircleTransposeConv *>(node);
+ if (transpose_conv->padding() == luci::Padding::UNDEFINED)
+ return false;
+
+ return true;
+}
+
+std::vector<std::string>
+CircleTransposeConvSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"inputSizes", "filter", "outBackProp", "bias"};
+}
+
+void CircleTransposeConvSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto transpose_conv = loco::must_cast<const luci::CircleTransposeConv *>(node);
+ s.args().append("stride(h,w)", to_str(transpose_conv->stride()));
+ s.args().append("padding", to_str(transpose_conv->padding()));
+}
+
+std::vector<std::string>
+CircleUnidirectionalSequenceLSTMSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"input",
+ "input_to_input_weights",
+ "input_to_forget_weights",
+ "input_to_cell_weights",
+ "input_to_output_weights",
+ "recurrent_to_input_weights",
+ "recurrent_to_forget_weights",
+ "recurrent_to_cell_weights",
+ "recurrent_to_output_weights",
+ "cell_to_input_weights",
+ "cell_to_forget_weights",
+ "cell_to_output_weights",
+ "input_gate_bias",
+ "forget_gate_bias",
+ "cell_gate_bias",
+ "output_gate_bias",
+ "projection_weights",
+ "projection_bias",
+ "activation_state",
+ "cell_state",
+ "input_layer_norm_coefficients",
+ "forget_layer_norm_coefficients",
+ "cell_layer_norm_coefficients",
+ "output_layer_norm_coefficients"};
+}
+
+void CircleUnidirectionalSequenceLSTMSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto lstm = loco::must_cast<const luci::CircleUnidirectionalSequenceLSTM *>(node);
+ s.args().append("cell_clip", to_str(lstm->cell_clip()));
+ s.args().append("proj_clip", to_str(lstm->proj_clip()));
+ s.args().append("time_major", to_str(lstm->time_major()));
+ s.args().append("asymmetric_quantize_inputs", to_str(lstm->asymmetric_quantize_inputs()));
+}
+
+void CircleUniqueSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto unique = loco::must_cast<const luci::CircleUnique *>(node);
+ s.args().append("idx_out_type", to_str(unique->idx_out_type()));
+}
+
+std::vector<std::string> CircleUnpackSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"value"};
+}
+
+void CircleUnpackSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto unpack = loco::must_cast<const luci::CircleUnpack *>(node);
+ s.args().append("num", std::to_string(unpack->num()));
+ s.args().append("axis", std::to_string(unpack->axis()));
+}
+std::vector<std::string> CircleWhereSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"condition"};
+}
+
+std::vector<std::string> CircleWhileSummaryBuilder::get_input_names(const luci::CircleNode *node)
+{
+ auto circle_while = loco::must_cast<const luci::CircleWhile *>(node);
+
+ auto input_names = std::vector<std::string>();
+ for (uint32_t i = 0; i < circle_while->input_count(); ++i)
+ input_names.push_back("input");
+
+ return input_names;
+}
+
+void CircleWhileSummaryBuilder::build_attributes(const luci::CircleNode *node,
+ locop::NodeSummary &s)
+{
+ auto circle_while = loco::must_cast<const luci::CircleWhile *>(node);
+
+ if (circle_while->cond_graph() != nullptr)
+ s.args().append("then_graph", circle_while->cond_graph()->name());
+ else
+ s.args().append("then_branch", std::to_string(circle_while->cond_branch()));
+
+ if (circle_while->body_graph() != nullptr)
+ s.args().append("else_graph", circle_while->body_graph()->name());
+ else
+ s.args().append("else_branch", std::to_string(circle_while->body_branch()));
+}
+
+std::vector<std::string> CircleOutputSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"from"};
+}
+
+std::vector<std::string> CircleTopKV2OutSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"topkv2"};
+}
+
+std::vector<std::string> CircleUniqueOutSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"unique"};
+}
+
+std::vector<std::string> CircleUnpackOutSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"unpack"};
+}
+
+std::vector<std::string> CircleWhileOutSummaryBuilder::get_input_names(const luci::CircleNode *)
+{
+ return {"while"};
+}
+
+} // namespace luci
diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilders.h b/compiler/luci/logex/src/CircleNodeSummaryBuilders.h
new file mode 100644
index 000000000..6cd24b7f1
--- /dev/null
+++ b/compiler/luci/logex/src/CircleNodeSummaryBuilders.h
@@ -0,0 +1,821 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_LOGEX_CIRCLE_NODE_SUMMARY_BUILDERS__
+#define __LUCI_LOGEX_CIRCLE_NODE_SUMMARY_BUILDERS__
+
+#include "CircleNodeSummaryBuilder.h"
+
+#include <luci/IR/CircleNode.h>
+
+#include <string>
+#include <vector>
+
+namespace luci
+{
+
+class CircleNodeWithXSummaryBuilder : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleNodeWithINPUTSummaryBuilder : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleNodeWithXYSummaryBuilder : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleNodeWithFEATURESSummaryBuilder : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+template <class REDUCER_NODE>
+class CircleNodeWithReducerSummaryBuilder : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *)
+ {
+ return {"input", "reduction_indices"};
+ }
+
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s)
+ {
+ auto mean = loco::must_cast<const REDUCER_NODE *>(node);
+ s.args().append("keep_dims", mean->keep_dims() ? "true" : "false");
+ }
+};
+
+} // namespace luci
+
+namespace luci
+{
+
+class CircleAbsSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleAddSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleAddNSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *node);
+};
+
+class CircleArgMaxSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleArgMinSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleAveragePool2DSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleBatchMatMulSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+private:
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleBatchToSpaceNDSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleBCQFullyConnectedSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleBCQGatherSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleBidirectionalSequenceLSTMSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleCastSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleCeilSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleConcatenationSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ std::vector<std::string> get_input_names(const luci::CircleNode *node);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleConstSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ void update_status(locop::NodeSummary &s);
+};
+
+class CircleConv2DSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleCosSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleCustomSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *node);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleDepthToSpaceSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+private:
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleDepthwiseConv2DSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleDequantizeSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+};
+
+class CircleDivSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleEluSummaryBuilder final : public CircleNodeWithFEATURESSummaryBuilder
+{
+};
+
+class CircleEqualSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleExpSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleExpandDimsSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleFakeQuantSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleFillSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleFloorSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleFloorDivSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleFloorModSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleFullyConnectedSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleGatherSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleGatherNdSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleGreaterSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleGreaterEqualSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleIfSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *node);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleInstanceNormSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleL2NormalizeSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleL2Pool2DSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleLeakyReluSummaryBuilder final : public CircleNodeWithFEATURESSummaryBuilder
+{
+private:
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleLessSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleLessEqualSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleLocalResponseNormalizationSummaryBuilder final
+ : public CircleNodeWithINPUTSummaryBuilder
+{
+private:
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleLogSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleLogicalAndSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleLogicalNotSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleLogicalOrSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleLogisticSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleLogSoftmaxSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleMatrixDiagSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleMatrixSetDiagSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleMaximumSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleMaxPool2DSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleMeanSummaryBuilder final : public CircleNodeWithReducerSummaryBuilder<luci::CircleMean>
+{
+};
+
+class CircleMinimumSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleMirrorPadSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleMulSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleNegSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleNonMaxSuppressionV4SummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleNonMaxSuppressionV5SummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleNotEqualSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleOneHotSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CirclePackSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *node);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CirclePadSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CirclePadV2SummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CirclePowSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CirclePReluSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleQuantizeSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+};
+
+class CircleRangeSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleRankSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+};
+
+class CircleReduceAnySummaryBuilder final
+ : public CircleNodeWithReducerSummaryBuilder<luci::CircleReduceAny>
+{
+};
+
+class CircleReduceMaxSummaryBuilder final
+ : public CircleNodeWithReducerSummaryBuilder<luci::CircleReduceMax>
+{
+};
+
+class CircleReduceMinSummaryBuilder final
+ : public CircleNodeWithReducerSummaryBuilder<luci::CircleReduceMin>
+{
+};
+
+class CircleReduceProdSummaryBuilder final
+ : public CircleNodeWithReducerSummaryBuilder<luci::CircleReduceProd>
+{
+};
+
+class CircleReluSummaryBuilder final : public CircleNodeWithFEATURESSummaryBuilder
+{
+};
+
+class CircleRelu6SummaryBuilder final : public CircleNodeWithFEATURESSummaryBuilder
+{
+};
+
+class CircleReluN1To1SummaryBuilder final : public CircleNodeWithFEATURESSummaryBuilder
+{
+};
+
+class CircleReshapeSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void update_status(locop::NodeSummary &s);
+};
+
+class CircleResizeBilinearSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleResizeNearestNeighborSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleReverseSequenceSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleReverseV2SummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleRoundSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleRsqrtSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleScatterNdSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleSegmentSumSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleSelectSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleSelectV2SummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleShapeSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+private:
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleSinSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleSliceSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleSoftmaxSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleSpaceToBatchNDSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleSpaceToDepthSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+private:
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleSparseToDenseSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleSplitSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleSplitVSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleSqrtSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleSquareSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleSquaredDifferenceSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleSqueezeSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+private:
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleStridedSliceSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleSubSummaryBuilder final : public CircleNodeWithXYSummaryBuilder
+{
+};
+
+class CircleSumSummaryBuilder final : public CircleNodeWithReducerSummaryBuilder<luci::CircleSum>
+{
+};
+
+class CircleSVDFSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleTanhSummaryBuilder final : public CircleNodeWithXSummaryBuilder
+{
+};
+
+class CircleTileSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleTopKV2SummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleTransposeSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleTransposeConvSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ bool validate(const luci::CircleNode *node);
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleUnidirectionalSequenceLSTMSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleUniqueSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+private:
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleUnpackSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleWhereSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleWhileSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *node);
+ void build_attributes(const luci::CircleNode *node, locop::NodeSummary &s);
+};
+
+class CircleZerosLikeSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+};
+
+class CircleBidirectionalSequenceLSTMOutSummaryBuilder final
+ : public CircleNodeWithINPUTSummaryBuilder
+{
+};
+
+class CircleCustomOutSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+};
+
+class CircleIfOutSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+};
+
+class CircleInputSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+};
+
+class CircleNonMaxSuppressionV4OutSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+};
+
+class CircleNonMaxSuppressionV5OutSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+};
+
+class CircleOutputSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleOutputDummySummaryBuilder final : public CircleNodeSummaryBuilder
+{
+};
+
+class CircleOutputExcludeSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+};
+
+class CircleSplitOutSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+};
+
+class CircleSplitVOutSummaryBuilder final : public CircleNodeWithINPUTSummaryBuilder
+{
+};
+
+class CircleTopKV2OutSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleUniqueOutSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleUnpackOutSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+class CircleVariableSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+};
+
+class CircleWhileOutSummaryBuilder final : public CircleNodeSummaryBuilder
+{
+private:
+ std::vector<std::string> get_input_names(const luci::CircleNode *);
+};
+
+} // namespace luci
+
+#endif // __LUCI_LOGEX_CIRCLE_NODE_SUMMARY_BUILDERS__
diff --git a/compiler/luci/logex/src/FormattedGraph.cpp b/compiler/luci/logex/src/FormattedGraph.cpp
index 0588ed79e..d3b2170b0 100644
--- a/compiler/luci/logex/src/FormattedGraph.cpp
+++ b/compiler/luci/logex/src/FormattedGraph.cpp
@@ -14,6 +14,7 @@
* limitations under the License.
*/
+#include "CircleNodeSummaryBuilder.h"
#include "luci/FormattedGraph.h"
#include <luci/IR/CircleDialect.h>
@@ -25,2179 +26,6 @@
#include <sstream>
#include <vector>
-using namespace luci;
-/**
- * @brief dump std::vector<int64_t> values to stream
- */
-std::ostream &operator<<(std::ostream &os, const std::vector<int64_t> &vi64)
-{
- for (auto vi : vi64)
- {
- os << vi << " ";
- }
- return os;
-}
-
-// For TF lite
-namespace
-{
-
-const char *to_str(loco::DataType type)
-{
- switch (type)
- {
- case loco::DataType::U8:
- return "UINT8";
- case loco::DataType::U16:
- return "UINT16";
- case loco::DataType::U32:
- return "UINT32";
- case loco::DataType::U64:
- return "UINT64";
-
- case loco::DataType::S8:
- return "INT8";
- case loco::DataType::S16:
- return "INT16";
- case loco::DataType::S32:
- return "INT32";
- case loco::DataType::S64:
- return "INT64";
-
- case loco::DataType::FLOAT16:
- return "FLOAT16";
- case loco::DataType::FLOAT32:
- return "FLOAT32";
- case loco::DataType::FLOAT64:
- return "FLOAT64";
-
- case loco::DataType::BOOL:
- return "BOOL";
-
- default:
- return "Error";
- }
-}
-
-const char *to_str(bool value) { return value ? "true" : "false"; }
-
-const char *to_str(luci::FusedActFunc fused)
-{
- switch (fused)
- {
- case luci::FusedActFunc::NONE:
- return "NONE";
- case luci::FusedActFunc::RELU:
- return "RELU";
- case luci::FusedActFunc::RELU_N1_TO_1:
- return "RELU_N1_TO_1";
- case luci::FusedActFunc::RELU6:
- return "RELU6";
- case luci::FusedActFunc::TANH:
- return "TANH";
- case luci::FusedActFunc::SIGN_BIT:
- return "SIGN_BIT";
- default:
- return "Error";
- }
-}
-
-const char *to_str(luci::Padding padding)
-{
- switch (padding)
- {
- case luci::Padding::SAME:
- return "SAME";
- case luci::Padding::VALID:
- return "VALID";
- default:
- return "Error";
- }
-}
-
-const char *to_str(luci::MirrorPadMode mode)
-{
- switch (mode)
- {
- case luci::MirrorPadMode::REFLECT:
- return "REFLECT";
- case luci::MirrorPadMode::SYMMETRIC:
- return "SYMMETRIC";
- default:
- return "Error";
- }
-}
-
-std::string to_str(const luci::Stride *stride)
-{
- return pepper::str(stride->h(), ",", stride->w());
-}
-
-std::string to_str(const luci::Filter *filter)
-{
- return pepper::str(filter->h(), ",", filter->w());
-}
-
-std::string circle_opname(uint32_t opnum)
-{
- static const std::string prefix{"circle."};
-
- switch (static_cast<luci::CircleOpcode>(opnum))
- {
-#define CIRCLE_NODE(OPCODE, CLASS) \
- case luci::CircleOpcode::OPCODE: \
- return prefix + #OPCODE;
-#define CIRCLE_VNODE CIRCLE_NODE
-#include <luci/IR/CircleNodes.lst>
-#undef CIRCLE_VNODE
-#undef CIRCLE_NODE
- default:
- break;
- };
-
- return prefix + "Invalid";
-}
-
-// CircleNodeSummaryBuilder with default implementation
-class CircleNodeSummaryBuilderBase : public locop::NodeSummaryBuilder
-{
-public:
- CircleNodeSummaryBuilderBase(const locop::SymbolTable *tbl) : _tbl{tbl}
- {
- // DO NOTHING
- }
-
-public:
- bool build(const loco::Node *, locop::NodeSummary &s) const final;
-
-protected:
-#define CIRCLE_NODE(OPCODE, CLASS) \
- virtual bool summary(const CLASS *, locop::NodeSummary &) const { return false; }
-#define CIRCLE_VNODE CIRCLE_NODE
-#include <luci/IR/CircleNodes.lst>
-#undef CIRCLE_VNODE
-#undef CIRCLE_NODE
-
-protected:
- const locop::SymbolTable *tbl(void) const { return _tbl; }
-
-private:
- const locop::SymbolTable *_tbl;
-};
-
-template <class CIRCLENODE>
-bool use_x(const locop::SymbolTable *tbl, const CIRCLENODE *node, locop::NodeSummary &s)
-{
- s.args().append("x", tbl->lookup(node->x()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-template <class CIRCLENODE>
-bool use_input(const locop::SymbolTable *tbl, const CIRCLENODE *node, locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-template <class CIRCLENODE>
-bool use_features(const locop::SymbolTable *tbl, const CIRCLENODE *node, locop::NodeSummary &s)
-{
- s.args().append("features", tbl->lookup(node->features()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-template <class CIRCLENODE>
-bool use_xy(const locop::SymbolTable *tbl, const CIRCLENODE *node, locop::NodeSummary &s)
-{
- s.args().append("x", tbl->lookup(node->x()));
- s.args().append("y", tbl->lookup(node->y()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-template <class CIRCLENODE>
-bool use_xy_act(const locop::SymbolTable *tbl, const CIRCLENODE *node, locop::NodeSummary &s)
-{
- assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED);
-
- s.args().append("x", tbl->lookup(node->x()));
- s.args().append("y", tbl->lookup(node->y()));
- s.args().append("fused_activation_function", to_str(node->fusedActivationFunction()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-template <class CIRCLENODE>
-bool use_reducer(const locop::SymbolTable *tbl, const CIRCLENODE *node, locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("reduction_indices", tbl->lookup(node->reduction_indices()));
- s.args().append("keep_dims", node->keep_dims() ? "true" : "false");
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-template <class CIRCLENODE>
-bool use_ido(const locop::SymbolTable *tbl, const CIRCLENODE *node, locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("dimension", tbl->lookup(node->dimension()));
- s.args().append("output_type", to_str(node->output_type()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleAddN *node,
- locop::NodeSummary &s)
-{
- for (uint32_t i = 0; i < node->arity(); ++i)
- s.args().append("inputs", tbl->lookup(node->inputs(i)));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleAveragePool2D *node,
- locop::NodeSummary &s)
-{
- assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED);
-
- s.args().append("value", tbl->lookup(node->value()));
- s.args().append("filter(h,w)", to_str(node->filter()));
- s.args().append("stride(h,w)", to_str(node->stride()));
- s.args().append("padding", to_str(node->padding()));
- s.args().append("fused", to_str(node->fusedActivationFunction()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleBatchMatMul *node,
- locop::NodeSummary &s)
-{
- s.args().append("x", tbl->lookup(node->x()));
- s.args().append("y", tbl->lookup(node->y()));
- s.args().append("adj_x", to_str(node->adj_x()));
- s.args().append("adj_y", to_str(node->adj_y()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleBatchToSpaceND *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("block_shape", tbl->lookup(node->block_shape()));
- s.args().append("crops", tbl->lookup(node->crops()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleBidirectionalSequenceLSTM *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
-
- s.args().append("fw_input_to_input_weights", tbl->lookup(node->fw_input_to_input_weights()));
- s.args().append("fw_input_to_forget_weights", tbl->lookup(node->fw_input_to_forget_weights()));
- s.args().append("fw_input_to_cell_weights", tbl->lookup(node->fw_input_to_cell_weights()));
- s.args().append("fw_input_to_output_weights", tbl->lookup(node->fw_input_to_output_weights()));
-
- s.args().append("fw_recurrent_to_input_weights",
- tbl->lookup(node->fw_recurrent_to_input_weights()));
- s.args().append("fw_recurrent_to_forget_weights",
- tbl->lookup(node->fw_recurrent_to_forget_weights()));
- s.args().append("fw_recurrent_to_cell_weights",
- tbl->lookup(node->fw_recurrent_to_cell_weights()));
- s.args().append("fw_recurrent_to_output_weights",
- tbl->lookup(node->fw_recurrent_to_output_weights()));
-
- s.args().append("fw_cell_to_input_weights", tbl->lookup(node->fw_cell_to_input_weights()));
- s.args().append("fw_cell_to_forget_weights", tbl->lookup(node->fw_cell_to_forget_weights()));
- s.args().append("fw_cell_to_output_weights", tbl->lookup(node->fw_cell_to_output_weights()));
-
- s.args().append("fw_input_gate_bias", tbl->lookup(node->fw_input_gate_bias()));
- s.args().append("fw_forget_gate_bias", tbl->lookup(node->fw_forget_gate_bias()));
- s.args().append("fw_cell_gate_bias", tbl->lookup(node->fw_cell_gate_bias()));
- s.args().append("fw_output_gate_bias", tbl->lookup(node->fw_output_gate_bias()));
-
- s.args().append("fw_projection_weights", tbl->lookup(node->fw_projection_weights()));
- s.args().append("fw_projection_bias", tbl->lookup(node->fw_projection_bias()));
-
- s.args().append("bw_input_to_input_weights", tbl->lookup(node->bw_input_to_input_weights()));
- s.args().append("bw_input_to_forget_weights", tbl->lookup(node->bw_input_to_forget_weights()));
- s.args().append("bw_input_to_cell_weights", tbl->lookup(node->bw_input_to_cell_weights()));
- s.args().append("bw_input_to_output_weights", tbl->lookup(node->bw_input_to_output_weights()));
-
- s.args().append("bw_recurrent_to_input_weights",
- tbl->lookup(node->bw_recurrent_to_input_weights()));
- s.args().append("bw_recurrent_to_forget_weights",
- tbl->lookup(node->bw_recurrent_to_forget_weights()));
- s.args().append("bw_recurrent_to_cell_weights",
- tbl->lookup(node->bw_recurrent_to_cell_weights()));
- s.args().append("bw_recurrent_to_output_weights",
- tbl->lookup(node->bw_recurrent_to_output_weights()));
-
- s.args().append("bw_cell_to_input_weights", tbl->lookup(node->bw_cell_to_input_weights()));
- s.args().append("bw_cell_to_forget_weights", tbl->lookup(node->bw_cell_to_forget_weights()));
- s.args().append("bw_cell_to_output_weights", tbl->lookup(node->bw_cell_to_output_weights()));
-
- s.args().append("bw_input_gate_bias", tbl->lookup(node->bw_input_gate_bias()));
- s.args().append("bw_forget_gate_bias", tbl->lookup(node->bw_forget_gate_bias()));
- s.args().append("bw_cell_gate_bias", tbl->lookup(node->bw_cell_gate_bias()));
- s.args().append("bw_output_gate_bias", tbl->lookup(node->bw_output_gate_bias()));
-
- s.args().append("bw_projection_weights", tbl->lookup(node->bw_projection_weights()));
- s.args().append("bw_projection_bias", tbl->lookup(node->bw_projection_bias()));
-
- s.args().append("fw_activation_state", tbl->lookup(node->fw_activation_state()));
- s.args().append("fw_cell_state", tbl->lookup(node->fw_cell_state()));
- s.args().append("bw_activation_state", tbl->lookup(node->bw_activation_state()));
- s.args().append("bw_cell_state", tbl->lookup(node->bw_cell_state()));
-
- s.args().append("auxillary_input", tbl->lookup(node->auxillary_input()));
- s.args().append("fw_auxillary_input_to_input_weights",
- tbl->lookup(node->fw_auxillary_input_to_input_weights()));
- s.args().append("fw_auxillary_input_to_forget_weights",
- tbl->lookup(node->fw_auxillary_input_to_forget_weights()));
- s.args().append("fw_auxillary_input_to_cell_weights",
- tbl->lookup(node->fw_auxillary_input_to_cell_weights()));
- s.args().append("fw_auxillary_input_to_output_weights",
- tbl->lookup(node->fw_auxillary_input_to_output_weights()));
- s.args().append("bw_auxillary_input_to_input_weights",
- tbl->lookup(node->bw_auxillary_input_to_input_weights()));
- s.args().append("bw_auxillary_input_to_forget_weights",
- tbl->lookup(node->bw_auxillary_input_to_forget_weights()));
- s.args().append("bw_auxillary_input_to_cell_weights",
- tbl->lookup(node->bw_auxillary_input_to_cell_weights()));
- s.args().append("bw_auxillary_input_to_output_weights",
- tbl->lookup(node->bw_auxillary_input_to_output_weights()));
-
- s.args().append("cell_clip", to_str(node->cell_clip()));
- s.args().append("proj_clip", to_str(node->proj_clip()));
- s.args().append("merge_outputs", to_str(node->merge_outputs()));
- s.args().append("time_major", to_str(node->time_major()));
- s.args().append("asymmetric_quantize_inputs", to_str(node->asymmetric_quantize_inputs()));
-
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleCast *node,
- locop::NodeSummary &s)
-{
- s.args().append("x", tbl->lookup(node->x()));
- s.args().append("in_data_type", to_str(node->in_data_type()));
- s.args().append("out_data_type", to_str(node->out_data_type()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleConcatenation *node,
- locop::NodeSummary &s)
-{
- assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED);
-
- for (uint32_t i = 0; i < node->numValues(); ++i)
- s.args().append("values", tbl->lookup(node->values(i)));
- s.args().append("axis", pepper::str(node->axis()));
- s.args().append("fused", to_str(node->fusedActivationFunction()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleConv2D *node,
- locop::NodeSummary &s)
-{
- assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED);
- assert(node->padding() != luci::Padding::UNDEFINED);
-
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("filter", tbl->lookup(node->filter()));
- s.args().append("bias", tbl->lookup(node->bias()));
- s.args().append("stride(h,w)", to_str(node->stride()));
- s.args().append("dilation(h,w)", to_str(node->dilation()));
- s.args().append("padding", to_str(node->padding()));
- s.args().append("fused", to_str(node->fusedActivationFunction()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleCustom *node,
- locop::NodeSummary &s)
-{
- for (uint32_t i = 0; i < node->numInputs(); i++)
- {
- s.args().append("input" + std::to_string(i), tbl->lookup(node->inputs(i)));
- }
- s.args().append("custom_code", node->custom_code());
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleDepthToSpace *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("block_size", std::to_string(node->block_size()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleDepthwiseConv2D *node,
- locop::NodeSummary &s)
-{
- assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED);
- assert(node->padding() != luci::Padding::UNDEFINED);
-
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("filter", tbl->lookup(node->filter()));
- s.args().append("bias", tbl->lookup(node->bias()));
- s.args().append("stride(h,w)", to_str(node->stride()));
- s.args().append("dilation(h,w)", to_str(node->dilation()));
- s.args().append("padding", to_str(node->padding()));
- s.args().append("depthMultiplier", std::to_string(node->depthMultiplier()));
- s.args().append("fused", to_str(node->fusedActivationFunction()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleExpandDims *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("axis", tbl->lookup(node->axis()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleFakeQuant *node,
- locop::NodeSummary &s)
-{
- s.args().append("inputs", tbl->lookup(node->inputs()));
- s.args().append("min", pepper::str(node->min()));
- s.args().append("max", pepper::str(node->max()));
- s.args().append("num_bits", pepper::str(node->num_bits()));
- s.args().append("narrow_range", node->narrow_range() ? "true" : "false");
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleFill *node,
- locop::NodeSummary &s)
-{
- s.args().append("dims", tbl->lookup(node->dims()));
- s.args().append("value", tbl->lookup(node->value()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleFullyConnected *node,
- locop::NodeSummary &s)
-{
- assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED);
-
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("weights", tbl->lookup(node->weights()));
- s.args().append("bias", tbl->lookup(node->bias()));
- s.args().append("fused", to_str(node->fusedActivationFunction()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleGather *node,
- locop::NodeSummary &s)
-{
- s.args().append("params", tbl->lookup(node->params()));
- s.args().append("indices", tbl->lookup(node->indices()));
- s.args().append("axis", pepper::str(node->axis()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleGatherNd *node,
- locop::NodeSummary &s)
-{
- s.args().append("params", tbl->lookup(node->params()));
- s.args().append("indices", tbl->lookup(node->indices()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleIf *node, locop::NodeSummary &s)
-{
- s.args().append("cond", tbl->lookup(node->cond()));
- for (uint32_t i = 0; i < node->input_count(); ++i)
- s.args().append("input", tbl->lookup(node->input(i)));
-
- if (node->then_graph() != nullptr)
- s.args().append("then_graph", node->then_graph()->name());
- else
- s.args().append("then_branch", pepper::str(node->then_branch()));
-
- if (node->else_graph() != nullptr)
- s.args().append("else_graph", node->else_graph()->name());
- else
- s.args().append("else_branch", pepper::str(node->else_branch()));
-
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleL2Normalize *node,
- locop::NodeSummary &s)
-{
- s.args().append("x", tbl->lookup(node->x()));
- s.args().append("fused_activation_function", to_str(node->fusedActivationFunction()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleL2Pool2D *node,
- locop::NodeSummary &s)
-{
- assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED);
-
- s.args().append("value", tbl->lookup(node->value()));
- s.args().append("filter(h,w)", to_str(node->filter()));
- s.args().append("stride(h,w)", to_str(node->stride()));
- s.args().append("padding", to_str(node->padding()));
- s.args().append("fused", to_str(node->fusedActivationFunction()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleLeakyRelu *node,
- locop::NodeSummary &s)
-{
- s.args().append("features", tbl->lookup(node->features()));
- s.args().append("alpha", std::to_string(node->alpha()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleLocalResponseNormalization *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("radius", pepper::str(node->radius()));
- s.args().append("bias", pepper::str(node->bias()));
- s.args().append("alpha", pepper::str(node->alpha()));
- s.args().append("beta", pepper::str(node->beta()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleLogSoftmax *node,
- locop::NodeSummary &s)
-{
- s.args().append("logits", tbl->lookup(node->logits()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleMatrixDiag *node,
- locop::NodeSummary &s)
-{
- s.args().append("diagonal", tbl->lookup(node->diagonal()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleMatrixSetDiag *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("diagonal", tbl->lookup(node->diagonal()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleMaxPool2D *node,
- locop::NodeSummary &s)
-{
- assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED);
-
- s.args().append("value", tbl->lookup(node->value()));
- s.args().append("filter(h,w)", to_str(node->filter()));
- s.args().append("stride(h,w)", to_str(node->stride()));
- s.args().append("padding", to_str(node->padding()));
- s.args().append("fused", to_str(node->fusedActivationFunction()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleMirrorPad *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("paddings", tbl->lookup(node->paddings()));
- s.args().append("mode", to_str(node->mode()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleNonMaxSuppressionV4 *node,
- locop::NodeSummary &s)
-{
- s.args().append("boxes", tbl->lookup(node->boxes()));
- s.args().append("scores", tbl->lookup(node->scores()));
- s.args().append("max_output_size", tbl->lookup(node->max_output_size()));
- s.args().append("iou_threshold", tbl->lookup(node->iou_threshold()));
- s.args().append("score_threshold", tbl->lookup(node->score_threshold()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleNonMaxSuppressionV5 *node,
- locop::NodeSummary &s)
-{
- s.args().append("boxes", tbl->lookup(node->boxes()));
- s.args().append("scores", tbl->lookup(node->scores()));
- s.args().append("max_output_size", tbl->lookup(node->max_output_size()));
- s.args().append("iou_threshold", tbl->lookup(node->iou_threshold()));
- s.args().append("score_threshold", tbl->lookup(node->score_threshold()));
- s.args().append("soft_nms_sigma", tbl->lookup(node->soft_nms_sigma()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleOneHot *node,
- locop::NodeSummary &s)
-{
- s.args().append("indices", tbl->lookup(node->indices()));
- s.args().append("depth", tbl->lookup(node->depth()));
- s.args().append("on_value", tbl->lookup(node->on_value()));
- s.args().append("off_value", tbl->lookup(node->off_value()));
- s.args().append("axis", pepper::str(node->axis()));
-
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CirclePack *node,
- locop::NodeSummary &s)
-{
- for (uint32_t i = 0; i < node->values_count(); ++i)
- s.args().append("values", tbl->lookup(node->values(i)));
- s.args().append("values_count", pepper::str(node->values_count()));
- s.args().append("axis", pepper::str(node->axis()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CirclePad *node, locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("paddings", tbl->lookup(node->paddings()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CirclePadV2 *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("paddings", tbl->lookup(node->paddings()));
- s.args().append("constant_values", tbl->lookup(node->constant_values()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CirclePRelu *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("alpha", tbl->lookup(node->alpha()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleRange *node,
- locop::NodeSummary &s)
-{
- s.args().append("start", tbl->lookup(node->start()));
- s.args().append("limit", tbl->lookup(node->limit()));
- s.args().append("delta", tbl->lookup(node->delta()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleReshape *node,
- locop::NodeSummary &s)
-{
- s.args().append("tensor", tbl->lookup(node->tensor()));
- s.args().append("shape", tbl->lookup(node->shape()));
- // TODO Show newShape info
- s.state(locop::NodeSummary::State::PartiallyKnown);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleResizeBilinear *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("size", tbl->lookup(node->size()));
- s.args().append("align_corners", node->align_corners() ? "true" : "false");
- s.args().append("half_pixel_centers", node->half_pixel_centers() ? "true" : "false");
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleResizeNearestNeighbor *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("size", tbl->lookup(node->size()));
- s.args().append("align_corners", node->align_corners() ? "true" : "false");
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleReverseSequence *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("seq_lengths", tbl->lookup(node->seq_lengths()));
- s.args().append("seq_axis", std::to_string(node->seq_axis()));
- s.args().append("batch_axis", std::to_string(node->batch_axis()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleReverseV2 *node,
- locop::NodeSummary &s)
-{
- s.args().append("tensor", tbl->lookup(node->tensor()));
- s.args().append("axis", tbl->lookup(node->axis()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleScatterNd *node,
- locop::NodeSummary &s)
-{
- s.args().append("indices", tbl->lookup(node->indices()));
- s.args().append("updates", tbl->lookup(node->updates()));
- s.args().append("shape", tbl->lookup(node->shape()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSegmentSum *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("segment_ids", tbl->lookup(node->segment_ids()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSelect *node,
- locop::NodeSummary &s)
-{
- s.args().append("condition", tbl->lookup(node->condition()));
- s.args().append("t", tbl->lookup(node->t()));
- s.args().append("e", tbl->lookup(node->e()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSelectV2 *node,
- locop::NodeSummary &s)
-{
- s.args().append("condition", tbl->lookup(node->condition()));
- s.args().append("t", tbl->lookup(node->t()));
- s.args().append("e", tbl->lookup(node->e()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleShape *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("out_type", to_str(node->out_type()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSlice *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("begin", tbl->lookup(node->begin()));
- s.args().append("size", tbl->lookup(node->size()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSoftmax *node,
- locop::NodeSummary &s)
-{
- s.args().append("logits", tbl->lookup(node->logits()));
- s.args().append("beta", pepper::str(node->beta()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSpaceToBatchND *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("block_shape", tbl->lookup(node->block_shape()));
- s.args().append("paddings", tbl->lookup(node->paddings()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSpaceToDepth *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("block_size", pepper::str(node->block_size()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSparseToDense *node,
- locop::NodeSummary &s)
-{
- s.args().append("indices", tbl->lookup(node->indices()));
- s.args().append("output_shape", tbl->lookup(node->output_shape()));
- s.args().append("values", tbl->lookup(node->values()));
- s.args().append("default_value", tbl->lookup(node->default_value()));
- s.args().append("Validate_indices", pepper::str(node->validate_indices()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSplit *node,
- locop::NodeSummary &s)
-{
- s.args().append("split_dim", tbl->lookup(node->split_dim()));
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("num_split", pepper::str(node->num_split()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSplitV *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("size_splits", tbl->lookup(node->size_splits()));
- s.args().append("split_dim", tbl->lookup(node->split_dim()));
- s.args().append("num_split", pepper::str(node->num_split()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleSqueeze *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
-
- std::stringstream ss{"("};
- for (size_t i = 0; i < node->squeeze_dims().size(); ++i)
- {
- if (i != 0)
- ss << ", ";
- ss << node->squeeze_dims()[i];
- }
- ss << ")";
- s.args().append("squeeze_dims", ss.str());
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleStridedSlice *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("begin", tbl->lookup(node->begin()));
- s.args().append("end", tbl->lookup(node->end()));
- s.args().append("strides", tbl->lookup(node->strides()));
- s.args().append("begin_mask", pepper::str(node->begin_mask()));
- s.args().append("end_mask", pepper::str(node->end_mask()));
- s.args().append("ellipsis_mask", pepper::str(node->ellipsis_mask()));
- s.args().append("new_axis_mask", pepper::str(node->new_axis_mask()));
- s.args().append("shrink_axis_mask", pepper::str(node->shrink_axis_mask()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleTile *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("multiples", tbl->lookup(node->multiples()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleTopKV2 *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("k", tbl->lookup(node->k()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleTranspose *node,
- locop::NodeSummary &s)
-{
- s.args().append("a", tbl->lookup(node->a()));
- s.args().append("perm", tbl->lookup(node->perm()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleTransposeConv *node,
- locop::NodeSummary &s)
-{
- assert(node->padding() != luci::Padding::UNDEFINED);
-
- s.args().append("inputSizes", tbl->lookup(node->inputSizes()));
- s.args().append("filter", tbl->lookup(node->filter()));
- s.args().append("outBackprop", tbl->lookup(node->outBackprop()));
- s.args().append("bias", tbl->lookup(node->bias()));
- s.args().append("stride(h,w)", to_str(node->stride()));
- s.args().append("padding", to_str(node->padding()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleUnidirectionalSequenceLSTM *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
-
- s.args().append("input_to_input_weights", tbl->lookup(node->input_to_input_weights()));
- s.args().append("input_to_forget_weights", tbl->lookup(node->input_to_forget_weights()));
- s.args().append("input_to_cell_weights", tbl->lookup(node->input_to_cell_weights()));
- s.args().append("input_to_output_weights", tbl->lookup(node->input_to_output_weights()));
-
- s.args().append("recurrent_to_input_weights", tbl->lookup(node->recurrent_to_input_weights()));
- s.args().append("recurrent_to_forget_weights", tbl->lookup(node->recurrent_to_forget_weights()));
- s.args().append("recurrent_to_cell_weights", tbl->lookup(node->recurrent_to_cell_weights()));
- s.args().append("recurrent_to_output_weights", tbl->lookup(node->recurrent_to_output_weights()));
-
- s.args().append("cell_to_input_weights", tbl->lookup(node->cell_to_input_weights()));
- s.args().append("cell_to_forget_weights", tbl->lookup(node->cell_to_forget_weights()));
- s.args().append("cell_to_output_weights", tbl->lookup(node->cell_to_output_weights()));
-
- s.args().append("input_gate_bias", tbl->lookup(node->input_gate_bias()));
- s.args().append("forget_gate_bias", tbl->lookup(node->forget_gate_bias()));
- s.args().append("cell_gate_bias", tbl->lookup(node->cell_gate_bias()));
- s.args().append("output_gate_bias", tbl->lookup(node->output_gate_bias()));
-
- s.args().append("projection_weights", tbl->lookup(node->projection_weights()));
- s.args().append("projection_bias", tbl->lookup(node->projection_bias()));
-
- s.args().append("activation_state", tbl->lookup(node->activation_state()));
- s.args().append("cell_state", tbl->lookup(node->cell_state()));
-
- s.args().append("input_layer_norm_coefficients",
- tbl->lookup(node->input_layer_norm_coefficients()));
- s.args().append("forget_layer_norm_coefficients",
- tbl->lookup(node->forget_layer_norm_coefficients()));
- s.args().append("cell_layer_norm_coefficients",
- tbl->lookup(node->cell_layer_norm_coefficients()));
- s.args().append("output_layer_norm_coefficients",
- tbl->lookup(node->output_layer_norm_coefficients()));
-
- s.args().append("cell_clip", to_str(node->cell_clip()));
- s.args().append("proj_clip", to_str(node->proj_clip()));
- s.args().append("time_major", to_str(node->time_major()));
- s.args().append("asymmetric_quantize_inputs", to_str(node->asymmetric_quantize_inputs()));
-
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleUnique *node,
- locop::NodeSummary &s)
-{
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("idx_out_type", to_str(node->idx_out_type()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleUnpack *node,
- locop::NodeSummary &s)
-{
- s.args().append("value", tbl->lookup(node->value()));
- s.args().append("num", pepper::str(node->num()));
- s.args().append("axis", pepper::str(node->axis()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleWhere *node,
- locop::NodeSummary &s)
-{
- s.args().append("condition", tbl->lookup(node->condition()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleWhile *node,
- locop::NodeSummary &s)
-{
- for (uint32_t i = 0; i < node->input_count(); ++i)
- s.args().append("input", tbl->lookup(node->input(i)));
-
- if (node->cond_graph() != nullptr)
- s.args().append("cond_graph", node->cond_graph()->name());
- else
- s.args().append("cond_branch", pepper::str(node->cond_branch()));
-
- if (node->body_graph() != nullptr)
- s.args().append("body_graph", node->body_graph()->name());
- else
- s.args().append("body_branch", pepper::str(node->body_branch()));
-
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleTopKV2Out *node,
- locop::NodeSummary &s)
-{
- s.args().append("topkv2", tbl->lookup(node->input()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleUniqueOut *node,
- locop::NodeSummary &s)
-{
- s.args().append("unique", tbl->lookup(node->input()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleUnpackOut *node,
- locop::NodeSummary &s)
-{
- s.args().append("unpack", tbl->lookup(node->input()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleWhileOut *node,
- locop::NodeSummary &s)
-{
- s.args().append("while", tbl->lookup(node->input()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleOutput *node,
- locop::NodeSummary &s)
-{
- s.args().append("from", tbl->lookup(node->from()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *, const luci::CircleOutputDummy *,
- locop::NodeSummary &s)
-{
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *, const luci::CircleOutputExclude *,
- locop::NodeSummary &s)
-{
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleBCQFullyConnected *node,
- locop::NodeSummary &s)
-{
- assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED);
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("weights_scales", tbl->lookup(node->weights_scales()));
- s.args().append("weights_binary", tbl->lookup(node->weights_binary()));
- s.args().append("bias", tbl->lookup(node->bias()));
- s.args().append("weights_clusters", tbl->lookup(node->weights_clusters()));
- s.args().append("fused", to_str(node->fusedActivationFunction()));
- s.args().append("weights_hidden_size", pepper::str(node->weights_hidden_size()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleBCQGather *node,
- locop::NodeSummary &s)
-{
- s.args().append("input_scales", tbl->lookup(node->input_scales()));
- s.args().append("input_binary", tbl->lookup(node->input_binary()));
- s.args().append("indices", tbl->lookup(node->indices()));
- s.args().append("input_clusters", tbl->lookup(node->input_clusters()));
- s.args().append("axis", pepper::str(node->axis()));
- s.args().append("input_hidden_size", pepper::str(node->input_hidden_size()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool summary_node(const locop::SymbolTable *tbl, const luci::CircleInstanceNorm *node,
- locop::NodeSummary &s)
-{
- auto fused = node->fusedActivationFunction();
- assert(fused != luci::FusedActFunc::UNDEFINED);
-
- s.args().append("input", tbl->lookup(node->input()));
- s.args().append("gamma", tbl->lookup(node->gamma()));
- s.args().append("beta", tbl->lookup(node->beta()));
- s.args().append("epsilon", pepper::str(node->epsilon()));
- s.args().append("fused_activation_function", to_str(fused));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-// SummaryBuilderLet type
-enum class SB
-{
- ABC,
- DEF,
- GHIJ,
- KLMN,
- OPQR,
- STUV,
- WXYZ,
- CIRC, // circle only
- VIRT, // virtual
-};
-
-template <SB sb> class SummaryBuilderLet;
-
-#define IMPLEMENT(CLASS) bool summary(const CLASS *, locop::NodeSummary &) const final;
-
-template <> class SummaryBuilderLet<SB::ABC> final : public CircleNodeSummaryBuilderBase
-{
-public:
- SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
- {
- // DO NOTHING
- }
-
-private:
- IMPLEMENT(luci::CircleAbs)
- IMPLEMENT(luci::CircleAdd)
- IMPLEMENT(luci::CircleAddN)
- IMPLEMENT(luci::CircleArgMax)
- IMPLEMENT(luci::CircleArgMin)
- IMPLEMENT(luci::CircleAveragePool2D)
- IMPLEMENT(luci::CircleBatchMatMul)
- IMPLEMENT(luci::CircleBatchToSpaceND)
- IMPLEMENT(luci::CircleBidirectionalSequenceLSTM)
- IMPLEMENT(luci::CircleCast)
- IMPLEMENT(luci::CircleCeil)
- IMPLEMENT(luci::CircleConcatenation)
- IMPLEMENT(luci::CircleConst)
- IMPLEMENT(luci::CircleConv2D)
- IMPLEMENT(luci::CircleCos)
- IMPLEMENT(luci::CircleCustom)
-};
-
-template <> class SummaryBuilderLet<SB::DEF> final : public CircleNodeSummaryBuilderBase
-{
-public:
- SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
- {
- // DO NOTHING
- }
-
-private:
- IMPLEMENT(luci::CircleDepthToSpace)
- IMPLEMENT(luci::CircleDepthwiseConv2D)
- IMPLEMENT(luci::CircleDequantize)
- IMPLEMENT(luci::CircleDiv)
- IMPLEMENT(luci::CircleElu)
- IMPLEMENT(luci::CircleEqual)
- IMPLEMENT(luci::CircleExp)
- IMPLEMENT(luci::CircleExpandDims)
- IMPLEMENT(luci::CircleFakeQuant)
- IMPLEMENT(luci::CircleFill)
- IMPLEMENT(luci::CircleFloor)
- IMPLEMENT(luci::CircleFloorDiv)
- IMPLEMENT(luci::CircleFloorMod)
- IMPLEMENT(luci::CircleFullyConnected)
-};
-
-template <> class SummaryBuilderLet<SB::GHIJ> final : public CircleNodeSummaryBuilderBase
-{
-public:
- SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
- {
- // DO NOTHING
- }
-
-private:
- IMPLEMENT(luci::CircleGather)
- IMPLEMENT(luci::CircleGatherNd)
- IMPLEMENT(luci::CircleGreater)
- IMPLEMENT(luci::CircleGreaterEqual)
- IMPLEMENT(luci::CircleIf)
-};
-
-template <> class SummaryBuilderLet<SB::KLMN> final : public CircleNodeSummaryBuilderBase
-{
-public:
- SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
- {
- // DO NOTHING
- }
-
-private:
- IMPLEMENT(luci::CircleL2Normalize)
- IMPLEMENT(luci::CircleL2Pool2D)
- IMPLEMENT(luci::CircleLeakyRelu)
- IMPLEMENT(luci::CircleLess)
- IMPLEMENT(luci::CircleLessEqual)
- IMPLEMENT(luci::CircleLocalResponseNormalization)
- IMPLEMENT(luci::CircleLog)
- IMPLEMENT(luci::CircleLogicalAnd)
- IMPLEMENT(luci::CircleLogicalNot)
- IMPLEMENT(luci::CircleLogicalOr)
- IMPLEMENT(luci::CircleLogistic)
- IMPLEMENT(luci::CircleLogSoftmax)
- IMPLEMENT(luci::CircleMatrixDiag)
- IMPLEMENT(luci::CircleMatrixSetDiag)
- IMPLEMENT(luci::CircleMaximum)
- IMPLEMENT(luci::CircleMaxPool2D)
- IMPLEMENT(luci::CircleMean)
- IMPLEMENT(luci::CircleMinimum)
- IMPLEMENT(luci::CircleMirrorPad)
- IMPLEMENT(luci::CircleMul)
- IMPLEMENT(luci::CircleNeg)
- IMPLEMENT(luci::CircleNonMaxSuppressionV4)
- IMPLEMENT(luci::CircleNonMaxSuppressionV5)
- IMPLEMENT(luci::CircleNotEqual)
-};
-
-template <> class SummaryBuilderLet<SB::OPQR> final : public CircleNodeSummaryBuilderBase
-{
-public:
- SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
- {
- // DO NOTHING
- }
-
-private:
- IMPLEMENT(luci::CircleOneHot)
- IMPLEMENT(luci::CirclePack)
- IMPLEMENT(luci::CirclePad)
- IMPLEMENT(luci::CirclePadV2)
- IMPLEMENT(luci::CirclePow)
- IMPLEMENT(luci::CirclePRelu)
- IMPLEMENT(luci::CircleQuantize)
- IMPLEMENT(luci::CircleRange)
- IMPLEMENT(luci::CircleRank)
- IMPLEMENT(luci::CircleReduceAny)
- IMPLEMENT(luci::CircleReduceMax)
- IMPLEMENT(luci::CircleReduceMin)
- IMPLEMENT(luci::CircleReduceProd)
- IMPLEMENT(luci::CircleRelu)
- IMPLEMENT(luci::CircleRelu6)
- IMPLEMENT(luci::CircleReluN1To1)
- IMPLEMENT(luci::CircleReshape)
- IMPLEMENT(luci::CircleResizeBilinear)
- IMPLEMENT(luci::CircleResizeNearestNeighbor)
- IMPLEMENT(luci::CircleReverseSequence)
- IMPLEMENT(luci::CircleReverseV2)
- IMPLEMENT(luci::CircleRound)
- IMPLEMENT(luci::CircleRsqrt)
-};
-
-template <> class SummaryBuilderLet<SB::STUV> final : public CircleNodeSummaryBuilderBase
-{
-public:
- SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
- {
- // DO NOTHING
- }
-
-private:
- IMPLEMENT(luci::CircleScatterNd)
- IMPLEMENT(luci::CircleSegmentSum)
- IMPLEMENT(luci::CircleSelect)
- IMPLEMENT(luci::CircleSelectV2)
- IMPLEMENT(luci::CircleShape)
- IMPLEMENT(luci::CircleSin)
- IMPLEMENT(luci::CircleSlice)
- IMPLEMENT(luci::CircleSoftmax)
- IMPLEMENT(luci::CircleSpaceToBatchND)
- IMPLEMENT(luci::CircleSpaceToDepth)
- IMPLEMENT(luci::CircleSparseToDense)
- IMPLEMENT(luci::CircleSplit)
- IMPLEMENT(luci::CircleSplitV)
- IMPLEMENT(luci::CircleSqrt)
- IMPLEMENT(luci::CircleSquare)
- IMPLEMENT(luci::CircleSquaredDifference)
- IMPLEMENT(luci::CircleSqueeze)
- IMPLEMENT(luci::CircleStridedSlice)
- IMPLEMENT(luci::CircleSub)
- IMPLEMENT(luci::CircleSum)
- IMPLEMENT(luci::CircleTanh)
- IMPLEMENT(luci::CircleTile)
- IMPLEMENT(luci::CircleTopKV2)
- IMPLEMENT(luci::CircleTranspose)
- IMPLEMENT(luci::CircleTransposeConv)
- IMPLEMENT(luci::CircleUnidirectionalSequenceLSTM)
- IMPLEMENT(luci::CircleUnique)
- IMPLEMENT(luci::CircleUnpack)
-};
-
-template <> class SummaryBuilderLet<SB::WXYZ> final : public CircleNodeSummaryBuilderBase
-{
-public:
- SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
- {
- // DO NOTHING
- }
-
-private:
- IMPLEMENT(luci::CircleWhere)
- IMPLEMENT(luci::CircleWhile)
- IMPLEMENT(luci::CircleZerosLike)
-};
-
-template <> class SummaryBuilderLet<SB::CIRC> final : public CircleNodeSummaryBuilderBase
-{
-public:
- SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
- {
- // DO NOTHING
- }
-
-private:
- IMPLEMENT(luci::CircleBCQFullyConnected)
- IMPLEMENT(luci::CircleBCQGather)
- IMPLEMENT(luci::CircleInstanceNorm)
-};
-
-template <> class SummaryBuilderLet<SB::VIRT> final : public CircleNodeSummaryBuilderBase
-{
-public:
- SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
- {
- // DO NOTHING
- }
-
-private:
- IMPLEMENT(luci::CircleInput)
- IMPLEMENT(luci::CircleOutput)
- IMPLEMENT(luci::CircleCustomOut)
- IMPLEMENT(luci::CircleIfOut)
- IMPLEMENT(luci::CircleNonMaxSuppressionV4Out)
- IMPLEMENT(luci::CircleNonMaxSuppressionV5Out)
- IMPLEMENT(luci::CircleOutputDummy)
- IMPLEMENT(luci::CircleOutputExclude)
- IMPLEMENT(luci::CircleSplitOut)
- IMPLEMENT(luci::CircleSplitVOut)
- IMPLEMENT(luci::CircleTopKV2Out)
- IMPLEMENT(luci::CircleUniqueOut)
- IMPLEMENT(luci::CircleUnpackOut)
- IMPLEMENT(luci::CircleWhileOut)
-};
-
-#undef IMPLEMENT
-
-bool CircleNodeSummaryBuilderBase::build(const loco::Node *node, locop::NodeSummary &s) const
-{
- if (node->dialect() != luci::CircleDialect::get())
- return false;
-
- auto ptr_to_str = [](const void *ptr) {
- std::stringstream ss;
- ss << ptr;
- return ss.str();
- };
-
- auto add_comment = [&]() {
- auto cnode = loco::must_cast<const luci::CircleNode *>(node);
- s.opname(circle_opname(node->opnum()));
- s.comments().append("[" + cnode->name() + "] = " + ptr_to_str(node));
- };
-
-#define CIRCLE_NODE(OPCODE, CLASS) \
- if (dynamic_cast<const CLASS *>(node)) \
- { \
- if (summary(dynamic_cast<const CLASS *>(node), s)) \
- { \
- add_comment(); \
- return true; \
- } \
- }
-#define CIRCLE_VNODE CIRCLE_NODE
-#include <luci/IR/CircleNodes.lst>
-#undef CIRCLE_VNODE
-#undef CIRCLE_NODE
-
- return false;
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleAbs *node, locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleAdd *node, locop::NodeSummary &s) const
-{
- return use_xy_act(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleAddN *node, locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleArgMax *node,
- locop::NodeSummary &s) const
-{
- return use_ido(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleArgMin *node,
- locop::NodeSummary &s) const
-{
- return use_ido(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleAveragePool2D *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleBatchMatMul *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleBatchToSpaceND *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleBidirectionalSequenceLSTM *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleCast *node, locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleCeil *node, locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleConcatenation *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleConst *, locop::NodeSummary &s) const
-{
- s.state(locop::NodeSummary::State::PartiallyKnown);
- return true;
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleConv2D *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleCos *node, locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleCustom *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleDepthToSpace *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleDepthwiseConv2D *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleDequantize *node,
- locop::NodeSummary &s) const
-{
- return use_input(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleDiv *node, locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleElu *node, locop::NodeSummary &s) const
-{
- return use_features(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleEqual *node, locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleExp *node, locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleExpandDims *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleFakeQuant *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleFill *node, locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleFloor *node, locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleFloorDiv *node,
- locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleFloorMod *node,
- locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleFullyConnected *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::GHIJ>::summary(const luci::CircleGather *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::GHIJ>::summary(const luci::CircleGatherNd *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::GHIJ>::summary(const luci::CircleGreater *node,
- locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::GHIJ>::summary(const luci::CircleGreaterEqual *node,
- locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::GHIJ>::summary(const luci::CircleIf *node, locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleL2Normalize *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleL2Pool2D *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLess *node, locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLessEqual *node,
- locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLeakyRelu *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLocalResponseNormalization *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLog *node, locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLogicalAnd *node,
- locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLogicalNot *node,
- locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLogicalOr *node,
- locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLogistic *node,
- locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLogSoftmax *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMatrixDiag *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMatrixSetDiag *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMaximum *node,
- locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMaxPool2D *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMean *node, locop::NodeSummary &s) const
-{
- return use_reducer(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMinimum *node,
- locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMirrorPad *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMul *node, locop::NodeSummary &s) const
-{
- return use_xy_act(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleNeg *node, locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleNonMaxSuppressionV4 *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleNonMaxSuppressionV5 *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleNotEqual *node,
- locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleOneHot *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CirclePack *node, locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CirclePad *node, locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CirclePadV2 *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CirclePow *node, locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CirclePRelu *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleQuantize *node,
- locop::NodeSummary &s) const
-{
- return use_input(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleRange *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleRank *node, locop::NodeSummary &s) const
-{
- return use_input(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReduceAny *node,
- locop::NodeSummary &s) const
-{
- return use_reducer(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReduceMax *node,
- locop::NodeSummary &s) const
-{
- return use_reducer(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReduceMin *node,
- locop::NodeSummary &s) const
-{
- return use_reducer(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReduceProd *node,
- locop::NodeSummary &s) const
-{
- return use_reducer(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleRelu *node, locop::NodeSummary &s) const
-{
- return use_features(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleRelu6 *node,
- locop::NodeSummary &s) const
-{
- return use_features(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReluN1To1 *node,
- locop::NodeSummary &s) const
-{
- return use_features(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReshape *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleResizeBilinear *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleResizeNearestNeighbor *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReverseSequence *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReverseV2 *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleRound *node,
- locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleRsqrt *node,
- locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleScatterNd *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSegmentSum *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSelect *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSelectV2 *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleShape *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSin *node, locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSlice *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSoftmax *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSpaceToBatchND *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSpaceToDepth *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSparseToDense *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSplit *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSplitV *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSqrt *node, locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSquare *node,
- locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSquaredDifference *node,
- locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSqueeze *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleStridedSlice *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSub *node, locop::NodeSummary &s) const
-{
- return use_xy(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSum *node, locop::NodeSummary &s) const
-{
- return use_reducer(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleTanh *node, locop::NodeSummary &s) const
-{
- return use_x(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleTile *node, locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleTopKV2 *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleTranspose *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleTransposeConv *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleUnidirectionalSequenceLSTM *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleUnique *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleUnpack *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::WXYZ>::summary(const luci::CircleWhere *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::WXYZ>::summary(const luci::CircleWhile *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::WXYZ>::summary(const luci::CircleZerosLike *node,
- locop::NodeSummary &s) const
-{
- return use_input(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::CIRC>::summary(const luci::CircleBCQFullyConnected *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::CIRC>::summary(const luci::CircleBCQGather *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::CIRC>::summary(const luci::CircleInstanceNorm *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleInput *, locop::NodeSummary &s) const
-{
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleOutput *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleCustomOut *node,
- locop::NodeSummary &s) const
-{
- return use_input(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleIfOut *node,
- locop::NodeSummary &s) const
-{
- return use_input(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleNonMaxSuppressionV4Out *node,
- locop::NodeSummary &s) const
-{
- return use_input(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleNonMaxSuppressionV5Out *node,
- locop::NodeSummary &s) const
-{
- return use_input(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleOutputDummy *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleOutputExclude *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleSplitOut *node,
- locop::NodeSummary &s) const
-{
- return use_input(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleSplitVOut *node,
- locop::NodeSummary &s) const
-{
- return use_input(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleTopKV2Out *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleUniqueOut *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleUnpackOut *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleWhileOut *node,
- locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
-} // namespace
-
namespace luci
{
@@ -2208,22 +36,10 @@ bool NodeSummaryBuilder::build(const loco::Node *node, locop::NodeSummary &s) co
return true;
}
-#define BUILD_GRP(GRP) \
- do \
- { \
- if (SummaryBuilderLet<SB::GRP>(_tbl).build(node, s)) \
- return true; \
- } while (false)
-
- BUILD_GRP(ABC);
- BUILD_GRP(DEF);
- BUILD_GRP(GHIJ);
- BUILD_GRP(KLMN);
- BUILD_GRP(OPQR);
- BUILD_GRP(STUV);
- BUILD_GRP(WXYZ);
- BUILD_GRP(CIRC);
- BUILD_GRP(VIRT);
+ if (CircleNodeSummaryBuilder().build(node, _tbl, s))
+ {
+ return true;
+ }
return false;
}
diff --git a/compiler/luci/partition/CMakeLists.txt b/compiler/luci/partition/CMakeLists.txt
index ec8e0b0d6..f28207df2 100644
--- a/compiler/luci/partition/CMakeLists.txt
+++ b/compiler/luci/partition/CMakeLists.txt
@@ -13,7 +13,7 @@ target_link_libraries(luci_partition PUBLIC luci_lang)
target_link_libraries(luci_partition PRIVATE luci_service)
target_link_libraries(luci_partition PRIVATE luci_log)
target_link_libraries(luci_partition PRIVATE luci_logex)
-target_link_libraries(luci_partition PRIVATE mio_circle)
+target_link_libraries(luci_partition PRIVATE mio_circle04)
target_link_libraries(luci_partition PRIVATE nncc_common)
target_link_libraries(luci_partition PRIVATE pepper_csv2vec)
target_link_libraries(luci_partition PRIVATE oops)
diff --git a/compiler/luci/partition/src/ConnectNode.h b/compiler/luci/partition/src/ConnectNode.h
index ebbff7a6a..e60567c69 100644
--- a/compiler/luci/partition/src/ConnectNode.h
+++ b/compiler/luci/partition/src/ConnectNode.h
@@ -161,6 +161,7 @@ public:
void visit(const luci::CircleSquaredDifference *) final;
void visit(const luci::CircleSqueeze *) final;
void visit(const luci::CircleStridedSlice *) final;
+ void visit(const luci::CircleSVDF *) final;
void visit(const luci::CircleSub *) final;
void visit(const luci::CircleSum *) final;
void visit(const luci::CircleTanh *) final;
@@ -197,6 +198,7 @@ public:
void visit(const luci::CircleTopKV2Out *) final;
void visit(const luci::CircleUniqueOut *) final;
void visit(const luci::CircleUnpackOut *) final;
+ void visit(const luci::CircleVariable *) final;
void visit(const luci::CircleWhileOut *) final;
public:
diff --git a/compiler/luci/partition/src/Nodes/CircleSVDF.cpp b/compiler/luci/partition/src/Nodes/CircleSVDF.cpp
new file mode 100644
index 000000000..f661a794c
--- /dev/null
+++ b/compiler/luci/partition/src/Nodes/CircleSVDF.cpp
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleSVDF *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleSVDF *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *weight_feature = loco::must_cast<luci::CircleNode *>(node->weight_feature());
+ luci::CircleNode *weight_time = loco::must_cast<luci::CircleNode *>(node->weight_time());
+ luci::CircleNode *bias = loco::must_cast<luci::CircleNode *>(node->bias());
+ luci::CircleNode *input_activation_state =
+ loco::must_cast<luci::CircleNode *>(node->input_activation_state());
+
+ cloned->input(cn->find_clone(input));
+ cloned->weight_feature(cn->find_clone(weight_feature));
+ cloned->weight_time(cn->find_clone(weight_time));
+ cloned->bias(cn->find_clone(bias));
+ cloned->input_activation_state(cn->find_clone(input_activation_state));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleSVDF *node) { connect(this, node); }
+
+} // namespace luci
diff --git a/compiler/luci/partition/src/Nodes/CircleSVDF.test.cpp b/compiler/luci/partition/src/Nodes/CircleSVDF.test.cpp
new file mode 100644
index 000000000..5fae5206e
--- /dev/null
+++ b/compiler/luci/partition/src/Nodes/CircleSVDF.test.cpp
@@ -0,0 +1,106 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleSVDF>
+{
+public:
+ NodeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g)
+ {
+ NodeGraphletT<luci::CircleSVDF>::init(g);
+
+ _node->fusedActivationFunction(luci::FusedActFunc::RELU);
+ }
+};
+
+class TestNodeGraph : public TestIsOGraph<5>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<5>::init({shape, shape, shape, shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->weight_feature(input(1));
+ node()->weight_time(input(2));
+ node()->bias(input(3));
+ node()->input_activation_state(input(4));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_SVDF)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSVDF *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSVDF *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(5, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+ ASSERT_EQ(cth.inputs(3), clone->arg(3));
+ ASSERT_EQ(cth.inputs(4), clone->arg(4));
+}
+
+TEST(ConnectNodeTest, connect_SVDF_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSVDF *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSVDF *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
diff --git a/compiler/luci/partition/src/Nodes/CircleVariable.cpp b/compiler/luci/partition/src/Nodes/CircleVariable.cpp
new file mode 100644
index 000000000..f7f6f21fd
--- /dev/null
+++ b/compiler/luci/partition/src/Nodes/CircleVariable.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleVariable *)
+{
+ // Nothing to do
+}
+
+} // namespace luci
diff --git a/compiler/luci/partition/src/PartitionIRDump.cpp b/compiler/luci/partition/src/PartitionIRDump.cpp
index 4f2c26800..0fabfc416 100644
--- a/compiler/luci/partition/src/PartitionIRDump.cpp
+++ b/compiler/luci/partition/src/PartitionIRDump.cpp
@@ -32,18 +32,18 @@ void dump(std::ostream &os, const PNode *pnode)
void dump(std::ostream &os, const PGroup *pgroup)
{
os << "--- PGroup: " << pgroup->group << std::endl;
- os << "Input(s): ";
+ os << "Input(s): [ ";
for (auto &node_in : pgroup->inputs)
os << node_in->name() << " ";
- os << std::endl;
+ os << "]" << std::endl;
for (auto &pnode : pgroup->pnodes)
{
dump(os, pnode.get());
}
- os << "Output(s): ";
+ os << "Output(s): [ ";
for (auto &node_out : pgroup->outputs)
os << node_out->name() << " ";
- os << std::endl;
+ os << "]" << std::endl;
}
void dump(std::ostream &os, const PGroups *pgroups)
@@ -57,7 +57,8 @@ void dump(std::ostream &os, const PGroups *pgroups)
{
auto node = it->first;
auto group = it->second;
- os << " Node: " << node << "(" << node->name() << "): " << group << std::endl;
+ os << " Node: " << node << "(" << luci::opcode_name(node) << "," << node->name()
+ << "): " << group << std::endl;
}
}
diff --git a/compiler/luci/partition/src/PartitionMerge.cpp b/compiler/luci/partition/src/PartitionMerge.cpp
index c517bf93f..4c3971bd8 100644
--- a/compiler/luci/partition/src/PartitionMerge.cpp
+++ b/compiler/luci/partition/src/PartitionMerge.cpp
@@ -58,9 +58,6 @@ bool is_input_same(const luci::PGroup *pgroup, const luci::PGroups *pgroups)
// we need to clone this CircleConst for each graph of the group.
if (dynamic_cast<const luci::CircleConst *>(input) != nullptr)
continue;
- // Skip also for OutputExclude
- if (dynamic_cast<const luci::CircleOutputExclude *>(input) != nullptr)
- continue;
auto input_group = pgroups->group_of(input);
// NOTE: all the nodes should be registered and return should be valid group.
@@ -87,7 +84,7 @@ bool is_input_same(const luci::PGroup *pgroup, const luci::PGroups *pgroups)
input_pgroup = pgroup_input;
else
{
- if (input_pgroup != pgroup_input)
+ if (input_pgroup->group != pgroup_input->group)
return false;
}
}
@@ -96,6 +93,48 @@ bool is_input_same(const luci::PGroup *pgroup, const luci::PGroups *pgroups)
}
/**
+ * @brief return true if there is only one output and is fed to same group of nodes
+ * @note pgroups is used to find group of pgroup
+ * ex)
+ * /-- pgroup_user_1 (grp_1)
+ * --- pgroup
+ * \-- pgroup_user_2 (grp_2)
+ *
+ * return false if grp_1 != grp_2
+ */
+bool is_output_same(const luci::PGroup *pgroup, const luci::PGroups *pgroups)
+{
+ assert(pgroups != nullptr);
+ assert(pgroup != nullptr);
+
+ std::string group;
+ for (auto &output : pgroup->outputs)
+ {
+ // get output_group
+ auto output_group = pgroups->group_of(output);
+ assert(not output_group.empty());
+ if (output_group.empty())
+ output_group = pgroups->default_group;
+
+ // find all PGroup that uses output
+ for (auto &pgroup_user : pgroups->pgroups)
+ {
+ for (auto &user_inputs : pgroup_user->inputs)
+ {
+ if (output == user_inputs)
+ {
+ // OK, these are connected, check group is same
+ if (pgroup_user->group != output_group)
+ return false;
+ }
+ }
+ }
+ }
+
+ return true;
+}
+
+/**
* @brief merge pgroup into pgroup_i
* @note output of pgroup_i should be input of pgroup
*/
@@ -191,6 +230,9 @@ std::unique_ptr<luci::PGroups> merge_pgroups(const luci::PGroups *s_pgroups)
// skip if there are multiple inputs but inputs differ in group
if (!is_input_same(pgroup.get(), d_pgroups.get()))
continue;
+ // skip if pgroup has different group for other users of pgroup_i
+ if (!is_output_same(pgroup_i.get(), d_pgroups.get()))
+ continue;
// TODO add more condition may be needed
merge_into(pgroup.get(), pgroup_i.get());
diff --git a/compiler/luci/partition/src/PartitionPGroups.cpp b/compiler/luci/partition/src/PartitionPGroups.cpp
index 0080873e6..eaeacf9c4 100644
--- a/compiler/luci/partition/src/PartitionPGroups.cpp
+++ b/compiler/luci/partition/src/PartitionPGroups.cpp
@@ -46,6 +46,9 @@ public:
bool visit(const luci::CircleUniqueOut *) final { return true; }
bool visit(const luci::CircleUnpackOut *) final { return true; }
bool visit(const luci::CircleWhileOut *) final { return true; }
+ // For inputs not used
+ bool visit(const luci::CircleOutputExclude *) final { return true; }
+ bool visit(const luci::CircleVariable *) final { return true; }
// TODO add all virtual nodes
// default is false
@@ -69,59 +72,80 @@ bool check_allocate_partition(const luci::CircleNode *node)
return true;
}
-class FindGroupToFollow final : public luci::CircleNodeVisitor<const std::string &>
+} // namespace
+
+namespace
{
-public:
- FindGroupToFollow(const luci::PartitionTable &partition, luci::PGroups *pgroups)
- : _partition(partition), _pgroups(pgroups)
- {
- // NOTHING TODO
- }
-private:
- const std::string &groupof(const luci::CircleNode *input) const
+std::string group_from_partition(const luci::CircleNode *node,
+ const luci::PartitionTable &partition)
+{
+ LOGGER(l);
+
+ auto group = partition.default_group;
+
+ std::string opcodename; // opcodename or opname
+
+ switch (partition.comply)
{
- auto group = _pgroups->node2group[input];
- assert(not group.empty());
- if (group.empty())
- return _partition.default_group;
- return _pgroups->node2group[input];
+ case luci::PartitionTable::COMPLY::OPCODE:
+ {
+ opcodename = luci::opcode_name(node);
+ assert(!opcodename.empty());
+
+ auto it = partition.byopcodes.find(opcodename);
+ if (it != partition.byopcodes.end())
+ group = it->second;
+ break;
+ }
+ case luci::PartitionTable::COMPLY::OPNAME:
+ {
+ opcodename = node->name();
+ assert(!opcodename.empty());
+
+ auto it = partition.byopnames.find(opcodename);
+ if (it != partition.byopnames.end())
+ group = it->second;
+ break;
+ }
+
+ default:
+ throw std::runtime_error("Unsupported partition.comply");
}
+ INFO(l) << "Op: " << node->name() << ": " << opcodename << ", " << node << ", " << group
+ << std::endl;
+
+ return group;
+}
+
+class IsVirtualInputNode final : public luci::CircleNodeVisitor<bool>
+{
public:
-#define IMPLEMENT(CLASS) \
- const std::string &visit(const luci::CLASS *node) final \
- { \
- auto input = loco::must_cast<luci::CircleNode *>(node->input()); \
- return groupof(input); \
- }
+ // TODO check CircleOutputDummy
+ bool visit(const luci::CircleOutputExclude *) final { return true; }
+ bool visit(const luci::CircleVariable *) final { return true; }
- IMPLEMENT(CircleCustomOut);
- IMPLEMENT(CircleIfOut);
- IMPLEMENT(CircleNonMaxSuppressionV4Out);
- IMPLEMENT(CircleNonMaxSuppressionV5Out);
- IMPLEMENT(CircleSplitOut);
- IMPLEMENT(CircleSplitVOut);
- IMPLEMENT(CircleTopKV2Out);
- IMPLEMENT(CircleUniqueOut);
- IMPLEMENT(CircleUnpackOut);
- IMPLEMENT(CircleWhileOut);
-
-#undef IMPLEMENT
-
- // return empty for nothing to do
- const std::string &visit(const luci::CircleNode *) final { return _empty_str; }
-
-private:
- const luci::PartitionTable &_partition;
- luci::PGroups *_pgroups = nullptr;
- std::string _empty_str;
+ // default is false
+ bool visit(const luci::CircleNode *) final { return false; }
};
-} // namespace
-
-namespace
+class IsMultiOutputNode final : public luci::CircleNodeVisitor<bool>
{
+public:
+ bool visit(const luci::CircleCustom *) final { return true; }
+ bool visit(const luci::CircleIf *) final { return true; }
+ bool visit(const luci::CircleNonMaxSuppressionV4 *) final { return true; }
+ bool visit(const luci::CircleNonMaxSuppressionV5 *) final { return true; }
+ bool visit(const luci::CircleSplit *) final { return true; }
+ bool visit(const luci::CircleSplitV *) final { return true; }
+ bool visit(const luci::CircleTopKV2 *) final { return true; }
+ bool visit(const luci::CircleUnique *) final { return true; }
+ bool visit(const luci::CircleUnpack *) final { return true; }
+ bool visit(const luci::CircleWhile *) final { return true; }
+ // default is false
+ bool visit(const luci::CircleNode *) final { return false; }
+};
void append(luci::CircleNode *node, luci::PGroups *pgroups, const std::string &group, uint32_t idx)
{
@@ -136,17 +160,56 @@ void append(luci::CircleNode *node, luci::PGroups *pgroups, const std::string &g
pgroup->pnodes.push_back(std::move(pnode));
+ IsVirtualInputNode queryvi;
// Set input of PGroup
for (uint32_t in = 0; in < node->arity(); ++in)
{
auto input = loco::must_cast<luci::CircleNode *>(node->arg(in));
- // this input maybe CircleInput in source graph
- // --> not confident this is safe
- pgroup->inputs.push_back(input);
+ if (input->accept(&queryvi))
+ {
+ auto pnode = std::make_unique<luci::PNode>();
+ pnode->node = input;
+ pnode->group = group;
+ pnode->pgroup = pgroup.get();
+
+ pgroup->pnodes.push_back(std::move(pnode));
+
+ pgroups->node2group[input] = group;
+ }
+ else
+ {
+ // this input maybe CircleInput in source graph
+ // --> not confident this is safe
+ pgroup->inputs.push_back(input);
+ }
+ }
+
+ IsMultiOutputNode query;
+ if (node->accept(&query))
+ {
+ // Include CircleXXXOut virtual nodes in this group
+ auto succs = loco::succs(node);
+ for (auto &succ_node : succs)
+ {
+ auto nodeout = loco::must_cast<luci::CircleNode *>(succ_node);
+
+ auto pnode = std::make_unique<luci::PNode>();
+ pnode->node = nodeout;
+ pnode->group = group;
+ pnode->pgroup = pgroup.get();
+
+ pgroup->pnodes.push_back(std::move(pnode));
+
+ pgroups->node2group[nodeout] = group;
+
+ pgroup->outputs.push_back(nodeout);
+ }
+ }
+ else
+ {
+ // Set output of PGroup: node itself
+ pgroup->outputs.push_back(node);
}
- // Set output of PGroup: node itself or multiple virtual outputs
- // TODO support multiple virtual outputs
- pgroup->outputs.push_back(node);
pgroups->node2group[node] = group;
pgroups->id2pgroup[pgroup->id] = pgroup.get();
@@ -182,70 +245,9 @@ std::unique_ptr<luci::PGroups> produce_pgroups(const luci::Module *source,
// check if node is normal node that we are interested
if (check_allocate_partition(node))
{
- auto group = partition.default_group;
-
- std::string opcodename; // opcodename or opname
-
- switch (partition.comply)
- {
- case luci::PartitionTable::COMPLY::OPCODE:
- {
- opcodename = luci::opcode_name(node);
- assert(!opcodename.empty());
-
- auto it = partition.byopcodes.find(opcodename);
- if (it != partition.byopcodes.end())
- group = it->second;
- break;
- }
- case luci::PartitionTable::COMPLY::OPNAME:
- {
- opcodename = node->name();
- assert(!opcodename.empty());
-
- auto it = partition.byopnames.find(opcodename);
- if (it != partition.byopnames.end())
- group = it->second;
- break;
- }
-
- default:
- throw std::runtime_error("Unsupported partition.comply");
- }
-
- INFO(l) << "Op: " << node->name() << ": " << opcodename << ", " << node << ", " << group
- << std::endl;
+ auto group = group_from_partition(node, partition);
append(node, pgroups.get(), group, idx);
-#if 0
- auto pgroup = std::make_unique<luci::PGroup>();
- pgroup->group = group;
- pgroup->id = idx + 1;
-
- auto pnode = std::make_unique<luci::PNode>();
- pnode->node = node;
- pnode->group = group;
- pnode->pgroup = pgroup.get();
-
- pgroup->pnodes.push_back(std::move(pnode));
-
- // Set input of PGroup
- for (uint32_t in = 0; in < node->arity(); ++in)
- {
- auto input = loco::must_cast<luci::CircleNode *>(node->arg(in));
- // this input maybe CircleInput in source graph
- // --> not confident this is safe
- pgroup->inputs.push_back(input);
- }
- // Set output of PGroup: node itself or multiple virtual outputs
- // TODO support multiple virtual outputs
- pgroup->outputs.push_back(node);
-
- pgroups->node2group[node] = group;
- pgroups->id2pgroup[pgroup->id] = pgroup.get();
-
- pgroups->pgroups.push_back(std::move(pgroup));
-#endif
}
else
{
@@ -255,22 +257,6 @@ std::unique_ptr<luci::PGroups> produce_pgroups(const luci::Module *source,
}
}
- // handle for virtual nodes like multiple outputs
- // these nodes should follow group of the input
- for (uint32_t idx = 0; idx < nodes->size(); ++idx)
- {
- auto node = loco::must_cast<luci::CircleNode *>(nodes->at(idx));
-
- // for virtual nodes like CircleUnpackOut should follow it's input (owner)
- // or just set to default
- FindGroupToFollow query(partition, pgroups.get());
- const auto &group = node->accept(&query);
- if (not group.empty())
- {
- append(node, pgroups.get(), group, idx);
- }
- }
-
return std::move(pgroups);
}
diff --git a/compiler/luci/pass/CMakeLists.txt b/compiler/luci/pass/CMakeLists.txt
index b8b406a38..5237c6d3f 100644
--- a/compiler/luci/pass/CMakeLists.txt
+++ b/compiler/luci/pass/CMakeLists.txt
@@ -1,4 +1,4 @@
-nnas_find_package(FlatBuffers EXACT 1.12 QUIET)
+nnas_find_package(FlatBuffers EXACT 2.0 QUIET)
if(NOT FlatBuffers_FOUND)
message(STATUS "FlatBuffers NOT FOUND")
return()
@@ -23,11 +23,11 @@ target_link_libraries(luci_pass PRIVATE luci_log)
target_link_libraries(luci_pass PRIVATE luci_service)
target_link_libraries(luci_pass PRIVATE luci_logex)
target_link_libraries(luci_pass PRIVATE luci_profile)
-target_link_libraries(luci_pass PRIVATE mio_tflite260_inc)
+target_link_libraries(luci_pass PRIVATE mio_tflite280_inc)
target_link_libraries(luci_pass PRIVATE nncc_common)
target_link_libraries(luci_pass PRIVATE pepper_csv2vec)
target_link_libraries(luci_pass PRIVATE oops)
-target_link_libraries(luci_pass PRIVATE flatbuffers-1.12)
+target_link_libraries(luci_pass PRIVATE flatbuffers-2.0)
install(TARGETS luci_pass DESTINATION lib)
install(DIRECTORY include/ DESTINATION include
FILES_MATCHING PATTERN "*.h")
@@ -43,5 +43,5 @@ target_include_directories(luci_pass_test PRIVATE src)
target_link_libraries(luci_pass_test luci_pass)
target_link_libraries(luci_pass_test luci_lang)
target_link_libraries(luci_pass_test luci_testhelper)
-target_link_libraries(luci_pass_test flatbuffers-1.12)
+target_link_libraries(luci_pass_test flatbuffers-2.0)
#target_link_libraries(luci_pass_test oops)
diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h
index 658563ecf..c803898f6 100644
--- a/compiler/luci/pass/include/luci/CircleOptimizer.h
+++ b/compiler/luci/pass/include/luci/CircleOptimizer.h
@@ -47,15 +47,12 @@ public:
ResolveCustomOpBatchMatMul,
ResolveCustomOpMatMul,
ResolveCustomOpMaxPoolWithArgmax,
- QuantizeDequantizeWeights,
- QuantizeWithMinMax,
- Requantize,
FoldAddV2,
FoldCast,
FoldDepthwiseConv2D,
FoldDequantize,
+ FoldGather,
FoldSparseToDense,
- ForceQuantParam,
ForwardReshapeToUnaryOp,
SparsifyTensorPass,
FusePreActivationBatchNorm,
@@ -79,6 +76,7 @@ public:
TransformMinReluToRelu6Pass,
SubstituteStridedSliceToReshape,
SubstituteTransposeToReshape,
+ RemoveRedundantQuantize,
RemoveRedundantReshape,
RemoveFakeQuant,
RemoveQuantDequantSeq,
@@ -86,16 +84,6 @@ public:
enum AlgorithmParameters
{
- // quantize
- Quantize_input_model_dtype,
- Quantize_output_model_dtype,
- Quantize_granularity, // layer-wise or channel-wise
- Quantize_tensor_names,
- Quantize_scales,
- Quantize_zero_points,
- Quantize_input_type,
- Quantize_output_type,
-
// sparsify
Sparsify_tensor_name,
Sparsify_traversal_order,
@@ -114,8 +102,6 @@ public:
virtual bool query(Algorithm) = 0;
virtual void param(AlgorithmParameters, const std::string &) = 0;
virtual const std::string param(AlgorithmParameters) const = 0;
- virtual void params(AlgorithmParameters, std::vector<std::string> &) = 0;
- virtual std::vector<std::string> params(AlgorithmParameters) const = 0;
};
public:
@@ -127,8 +113,6 @@ public:
void optimize(loco::Graph *) const;
- void quantize(loco::Graph *) const;
-
void sparsify(loco::Graph *) const;
private:
diff --git a/compiler/luci/pass/include/luci/CircleQuantizer.h b/compiler/luci/pass/include/luci/CircleQuantizer.h
new file mode 100644
index 000000000..4e7074d98
--- /dev/null
+++ b/compiler/luci/pass/include/luci/CircleQuantizer.h
@@ -0,0 +1,97 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_CIRCLE_QUANTIZER_H__
+#define __LUCI_CIRCLE_QUANTIZER_H__
+
+#include <loco.h>
+
+#include <string>
+#include <vector>
+
+namespace luci
+{
+
+class CircleQuantizer final
+{
+public:
+ struct Options
+ {
+ struct LayerParam
+ {
+ std::string name;
+ std::string dtype;
+ std::string granularity;
+ };
+
+ enum Algorithm
+ {
+ QuantizeDequantizeWeights,
+ QuantizeWithMinMax,
+ Requantize,
+ CopyQuantParam,
+ ForceQuantParam,
+ ConvertToFakeQuantizedModel,
+ };
+
+ enum AlgorithmParameters
+ {
+ // quantize
+ Quantize_input_model_dtype,
+ Quantize_output_model_dtype,
+ Quantize_granularity, // layer-wise or channel-wise
+ Quantize_tensor_names,
+ Quantize_scales,
+ Quantize_zero_points,
+ Quantize_layer_params,
+
+ // copy_quantparam
+ Quantize_src_tensor_names,
+ Quantize_dst_tensor_names,
+
+ Quantize_input_type,
+ Quantize_output_type,
+ Quantize_TF_style_maxpool,
+ };
+
+ virtual ~Options() = default;
+
+ virtual void enable(Algorithm) = 0;
+ virtual bool query(Algorithm) = 0;
+ virtual void param(AlgorithmParameters, const std::string &) = 0;
+ virtual const std::string param(AlgorithmParameters) const = 0;
+ virtual void params(AlgorithmParameters, std::vector<std::string> &) = 0;
+ virtual std::vector<std::string> params(AlgorithmParameters) const = 0;
+
+ // Quantization parameters for multiple layers
+ virtual void layer_params(AlgorithmParameters, std::vector<std::shared_ptr<LayerParam>> &) = 0;
+ virtual std::vector<std::shared_ptr<LayerParam>> layer_params(AlgorithmParameters) const = 0;
+ };
+
+public:
+ // TODO maybe caller can provide Options as ctor parameters
+ Options *options(void);
+
+public:
+ void quantize(loco::Graph *) const;
+
+private:
+ std::unique_ptr<Options> _options;
+};
+
+} // namespace luci
+
+#endif // __LUCI_CIRCLE_QUANTIZER_H__
diff --git a/compiler/luci/pass/include/luci/Pass/ConvertToFakeQuantizedModelPass.h b/compiler/luci/pass/include/luci/Pass/ConvertToFakeQuantizedModelPass.h
new file mode 100644
index 000000000..91dd2300e
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/ConvertToFakeQuantizedModelPass.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_CONVERT_TO_FAKE_QUANTIZED_MODEL_PASS_H__
+#define __LUCI_CONVERT_TO_FAKE_QUANTIZED_MODEL_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to convert a quantized model to a fake-quantized fp32 model.
+ */
+struct ConvertToFakeQuantizedModelPass final : public logo::Pass
+{
+ ConvertToFakeQuantizedModelPass() {}
+
+ const char *name(void) const final { return "luci::ConvertToFakeQuantizedModelPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_CONVERT_TO_FAKE_QUANTIZED_MODEL_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/CopyQuantParamPass.h b/compiler/luci/pass/include/luci/Pass/CopyQuantParamPass.h
new file mode 100644
index 000000000..18c9cd56a
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/CopyQuantParamPass.h
@@ -0,0 +1,53 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_COPY_QUANT_PARAM_PASS_H__
+#define __LUCI_COPY_QUANT_PARAM_PASS_H__
+
+#include <loco.h>
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Pass to copy quantparam (scale, zerop) of a tensor to another tensor
+ */
+class CopyQuantParamPass : public logo::Pass
+{
+public:
+ using TensorVector = std::vector<std::string>;
+
+public:
+ CopyQuantParamPass(TensorVector &src_tensors, TensorVector &dst_tensors)
+ : _src_tensors{src_tensors}, _dst_tensors{dst_tensors}
+ {
+ // DO NOTHING
+ }
+ virtual const char *name(void) const { return "luci::CopyQuantParamPass"; }
+
+public:
+ bool run(loco::Graph *graph);
+
+private:
+ TensorVector _src_tensors;
+ TensorVector _dst_tensors;
+};
+
+} // namespace luci
+
+#endif //__LUCI_COPY_QUANT_PARAM_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/FoldGatherPass.h b/compiler/luci/pass/include/luci/Pass/FoldGatherPass.h
new file mode 100644
index 000000000..de08c8845
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/FoldGatherPass.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_FOLD_GATHER_PASS_H__
+#define __LUCI_FOLD_GATHER_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to fold Gather to a constant tensor
+ *
+ */
+struct FoldGatherPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::FoldGatherPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_FOLD_GATHER_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/PropagateQParamBackwardPass.h b/compiler/luci/pass/include/luci/Pass/PropagateQParamBackwardPass.h
new file mode 100644
index 000000000..0c489fc30
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/PropagateQParamBackwardPass.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_PROPAGATE_QPARAM_BACKWARD_PASS_H__
+#define __LUCI_PROPAGATE_QPARAM_BACKWARD_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to propagate quantization parameters of an operator's output to input
+ */
+struct PropagateQParamBackwardPass final : public logo::Pass
+{
+ PropagateQParamBackwardPass(loco::DataType output) : _output_model_dtype(output) {}
+
+ const char *name(void) const final { return "luci::PropagateQParamBackwardPass"; }
+
+ bool run(loco::Graph *g) final;
+
+private:
+ loco::DataType _output_model_dtype;
+};
+
+} // namespace luci
+
+#endif // __LUCI_PROPAGATE_QPARAM_BACKWARD_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/PropagateQuantParamPass.h b/compiler/luci/pass/include/luci/Pass/PropagateQParamForwardPass.h
index 7e0c44b8c..952bd9614 100644
--- a/compiler/luci/pass/include/luci/Pass/PropagateQuantParamPass.h
+++ b/compiler/luci/pass/include/luci/Pass/PropagateQParamForwardPass.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef __LUCI_PROPAGATE_QUANT_PARAM_PASS_H__
-#define __LUCI_PROPAGATE_QUANT_PARAM_PASS_H__
+#ifndef __LUCI_PROPAGATE_QPARAM_FORWARD_PASS_H__
+#define __LUCI_PROPAGATE_QPARAM_FORWARD_PASS_H__
#include <logo/Pass.h>
@@ -23,15 +23,22 @@ namespace luci
{
/**
- * @brief Class to propagate quantization parameters of an operator's output to input
+ * @brief Class to propagate quantization parameters of an operator's input to output
*/
-struct PropagateQuantParamPass final : public logo::Pass
+struct PropagateQParamForwardPass final : public logo::Pass
{
- const char *name(void) const final { return "luci::PropagateQuantParamPass"; }
+ PropagateQParamForwardPass(bool TF_style_maxpool) : _TF_style_maxpool(TF_style_maxpool) {}
+
+ PropagateQParamForwardPass() {}
+
+ const char *name(void) const final { return "luci::PropagateQParamForwardPass"; }
bool run(loco::Graph *g) final;
+
+private:
+ bool _TF_style_maxpool = false;
};
} // namespace luci
-#endif // __LUCI_PROPAGATE_QUANT_PARAM_PASS_H__
+#endif // __LUCI_PROPAGATE_QPARAM_FORWARD_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/QuantizationParameters.h b/compiler/luci/pass/include/luci/Pass/QuantizationParameters.h
index 5c9cd427f..30c8db058 100644
--- a/compiler/luci/pass/include/luci/Pass/QuantizationParameters.h
+++ b/compiler/luci/pass/include/luci/Pass/QuantizationParameters.h
@@ -17,6 +17,10 @@
#ifndef __LUCI_QUANTIZATION_PARAMETERS_H__
#define __LUCI_QUANTIZATION_PARAMETERS_H__
+#include <loco.h>
+
+#include <string>
+
namespace luci
{
@@ -26,6 +30,13 @@ enum QuantizationGranularity
ChannelWise = 1,
};
+struct LayerInfo
+{
+ std::string name;
+ loco::DataType dtype;
+ QuantizationGranularity granularity;
+};
+
} // namespace luci
#endif // __LUCI_QUANTIZATION_PARAMETERS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/QuantizeDequantizeWeightsPass.h b/compiler/luci/pass/include/luci/Pass/QuantizeDequantizeWeightsPass.h
index 68765ec5b..1825ee1aa 100644
--- a/compiler/luci/pass/include/luci/Pass/QuantizeDequantizeWeightsPass.h
+++ b/compiler/luci/pass/include/luci/Pass/QuantizeDequantizeWeightsPass.h
@@ -32,12 +32,30 @@ namespace luci
class QuantizeDequantizeWeightsPass : public logo::Pass
{
public:
+ struct Context
+ {
+ loco::DataType input_model_dtype = loco::DataType::Unknown;
+ loco::DataType output_model_dtype = loco::DataType::Unknown;
+ QuantizationGranularity granularity = QuantizationGranularity::ChannelWise;
+ std::vector<LayerInfo> layers_info;
+ };
+
+public:
+ QuantizeDequantizeWeightsPass(std::unique_ptr<Context> &&ctx) : _ctx{std::move(ctx)}
+ {
+ // DO NOTHING
+ }
+
+public:
QuantizeDequantizeWeightsPass(loco::DataType input_model_dtype, loco::DataType output_model_dtype,
QuantizationGranularity granularity)
- : _input_model_dtype{input_model_dtype}, _output_model_dtype{output_model_dtype}, _granularity{
- granularity}
{
- // DO NOTHING
+ _ctx = std::make_unique<Context>();
+ {
+ _ctx->input_model_dtype = input_model_dtype;
+ _ctx->output_model_dtype = output_model_dtype;
+ _ctx->granularity = granularity;
+ }
}
virtual const char *name(void) const { return "luci::QuantizeDequantizeWeightsPass"; }
@@ -45,9 +63,7 @@ public:
bool run(loco::Graph *graph);
private:
- loco::DataType _input_model_dtype;
- loco::DataType _output_model_dtype;
- QuantizationGranularity _granularity;
+ std::unique_ptr<Context> _ctx;
};
} // namespace luci
diff --git a/compiler/luci/pass/include/luci/Pass/QuantizePreCheckerPass.h b/compiler/luci/pass/include/luci/Pass/QuantizePreCheckerPass.h
new file mode 100644
index 000000000..c852f88e0
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/QuantizePreCheckerPass.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_QUANTIZE_PRE_CHECKER_PASS_H__
+#define __LUCI_QUANTIZE_PRE_CHECKER_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Pass to verify the input model has the form acceptable by quantizer
+ */
+class QuantizePreCheckerPass : public logo::Pass
+{
+public:
+ const char *name(void) const final { return "luci::QuantizePreCheckerPass"; }
+
+public:
+ bool run(loco::Graph *graph) final;
+};
+
+} // namespace luci
+
+#endif //__LUCI_QUANTIZE_PRE_CHECKER_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h b/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h
index 648abad70..ea6db85d1 100644
--- a/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h
+++ b/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h
@@ -23,6 +23,8 @@
#include <luci/Pass/QuantizationParameters.h>
+#include <vector>
+
namespace luci
{
@@ -31,26 +33,41 @@ namespace luci
*/
class QuantizeWithMinMaxPass : public logo::Pass
{
+public:
+ struct Context
+ {
+ loco::DataType input_model_dtype = loco::DataType::Unknown;
+ loco::DataType output_model_dtype = loco::DataType::Unknown;
+ QuantizationGranularity granularity = QuantizationGranularity::ChannelWise;
+ loco::DataType input_type = loco::DataType::Unknown;
+ loco::DataType output_type = loco::DataType::Unknown;
+ bool TF_style_maxpool = false;
+ std::vector<LayerInfo> layers_info;
+ };
+
// For backward-compatibility
// TODO Remove this constructor
public:
QuantizeWithMinMaxPass(loco::DataType input_model_dtype, loco::DataType output_model_dtype,
QuantizationGranularity granularity)
- : _input_model_dtype{input_model_dtype}, _output_model_dtype{output_model_dtype},
- _granularity{granularity}, _input_type{output_model_dtype}, _output_type{output_model_dtype}
{
- // DO NOTHING
+ _ctx = std::make_unique<Context>();
+ {
+ _ctx->input_model_dtype = input_model_dtype;
+ _ctx->output_model_dtype = output_model_dtype;
+ _ctx->granularity = granularity;
+ _ctx->input_type = output_model_dtype;
+ _ctx->output_type = output_model_dtype;
+ _ctx->TF_style_maxpool = false;
+ }
}
public:
- QuantizeWithMinMaxPass(loco::DataType input_model_dtype, loco::DataType output_model_dtype,
- QuantizationGranularity granularity, loco::DataType input_type,
- loco::DataType output_type)
- : _input_model_dtype{input_model_dtype}, _output_model_dtype{output_model_dtype},
- _granularity{granularity}, _input_type{input_type}, _output_type{output_type}
+ QuantizeWithMinMaxPass(std::unique_ptr<Context> &&ctx) : _ctx{std::move(ctx)}
{
// DO NOTHING
}
+
virtual const char *name(void) const { return "luci::QuantizeWithMinMaxPass"; }
public:
@@ -61,11 +78,7 @@ private:
void set_output_type(loco::Graph *graph) const;
private:
- loco::DataType _input_model_dtype;
- loco::DataType _output_model_dtype;
- QuantizationGranularity _granularity;
- loco::DataType _input_type;
- loco::DataType _output_type;
+ std::unique_ptr<Context> _ctx;
};
} // namespace luci
diff --git a/compiler/luci/pass/include/luci/Pass/RemoveRedundantQuantizePass.h b/compiler/luci/pass/include/luci/Pass/RemoveRedundantQuantizePass.h
new file mode 100644
index 000000000..3e76bcdc3
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/RemoveRedundantQuantizePass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_REMOVE_REDUNDANT_QUANTIZE_PASS_H__
+#define __LUCI_REMOVE_REDUNDANT_QUANTIZE_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to remove redundant quantize operations
+ */
+struct RemoveRedundantQuantizePass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::RemoveRedundantQuantizePass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_REMOVE_REDUNDANT_QUANTIZE_PASS_H__
diff --git a/compiler/luci/pass/src/BatchNormPatternFinder.cpp b/compiler/luci/pass/src/BatchNormPatternFinder.cpp
index c1a06bfda..e3f126b15 100644
--- a/compiler/luci/pass/src/BatchNormPatternFinder.cpp
+++ b/compiler/luci/pass/src/BatchNormPatternFinder.cpp
@@ -44,10 +44,26 @@ bool is_batchnorm_add(const luci::CircleAdd *add, luci::CircleMul *&mul, luci::C
return false;
}
- if (constant->rank() != 1)
+ uint32_t channel_dim = 0;
+
+ if (constant->rank() == 1)
+ {
+ channel_dim = constant->dim(0).value();
+ }
+ else if (constant->rank() == 4)
+ {
+ for (uint32_t i = 0; i < 3; i++)
+ {
+ if (constant->dim(i).value() != 1)
+ return false;
+ }
+ channel_dim = constant->dim(3).value();
+ }
+ else
+ {
return false;
+ }
- auto channel_dim = constant->dim(0);
// Assumption: Layout is channel-last
if (!(channel_dim == add->dim(add->rank() - 1)))
return false;
@@ -90,10 +106,26 @@ bool is_batchnorm_mul(const luci::CircleMul *mul, luci::CircleNode *&pred_node,
return false;
}
- if (constant->rank() != 1)
+ uint32_t channel_dim = 0;
+
+ if (constant->rank() == 1)
+ {
+ channel_dim = constant->dim(0).value();
+ }
+ else if (constant->rank() == 4)
+ {
+ for (uint32_t i = 0; i < 3; i++)
+ {
+ if (constant->dim(i).value() != 1)
+ return false;
+ }
+ channel_dim = constant->dim(3).value();
+ }
+ else
+ {
return false;
+ }
- auto channel_dim = constant->dim(0);
// Assumption: Layout is channel-last
if (!(channel_dim == mul->dim(mul->rank() - 1)))
return false;
diff --git a/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp b/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp
index 08e7fac1c..cc8c5615f 100644
--- a/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp
+++ b/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp
@@ -50,7 +50,7 @@ public:
auto channel_size = *last_it;
_add->shape(shape);
- _add_beta->shape({channel_size});
+ set_beta_shape(channel_size);
_add_beta->size<loco::DataType::FLOAT32>(channel_size);
for (uint32_t i = 0; i < channel_size; i++)
_add_beta->at<loco::DataType::FLOAT32>(i) = i;
@@ -63,10 +63,23 @@ public:
luci::CircleAdd *add() { return _add; }
protected:
+ virtual void set_beta_shape(uint32_t channel) = 0;
+
+protected:
luci::CircleAdd *_add = nullptr;
luci::CircleConst *_add_beta = nullptr;
};
+class AddRank1BetaGraphlet : public AddBetaGraphlet
+{
+ void set_beta_shape(uint32_t channel) final { _add_beta->shape({channel}); }
+};
+
+class AddRank4BetaGraphlet : public AddBetaGraphlet
+{
+ void set_beta_shape(uint32_t channel) final { _add_beta->shape({1, 1, 1, channel}); }
+};
+
/**
* @brief Graphlet with Mul and Const as gamma from BatchNorm
*/
@@ -90,7 +103,7 @@ public:
auto channel_size = *last_it;
_mul->shape(shape);
- _mul_gamma->shape({channel_size});
+ set_gamma_shape(channel_size);
_mul_gamma->size<loco::DataType::FLOAT32>(channel_size);
for (uint32_t i = 0; i < channel_size; i++)
_mul_gamma->at<loco::DataType::FLOAT32>(i) = i;
@@ -103,14 +116,27 @@ public:
luci::CircleMul *mul(void) { return _mul; }
protected:
+ virtual void set_gamma_shape(uint32_t channel) = 0;
+
+protected:
luci::CircleMul *_mul = nullptr;
luci::CircleConst *_mul_gamma = nullptr;
};
+class MulRank1GammaGraphlet : public MulGammaGraphlet
+{
+ void set_gamma_shape(uint32_t channel) final { _mul_gamma->shape({channel}); }
+};
+
+class MulRank4GammaGraphlet : public MulGammaGraphlet
+{
+ void set_gamma_shape(uint32_t channel) final { _mul_gamma->shape({1, 1, 1, channel}); }
+};
+
/**
* @brief Graph of Mul-Add pattern from BatchNorm
*/
-class MulAddGraph : public TestIOGraph, public AddBetaGraphlet, public MulGammaGraphlet
+class MulAddGraph : public TestIOGraph, public AddRank1BetaGraphlet, public MulRank1GammaGraphlet
{
public:
MulAddGraph() = default;
@@ -118,8 +144,30 @@ public:
void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
{
TestIOGraph::init(shape_in, shape_out);
- MulGammaGraphlet::init(g(), shape_in, luci::FusedActFunc::NONE);
- AddBetaGraphlet::init(g(), shape_out, luci::FusedActFunc::RELU);
+ MulRank1GammaGraphlet::init(g(), shape_in, luci::FusedActFunc::NONE);
+ AddRank1BetaGraphlet::init(g(), shape_out, luci::FusedActFunc::RELU);
+
+ // connect network
+ _mul->x(input());
+ _mul->y(_mul_gamma);
+ _add->x(_mul);
+ _add->y(_add_beta);
+ output()->from(_add);
+ }
+};
+
+class MulAddRank4Graph : public TestIOGraph,
+ public AddRank4BetaGraphlet,
+ public MulRank4GammaGraphlet
+{
+public:
+ MulAddRank4Graph() = default;
+
+ void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
+ {
+ TestIOGraph::init(shape_in, shape_out);
+ MulRank4GammaGraphlet::init(g(), shape_in, luci::FusedActFunc::NONE);
+ AddRank4BetaGraphlet::init(g(), shape_out, luci::FusedActFunc::RELU);
// connect network
_mul->x(input());
@@ -133,7 +181,7 @@ public:
/**
* @brief Graph of Add with Const
*/
-class AddGraph : public TestIOGraph, public AddBetaGraphlet
+class AddGraph : public TestIOGraph, public AddRank1BetaGraphlet
{
public:
AddGraph() = default;
@@ -141,7 +189,24 @@ public:
void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
{
TestIOGraph::init(shape_in, shape_out);
- AddBetaGraphlet::init(g(), shape_in, luci::FusedActFunc::RELU);
+ AddRank1BetaGraphlet::init(g(), shape_in, luci::FusedActFunc::RELU);
+
+ // connect network
+ _add->x(input());
+ _add->y(_add_beta);
+ output()->from(_add);
+ }
+};
+
+class AddRank4Graph : public TestIOGraph, public AddRank4BetaGraphlet
+{
+public:
+ AddRank4Graph() = default;
+
+ void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
+ {
+ TestIOGraph::init(shape_in, shape_out);
+ AddRank4BetaGraphlet::init(g(), shape_in, luci::FusedActFunc::RELU);
// connect network
_add->x(input());
@@ -160,6 +225,7 @@ public:
protected:
luci::test::MulAddGraph _mag;
+ luci::test::MulAddRank4Graph _mag_r4;
};
class BatchNormPatternFinderAddTest : public ::testing::Test
@@ -169,6 +235,7 @@ public:
protected:
luci::test::AddGraph _ag;
+ luci::test::AddRank4Graph _ag_r4;
};
TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_add)
@@ -192,6 +259,19 @@ TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_add2)
ASSERT_TRUE(res);
}
+TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_add_rank4)
+{
+ _mag_r4.init({1, 16, 16, 4}, {1, 16, 16, 4});
+
+ luci::CircleMul *mul = nullptr;
+ luci::CircleConst *beta = nullptr;
+
+ auto res = luci::is_batchnorm_add(_mag_r4.add(), mul, beta);
+ ASSERT_TRUE(res);
+ ASSERT_NE(nullptr, mul);
+ ASSERT_NE(nullptr, beta);
+}
+
TEST_F(BatchNormPatternFinderAddTest, is_batchnorm_add_NEG)
{
_ag.init({1, 16, 16, 4}, {1, 16, 16, 4});
@@ -215,3 +295,16 @@ TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_mul)
ASSERT_NE(nullptr, pred);
ASSERT_NE(nullptr, gamma);
}
+
+TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_mul_rank4)
+{
+ _mag_r4.init({1, 16, 16, 4}, {1, 16, 16, 4});
+
+ luci::CircleNode *pred = nullptr;
+ luci::CircleConst *gamma = nullptr;
+
+ auto res = luci::is_batchnorm_mul(_mag_r4.mul(), pred, gamma);
+ ASSERT_TRUE(res);
+ ASSERT_NE(nullptr, pred);
+ ASSERT_NE(nullptr, gamma);
+}
diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp
index 75f04b3b5..6dbb22d7c 100644
--- a/compiler/luci/pass/src/CircleOptimizer.cpp
+++ b/compiler/luci/pass/src/CircleOptimizer.cpp
@@ -22,9 +22,9 @@
#include "luci/Pass/FoldCastPass.h"
#include "luci/Pass/FoldDepthwiseConv2DPass.h"
#include "luci/Pass/FoldDequantizePass.h"
+#include "luci/Pass/FoldGatherPass.h"
#include "luci/Pass/FoldSparseToDensePass.h"
#include "luci/Pass/ForwardReshapeToUnaryOpPass.h"
-#include "luci/Pass/ForceQuantParamPass.h"
#include "luci/Pass/FuseActivationFunctionPass.h"
#include "luci/Pass/FuseAddWithFullyConnectedPass.h"
#include "luci/Pass/FuseAddWithTConvPass.h"
@@ -37,11 +37,11 @@
#include "luci/Pass/FusePreActivationBatchNormPass.h"
#include "luci/Pass/FuseTransposeWithMeanPass.h"
#include "luci/Pass/MakeBatchNormGammaPositivePass.h"
-#include "luci/Pass/PropagateQuantParamPass.h"
#include "luci/Pass/RemoveFakeQuantPass.h"
#include "luci/Pass/RemoveQuantDequantSeqPass.h"
#include "luci/Pass/RemoveRedundantReshapePass.h"
#include "luci/Pass/RemoveRedundantTransposePass.h"
+#include "luci/Pass/RemoveRedundantQuantizePass.h"
#include "luci/Pass/RemoveUnnecessaryReshapePass.h"
#include "luci/Pass/RemoveUnnecessarySlicePass.h"
#include "luci/Pass/RemoveUnnecessaryStridedSlicePass.h"
@@ -52,9 +52,6 @@
#include "luci/Pass/ResolveCustomOpBatchMatMulPass.h"
#include "luci/Pass/ResolveCustomOpMatMulPass.h"
#include "luci/Pass/ResolveCustomOpMaxPoolWithArgmaxPass.h"
-#include "luci/Pass/RequantizePass.h"
-#include "luci/Pass/QuantizeWithMinMaxPass.h"
-#include "luci/Pass/QuantizeDequantizeWeightsPass.h"
#include "luci/Pass/SparsifyTensorPass.h"
#include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h"
#include "luci/Pass/SubstitutePackToReshapePass.h"
@@ -75,9 +72,6 @@
#include "ModulePhase.h"
#include "ProgressReporter.h"
-#include "helpers/Strings.h"
-
-#include "QuantizedModelVerifier.h"
#include <luci/IR/CircleNodes.h>
#include <logo/Phase.h>
@@ -91,37 +85,17 @@ namespace
using namespace luci;
-template <typename T> T lexical_cast(const std::string &str)
-{
- std::istringstream ss;
- ss.str(str);
- T data;
- ss >> data;
- return data;
-}
-
-template <typename T> std::vector<T> lexical_cast(std::vector<std::string> &sv)
-{
- std::vector<T> result;
- std::transform(sv.begin(), sv.end(), std::back_inserter(result),
- [](std::string str) -> T { return lexical_cast<T>(str); });
- return result;
-}
-
class OptimizeOptionsImpl final : public luci::CircleOptimizer::Options
{
public:
void enable(Algorithm) final;
void param(AlgorithmParameters, const std::string &) final;
const std::string param(AlgorithmParameters) const final;
- void params(AlgorithmParameters, std::vector<std::string> &) final;
- std::vector<std::string> params(AlgorithmParameters) const final;
bool query(Algorithm) final;
private:
std::vector<Algorithm> _algorithms;
std::map<AlgorithmParameters, const std::string> _algorithm_params;
- std::map<AlgorithmParameters, std::vector<std::string>> _multiple_params;
};
void OptimizeOptionsImpl::enable(Algorithm algo) { _algorithms.push_back(algo); }
@@ -144,24 +118,6 @@ const std::string OptimizeOptionsImpl::param(AlgorithmParameters param) const
}
}
-void OptimizeOptionsImpl::params(AlgorithmParameters param, std::vector<std::string> &vec)
-{
- _multiple_params[param] = vec;
-}
-
-std::vector<std::string> OptimizeOptionsImpl::params(AlgorithmParameters param) const
-{
- auto param_vec = _multiple_params.find(param);
- if (param_vec != _multiple_params.end())
- {
- return param_vec->second;
- }
- else
- {
- return std::vector<std::string>();
- }
-}
-
bool OptimizeOptionsImpl::query(Algorithm algo)
{
std::vector<Algorithm>::iterator it = std::find(_algorithms.begin(), _algorithms.end(), algo);
@@ -312,6 +268,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<luci::FoldDequantizePass>());
}
+ if (_options->query(Options::Algorithm::FoldGather))
+ {
+ phase.emplace_back(std::make_unique<luci::FoldGatherPass>());
+ }
if (_options->query(Options::Algorithm::FoldSparseToDense))
{
phase.emplace_back(std::make_unique<luci::FoldSparseToDensePass>());
@@ -368,6 +328,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<luci::RemoveRedundantTransposePass>());
}
+ if (_options->query(Options::Algorithm::RemoveRedundantQuantize))
+ {
+ phase.emplace_back(std::make_unique<luci::RemoveRedundantQuantizePass>());
+ }
if (_options->query(Options::Algorithm::ReplaceMulAddWithDepthwiseConv))
{
phase.emplace_back(std::make_unique<luci::ReplaceMulAddWithDepthwiseConvPass>());
@@ -417,174 +381,6 @@ void CircleOptimizer::optimize(loco::Graph *g) const
phase_runner.run(phase);
}
-void CircleOptimizer::quantize(loco::Graph *g) const
-{
- // Fake quantization of weights
- if (_options->query(Options::Algorithm::QuantizeDequantizeWeights))
- {
- static const std::vector<std::string> fakeq_supported_input_model_dtype{"float32"};
- static const std::vector<std::string> fakeq_supported_output_model_dtype{"uint8", "int16"};
- static const std::vector<std::string> fakeq_supported_granularity{"layer", "channel"};
-
- auto input_model_dtype =
- _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
- auto output_model_dtype =
- _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
- auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
-
- if (!in_array(to_lower_case(input_model_dtype), fakeq_supported_input_model_dtype))
- throw std::runtime_error("Unsupported input type. List of supported input type: " +
- to_string(fakeq_supported_input_model_dtype));
-
- if (!in_array(to_lower_case(output_model_dtype), fakeq_supported_output_model_dtype))
- throw std::runtime_error("Unsupported output type. List of supported output type: " +
- to_string(fakeq_supported_output_model_dtype));
-
- if (!in_array(to_lower_case(granularity), fakeq_supported_granularity))
- throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
- to_string(fakeq_supported_granularity));
-
- if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise &&
- str_to_dtype(output_model_dtype) != loco::DataType::U8)
- throw std::runtime_error("Layer-wise quantization only supports uint8 dtype.");
-
- // Clear existing quantparams before doing fake quantization
- for (auto node : loco::active_nodes(loco::output_nodes(g)))
- {
- auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- if (circle_node->quantparam() != nullptr)
- circle_node->quantparam(nullptr);
- }
-
- luci::QuantizeDequantizeWeightsPass fake_quantizer(str_to_dtype(input_model_dtype),
- str_to_dtype(output_model_dtype),
- str_to_granularity(granularity));
- fake_quantizer.run(g);
- }
-
- // Actual quantization of weights, bias, and activation
- if (_options->query(Options::Algorithm::QuantizeWithMinMax))
- {
- static const std::vector<std::string> qwmm_supported_input_model_dtype{"float32"};
- static const std::vector<std::string> qwmm_supported_output_model_dtype{"uint8", "int16"};
- static const std::vector<std::string> qwmm_supported_granularity{"layer", "channel"};
- static const std::vector<std::string> qwmm_supported_input_type{"uint8", "int16"};
- static const std::vector<std::string> qwmm_supported_output_type{"uint8", "int16"};
-
- auto input_model_dtype =
- _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
- auto output_model_dtype =
- _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
- auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
- auto input_type = _options->param(Options::AlgorithmParameters::Quantize_input_type);
- if (input_type.empty())
- input_type = output_model_dtype;
- auto output_type = _options->param(Options::AlgorithmParameters::Quantize_output_type);
- if (output_type.empty())
- output_type = output_model_dtype;
-
- if (!in_array(to_lower_case(input_model_dtype), qwmm_supported_input_model_dtype))
- throw std::runtime_error("Unsupported input type. List of supported input types: " +
- to_string(qwmm_supported_input_model_dtype));
-
- if (!in_array(to_lower_case(output_model_dtype), qwmm_supported_output_model_dtype))
- throw std::runtime_error("Unsupported output type. List of supported output types: " +
- to_string(qwmm_supported_output_model_dtype));
-
- if (!in_array(to_lower_case(granularity), qwmm_supported_granularity))
- throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
- to_string(qwmm_supported_granularity));
-
- if (!in_array(to_lower_case(input_type), qwmm_supported_input_type))
- throw std::runtime_error("Unsupported input type. List of supported input types: " +
- to_string(qwmm_supported_input_type));
-
- if (!in_array(to_lower_case(output_type), qwmm_supported_output_type))
- throw std::runtime_error("Unsupported output type. List of supported output types: " +
- to_string(qwmm_supported_output_type));
-
- if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise &&
- str_to_dtype(output_model_dtype) != loco::DataType::U8)
- throw std::runtime_error("Layer-wise quantization only supports uint8 dtype.");
-
- luci::QuantizeWithMinMaxPass quantizer(
- str_to_dtype(input_model_dtype), str_to_dtype(output_model_dtype),
- str_to_granularity(granularity), str_to_dtype(input_type), str_to_dtype(output_type));
- quantizer.run(g);
-
- // Post-quantization optimizations
- logo::Phase phase;
-
- phase.emplace_back(std::make_unique<luci::PropagateQuantParamPass>());
-
- phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
- phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
- phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
-
- ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
- logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
- phase_runner.attach(&prog);
- phase_runner.run(phase);
-
- // Verify the type/granularity of the quantized model
- luci::QuantizedModelVerifier verifier(str_to_dtype(output_model_dtype),
- str_to_granularity(granularity));
- verifier.verify(g);
- }
-
- // Requantize
- if (_options->query(Options::Algorithm::Requantize))
- {
- static const std::vector<std::string> rq_supported_input_model_dtype{"int8"};
- static const std::vector<std::string> rq_supported_output_model_dtype{"uint8"};
-
- auto input_model_dtype =
- _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
- auto output_model_dtype =
- _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
-
- if (!in_array(to_lower_case(input_model_dtype), rq_supported_input_model_dtype))
- throw std::runtime_error("Unsupported input type. List of supported input types: " +
- to_string(rq_supported_input_model_dtype));
-
- if (!in_array(to_lower_case(output_model_dtype), rq_supported_output_model_dtype))
- throw std::runtime_error("Unsupported output type. List of supported output types: " +
- to_string(rq_supported_output_model_dtype));
-
- luci::RequantizePass requantizer(str_to_dtype(input_model_dtype),
- str_to_dtype(output_model_dtype));
- requantizer.run(g);
- }
-
- // Force to write quantparam to specified tensors
- // NOTE Only per-tensor (not per-channel) qparam can be written
- if (_options->query(Options::Algorithm::ForceQuantParam))
- {
- ForceQuantParamPass::TensorVector tensors =
- _options->params(Options::AlgorithmParameters::Quantize_tensor_names);
- auto str_scales = _options->params(Options::AlgorithmParameters::Quantize_scales);
- auto str_zero_points = _options->params(Options::AlgorithmParameters::Quantize_zero_points);
-
- // Cast scales/zero_points to proper types
- ForceQuantParamPass::ScaleVector scales = lexical_cast<float>(str_scales);
- ForceQuantParamPass::ZPVector zero_points = lexical_cast<int64_t>(str_zero_points);
-
- ForceQuantParamPass fq(tensors, scales, zero_points);
- fq.run(g);
- }
-
- logo::Phase phase;
-
- // Do Shape/Type inference
- phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
- phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
-
- ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
- logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
- phase_runner.attach(&prog);
- phase_runner.run(phase);
-}
-
void CircleOptimizer::sparsify(loco::Graph *g) const
{
if (_options->query(Options::Algorithm::SparsifyTensorPass))
diff --git a/compiler/luci/pass/src/CircleOptimizer.test.cpp b/compiler/luci/pass/src/CircleOptimizer.test.cpp
index a1b5c7f80..041fc7d75 100644
--- a/compiler/luci/pass/src/CircleOptimizer.test.cpp
+++ b/compiler/luci/pass/src/CircleOptimizer.test.cpp
@@ -71,171 +71,3 @@ TEST(CircleOptimizerTest, sparsify_simple)
SUCCEED();
}
-
-TEST(CircleOptimizerTest, quantize_quantdequant_simple)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::QuantizeDequantizeWeights);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
- options->param(AlgorithmParameters::Quantize_granularity, "layer");
-
- o.quantize(&g);
-
- SUCCEED();
-}
-
-TEST(CircleOptimizerTest, quantize_quantdequant_input_NEG)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::QuantizeDequantizeWeights);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "invalid");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
- options->param(AlgorithmParameters::Quantize_granularity, "layer");
-
- EXPECT_THROW(o.quantize(&g), std::runtime_error);
-}
-
-TEST(CircleOptimizerTest, quantize_quantdequant_output_NEG)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::QuantizeDequantizeWeights);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "invalid");
- options->param(AlgorithmParameters::Quantize_granularity, "layer");
-
- EXPECT_THROW(o.quantize(&g), std::runtime_error);
-}
-
-TEST(CircleOptimizerTest, quantize_quantdequant_gran_NEG)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::QuantizeDequantizeWeights);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
- options->param(AlgorithmParameters::Quantize_granularity, "invalid");
-
- EXPECT_THROW(o.quantize(&g), std::runtime_error);
-}
-
-TEST(CircleOptimizerTest, quantize_minmax_simple)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::QuantizeWithMinMax);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
- options->param(AlgorithmParameters::Quantize_granularity, "layer");
-
- o.quantize(&g);
-
- SUCCEED();
-}
-
-TEST(CircleOptimizerTest, quantize_minmax_input_NEG)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::QuantizeWithMinMax);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "invalid");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
- options->param(AlgorithmParameters::Quantize_granularity, "layer");
-
- EXPECT_THROW(o.quantize(&g), std::runtime_error);
-}
-
-TEST(CircleOptimizerTest, quantize_minmax_output_NEG)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::QuantizeWithMinMax);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "invalid");
- options->param(AlgorithmParameters::Quantize_granularity, "layer");
-
- EXPECT_THROW(o.quantize(&g), std::runtime_error);
-}
-
-TEST(CircleOptimizerTest, quantize_minmax_gran_NEG)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::QuantizeWithMinMax);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
- options->param(AlgorithmParameters::Quantize_granularity, "invalid");
-
- EXPECT_THROW(o.quantize(&g), std::runtime_error);
-}
-
-TEST(CircleOptimizerTest, quantize_requant_simple)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::Requantize);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "int8");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
-
- o.quantize(&g);
-
- SUCCEED();
-}
-
-TEST(CircleOptimizerTest, quantize_requant_input_NEG)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::Requantize);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "invalid");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
-
- EXPECT_THROW(o.quantize(&g), std::runtime_error);
-}
-
-TEST(CircleOptimizerTest, quantize_requant_output_NEG)
-{
- loco::Graph g;
- luci::CircleOptimizer o;
-
- auto options = o.options();
-
- options->enable(Algorithms::Requantize);
- options->param(AlgorithmParameters::Quantize_input_model_dtype, "int8");
- options->param(AlgorithmParameters::Quantize_output_model_dtype, "invalid");
-
- EXPECT_THROW(o.quantize(&g), std::runtime_error);
-}
diff --git a/compiler/luci/pass/src/CircleQuantizer.cpp b/compiler/luci/pass/src/CircleQuantizer.cpp
new file mode 100644
index 000000000..ce38a90b9
--- /dev/null
+++ b/compiler/luci/pass/src/CircleQuantizer.cpp
@@ -0,0 +1,458 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/CircleQuantizer.h"
+
+#include "luci/Pass/CopyQuantParamPass.h"
+#include "luci/Pass/ForceQuantParamPass.h"
+#include "luci/Pass/PropagateQParamForwardPass.h"
+#include "luci/Pass/RequantizePass.h"
+#include "luci/Pass/ConvertToFakeQuantizedModelPass.h"
+#include "luci/Pass/FoldDequantizePass.h"
+#include "luci/Pass/QuantizePreCheckerPass.h"
+#include "luci/Pass/QuantizeWithMinMaxPass.h"
+#include "luci/Pass/QuantizeDequantizeWeightsPass.h"
+
+#include "luci/Pass/CircleShapeInferencePass.h"
+#include "luci/Pass/CircleTypeInferencePass.h"
+
+// logo passes
+#include <logo/RemoveDeadNodeWithQueryPass.h>
+
+#include "ProgressReporter.h"
+#include "helpers/Strings.h"
+
+#include "QuantizedModelVerifier.h"
+
+#include <luci/IR/CircleNode.h>
+#include <logo/Phase.h>
+
+#include <memory>
+
+namespace
+{
+
+using namespace luci;
+using LayerParam = luci::CircleQuantizer::Options::LayerParam;
+
+template <typename T> T lexical_cast(const std::string &str)
+{
+ std::istringstream ss;
+ ss.str(str);
+ T data;
+ ss >> data;
+ return data;
+}
+
+template <typename T> std::vector<T> lexical_cast(std::vector<std::string> &sv)
+{
+ std::vector<T> result;
+ std::transform(sv.begin(), sv.end(), std::back_inserter(result),
+ [](std::string str) -> T { return lexical_cast<T>(str); });
+ return result;
+}
+
+class QuantizeOptionsImpl final : public luci::CircleQuantizer::Options
+{
+public:
+ void enable(Algorithm) final;
+ void param(AlgorithmParameters, const std::string &) final;
+ const std::string param(AlgorithmParameters) const final;
+ void params(AlgorithmParameters, std::vector<std::string> &) final;
+ std::vector<std::string> params(AlgorithmParameters) const final;
+ void layer_params(AlgorithmParameters, std::vector<std::shared_ptr<LayerParam>> &) final;
+ std::vector<std::shared_ptr<LayerParam>> layer_params(AlgorithmParameters) const final;
+ bool query(Algorithm) final;
+
+private:
+ std::vector<Algorithm> _algorithms;
+ std::map<AlgorithmParameters, const std::string> _algorithm_params;
+ std::map<AlgorithmParameters, std::vector<std::string>> _multiple_params;
+ std::map<AlgorithmParameters, std::vector<std::shared_ptr<LayerParam>>> _layer_params;
+};
+
+void QuantizeOptionsImpl::enable(Algorithm algo) { _algorithms.push_back(algo); }
+
+void QuantizeOptionsImpl::param(AlgorithmParameters param, const std::string &str)
+{
+ _algorithm_params.insert(std::pair<AlgorithmParameters, const std::string>(param, str));
+}
+
+const std::string QuantizeOptionsImpl::param(AlgorithmParameters param) const
+{
+ auto param_str = _algorithm_params.find(param);
+ if (param_str != _algorithm_params.end())
+ {
+ return param_str->second;
+ }
+ else
+ {
+ return std::string();
+ }
+}
+
+void QuantizeOptionsImpl::params(AlgorithmParameters param, std::vector<std::string> &vec)
+{
+ _multiple_params[param] = vec;
+}
+
+std::vector<std::string> QuantizeOptionsImpl::params(AlgorithmParameters param) const
+{
+ auto param_vec = _multiple_params.find(param);
+ if (param_vec != _multiple_params.end())
+ {
+ return param_vec->second;
+ }
+ else
+ {
+ return std::vector<std::string>();
+ }
+}
+
+void QuantizeOptionsImpl::layer_params(AlgorithmParameters param,
+ std::vector<std::shared_ptr<LayerParam>> &vec)
+{
+ _layer_params[param] = vec;
+}
+
+std::vector<std::shared_ptr<LayerParam>>
+QuantizeOptionsImpl::layer_params(AlgorithmParameters param) const
+{
+ auto param_vec = _layer_params.find(param);
+ if (param_vec != _layer_params.end())
+ {
+ return param_vec->second;
+ }
+ else
+ {
+ return std::vector<std::shared_ptr<LayerParam>>();
+ }
+}
+
+bool QuantizeOptionsImpl::query(Algorithm algo)
+{
+ std::vector<Algorithm>::iterator it = std::find(_algorithms.begin(), _algorithms.end(), algo);
+ if (it == _algorithms.end())
+ return false;
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+CircleQuantizer::Options *CircleQuantizer::options(void)
+{
+ if (_options == nullptr)
+ {
+ _options = std::make_unique<QuantizeOptionsImpl>();
+ }
+
+ return _options.get();
+}
+
+void CircleQuantizer::quantize(loco::Graph *g) const
+{
+ // Fake quantization of weights
+ if (_options->query(Options::Algorithm::QuantizeDequantizeWeights))
+ {
+ static const std::vector<std::string> fakeq_supported_input_model_dtype{"float32"};
+ static const std::vector<std::string> fakeq_supported_output_model_dtype{"uint8", "int16"};
+ static const std::vector<std::string> fakeq_supported_granularity{"layer", "channel"};
+
+ auto input_model_dtype =
+ _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
+ auto output_model_dtype =
+ _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
+ auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
+ auto layer_params = _options->layer_params(Options::AlgorithmParameters::Quantize_layer_params);
+
+ if (!in_array(to_lower_case(input_model_dtype), fakeq_supported_input_model_dtype))
+ throw std::runtime_error("Unsupported input type. List of supported input type: " +
+ to_string(fakeq_supported_input_model_dtype));
+
+ if (!in_array(to_lower_case(output_model_dtype), fakeq_supported_output_model_dtype))
+ throw std::runtime_error("Unsupported output type. List of supported output type: " +
+ to_string(fakeq_supported_output_model_dtype));
+
+ if (!in_array(to_lower_case(granularity), fakeq_supported_granularity))
+ throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
+ to_string(fakeq_supported_granularity));
+
+ if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise &&
+ str_to_dtype(output_model_dtype) != loco::DataType::U8)
+ throw std::runtime_error("Layer-wise quantization only supports uint8 dtype.");
+
+ // Check dtype/granularity of layer params
+ for (auto layer_param : layer_params)
+ {
+ auto name = layer_param->name;
+ if (!in_array(to_lower_case(layer_param->dtype), fakeq_supported_output_model_dtype))
+ {
+ throw std::runtime_error("Unsupported dtype in " + name + ". List of supported dtype: " +
+ to_string(fakeq_supported_output_model_dtype));
+ }
+ if (!in_array(to_lower_case(layer_param->granularity), fakeq_supported_granularity))
+ {
+ throw std::runtime_error(
+ "Unsupported granularity in " + name +
+ ". List of supported granularity: " + to_string(fakeq_supported_granularity));
+ }
+ }
+
+ // Clear existing quantparams before doing fake quantization
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ if (circle_node->quantparam() != nullptr)
+ circle_node->quantparam(nullptr);
+ }
+
+ auto ctx = std::make_unique<luci::QuantizeDequantizeWeightsPass::Context>();
+ {
+ ctx->input_model_dtype = str_to_dtype(input_model_dtype);
+ ctx->output_model_dtype = str_to_dtype(output_model_dtype);
+ ctx->granularity = str_to_granularity(granularity);
+
+ for (auto layer_param : layer_params)
+ {
+ LayerInfo info;
+ {
+ info.name = layer_param->name;
+ info.dtype = str_to_dtype(layer_param->dtype);
+ info.granularity = str_to_granularity(layer_param->granularity);
+ }
+ ctx->layers_info.emplace_back(info);
+ }
+ }
+
+ luci::QuantizeDequantizeWeightsPass fake_quantizer(std::move(ctx));
+
+ fake_quantizer.run(g);
+ }
+
+ // Actual quantization of weights, bias, and activation
+ if (_options->query(Options::Algorithm::QuantizeWithMinMax))
+ {
+ static const std::vector<std::string> qwmm_supported_input_model_dtype{"float32"};
+ static const std::vector<std::string> qwmm_supported_output_model_dtype{"uint8", "int16"};
+ static const std::vector<std::string> qwmm_supported_granularity{"layer", "channel"};
+ static const std::vector<std::string> qwmm_supported_input_type{"uint8", "int16"};
+ static const std::vector<std::string> qwmm_supported_output_type{"uint8", "int16"};
+
+ auto input_model_dtype =
+ _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
+ auto output_model_dtype =
+ _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
+ auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
+ auto input_type = _options->param(Options::AlgorithmParameters::Quantize_input_type);
+ if (input_type.empty())
+ input_type = output_model_dtype;
+ auto output_type = _options->param(Options::AlgorithmParameters::Quantize_output_type);
+ if (output_type.empty())
+ output_type = output_model_dtype;
+
+ bool TF_style_maxpool =
+ _options->param(Options::AlgorithmParameters::Quantize_TF_style_maxpool) == "True";
+
+ auto layer_params = _options->layer_params(Options::AlgorithmParameters::Quantize_layer_params);
+
+ if (!in_array(to_lower_case(input_model_dtype), qwmm_supported_input_model_dtype))
+ throw std::runtime_error("Unsupported input type. List of supported input types: " +
+ to_string(qwmm_supported_input_model_dtype));
+
+ if (!in_array(to_lower_case(output_model_dtype), qwmm_supported_output_model_dtype))
+ throw std::runtime_error("Unsupported output type. List of supported output types: " +
+ to_string(qwmm_supported_output_model_dtype));
+
+ if (!in_array(to_lower_case(granularity), qwmm_supported_granularity))
+ throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
+ to_string(qwmm_supported_granularity));
+
+ if (!in_array(to_lower_case(input_type), qwmm_supported_input_type))
+ throw std::runtime_error("Unsupported input type. List of supported input types: " +
+ to_string(qwmm_supported_input_type));
+
+ if (!in_array(to_lower_case(output_type), qwmm_supported_output_type))
+ throw std::runtime_error("Unsupported output type. List of supported output types: " +
+ to_string(qwmm_supported_output_type));
+
+ if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise &&
+ str_to_dtype(output_model_dtype) != loco::DataType::U8)
+ throw std::runtime_error("Layer-wise quantization only supports uint8 dtype.");
+
+ // Check dtype/granularity of layer params
+ for (auto layer_param : layer_params)
+ {
+ auto name = layer_param->name;
+ if (!in_array(to_lower_case(layer_param->dtype), qwmm_supported_output_model_dtype))
+ {
+ throw std::runtime_error("Unsupported dtype in " + name + ". List of supported dtype: " +
+ to_string(qwmm_supported_output_model_dtype));
+ }
+ if (!in_array(to_lower_case(layer_param->granularity), qwmm_supported_granularity))
+ {
+ throw std::runtime_error(
+ "Unsupported granularity in " + name +
+ ". List of supported granularity: " + to_string(qwmm_supported_granularity));
+ }
+ }
+
+ // Input model checker for quantization
+ luci::QuantizePreCheckerPass input_model_checker{};
+ input_model_checker.run(g);
+
+ auto ctx = std::make_unique<luci::QuantizeWithMinMaxPass::Context>();
+ {
+ ctx->input_model_dtype = str_to_dtype(input_model_dtype);
+ ctx->output_model_dtype = str_to_dtype(output_model_dtype);
+ ctx->granularity = str_to_granularity(granularity);
+ ctx->input_type = str_to_dtype(input_type);
+ ctx->output_type = str_to_dtype(output_type);
+ ctx->TF_style_maxpool = TF_style_maxpool;
+
+ for (auto layer_param : layer_params)
+ {
+ LayerInfo info;
+ {
+ info.name = layer_param->name;
+ info.dtype = str_to_dtype(layer_param->dtype);
+ info.granularity = str_to_granularity(layer_param->granularity);
+ }
+ ctx->layers_info.emplace_back(info);
+ }
+ }
+
+ luci::QuantizeWithMinMaxPass quantizer(std::move(ctx));
+
+ quantizer.run(g);
+
+ auto verify_ctx = std::make_unique<luci::QuantizedModelVerifier::Context>();
+ {
+ verify_ctx->output_model_dtype = str_to_dtype(output_model_dtype);
+ verify_ctx->granularity = str_to_granularity(granularity);
+ verify_ctx->input_type = str_to_dtype(input_type);
+ verify_ctx->output_type = str_to_dtype(output_type);
+ verify_ctx->TF_style_maxpool = TF_style_maxpool;
+
+ for (auto layer_param : layer_params)
+ {
+ LayerInfo info;
+ {
+ info.name = layer_param->name;
+ info.dtype = str_to_dtype(layer_param->dtype);
+ info.granularity = str_to_granularity(layer_param->granularity);
+ }
+ verify_ctx->layers_info.emplace_back(info);
+ }
+ }
+
+ // Verify the type/granularity of the quantized model
+ luci::QuantizedModelVerifier verifier(std::move(verify_ctx));
+
+ verifier.verify(g);
+ }
+
+ // Requantize
+ if (_options->query(Options::Algorithm::Requantize))
+ {
+ static const std::vector<std::string> rq_supported_input_model_dtype{"int8"};
+ static const std::vector<std::string> rq_supported_output_model_dtype{"uint8"};
+
+ auto input_model_dtype =
+ _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
+ auto output_model_dtype =
+ _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
+
+ if (!in_array(to_lower_case(input_model_dtype), rq_supported_input_model_dtype))
+ throw std::runtime_error("Unsupported input type. List of supported input types: " +
+ to_string(rq_supported_input_model_dtype));
+
+ if (!in_array(to_lower_case(output_model_dtype), rq_supported_output_model_dtype))
+ throw std::runtime_error("Unsupported output type. List of supported output types: " +
+ to_string(rq_supported_output_model_dtype));
+
+ luci::RequantizePass requantizer(str_to_dtype(input_model_dtype),
+ str_to_dtype(output_model_dtype));
+ requantizer.run(g);
+ }
+
+ // Force to write quantparam to specified tensors
+ // NOTE Only per-tensor (not per-channel) qparam can be written
+ if (_options->query(Options::Algorithm::ForceQuantParam))
+ {
+ ForceQuantParamPass::TensorVector tensors =
+ _options->params(Options::AlgorithmParameters::Quantize_tensor_names);
+ auto str_scales = _options->params(Options::AlgorithmParameters::Quantize_scales);
+ auto str_zero_points = _options->params(Options::AlgorithmParameters::Quantize_zero_points);
+
+ // Cast scales/zero_points to proper types
+ ForceQuantParamPass::ScaleVector scales = lexical_cast<float>(str_scales);
+ ForceQuantParamPass::ZPVector zero_points = lexical_cast<int64_t>(str_zero_points);
+
+ ForceQuantParamPass fq(tensors, scales, zero_points);
+ fq.run(g);
+ }
+
+ // Copy quantparam of a tensor to another tensor
+ if (_options->query(Options::Algorithm::CopyQuantParam))
+ {
+ CopyQuantParamPass::TensorVector src_tensors =
+ _options->params(Options::AlgorithmParameters::Quantize_src_tensor_names);
+ CopyQuantParamPass::TensorVector dst_tensors =
+ _options->params(Options::AlgorithmParameters::Quantize_dst_tensor_names);
+
+ CopyQuantParamPass cq(src_tensors, dst_tensors);
+ cq.run(g);
+ }
+
+ // Convert quantized model to fake-quantized model
+ if (_options->query(Options::Algorithm::ConvertToFakeQuantizedModel))
+ {
+ luci::ConvertToFakeQuantizedModelPass fake_quantizer;
+ fake_quantizer.run(g);
+
+ logo::Phase phase;
+
+ // Default passes
+ phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
+ phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
+ phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
+
+ // Fold Dequantize Ops generated during fake quantization
+ phase.emplace_back(std::make_unique<luci::FoldDequantizePass>());
+
+ ProgressReporter prog(g, logo::PhaseStrategy::Restart);
+ logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
+ phase_runner.attach(&prog);
+ phase_runner.run(phase);
+ }
+
+ logo::Phase phase;
+
+ // Do Shape/Type inference
+ phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
+ phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
+
+ ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
+ logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
+ phase_runner.attach(&prog);
+ phase_runner.run(phase);
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/CircleQuantizer.test.cpp b/compiler/luci/pass/src/CircleQuantizer.test.cpp
new file mode 100644
index 000000000..5766d5fe5
--- /dev/null
+++ b/compiler/luci/pass/src/CircleQuantizer.test.cpp
@@ -0,0 +1,191 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/CircleQuantizer.h"
+
+#include <gtest/gtest.h>
+
+using namespace luci;
+using Algorithms = luci::CircleQuantizer::Options::Algorithm;
+using AlgorithmParameters = luci::CircleQuantizer::Options::AlgorithmParameters;
+
+TEST(CircleQuantizerTest, quantize_quantdequant_simple)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeDequantizeWeights);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
+ options->param(AlgorithmParameters::Quantize_granularity, "layer");
+
+ o.quantize(&g);
+
+ SUCCEED();
+}
+
+TEST(CircleQuantizerTest, quantize_quantdequant_input_NEG)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeDequantizeWeights);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "invalid");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
+ options->param(AlgorithmParameters::Quantize_granularity, "layer");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleQuantizerTest, quantize_quantdequant_output_NEG)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeDequantizeWeights);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "invalid");
+ options->param(AlgorithmParameters::Quantize_granularity, "layer");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleQuantizerTest, quantize_quantdequant_gran_NEG)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeDequantizeWeights);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
+ options->param(AlgorithmParameters::Quantize_granularity, "invalid");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleQuantizerTest, quantize_minmax_simple)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeWithMinMax);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
+ options->param(AlgorithmParameters::Quantize_granularity, "layer");
+
+ o.quantize(&g);
+
+ SUCCEED();
+}
+
+TEST(CircleQuantizerTest, quantize_minmax_input_NEG)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeWithMinMax);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "invalid");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
+ options->param(AlgorithmParameters::Quantize_granularity, "layer");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleQuantizerTest, quantize_minmax_output_NEG)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeWithMinMax);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "invalid");
+ options->param(AlgorithmParameters::Quantize_granularity, "layer");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleQuantizerTest, quantize_minmax_gran_NEG)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeWithMinMax);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
+ options->param(AlgorithmParameters::Quantize_granularity, "invalid");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleQuantizerTest, quantize_requant_simple)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::Requantize);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "int8");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
+
+ o.quantize(&g);
+
+ SUCCEED();
+}
+
+TEST(CircleQuantizerTest, quantize_requant_input_NEG)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::Requantize);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "invalid");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "uint8");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleQuantizerTest, quantize_requant_output_NEG)
+{
+ loco::Graph g;
+ luci::CircleQuantizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::Requantize);
+ options->param(AlgorithmParameters::Quantize_input_model_dtype, "int8");
+ options->param(AlgorithmParameters::Quantize_output_model_dtype, "invalid");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp
index 270714049..ce4f54035 100644
--- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp
+++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp
@@ -228,6 +228,9 @@ bool check_4d_reshape(loco::Node *node, const std::vector<int32_t> indices)
if (input->shape_status() != luci::ShapeStatus::VALID)
return false;
+ if (input->rank() != 4)
+ return false;
+
if (reshape->shape_status() != luci::ShapeStatus::VALID)
return false;
@@ -804,6 +807,8 @@ class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool>
return true;
}
+ bool visit(luci::CircleElu *node) { return convert_unary_features<luci::CircleElu>(node); }
+
bool visit(luci::CircleLeakyRelu *node)
{
return convert_unary_features<luci::CircleLeakyRelu>(node);
@@ -1240,6 +1245,7 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g)
break;
case luci::CircleOpcode::ADD:
case luci::CircleOpcode::CONCATENATION:
+ case luci::CircleOpcode::ELU:
case luci::CircleOpcode::LEAKY_RELU:
case luci::CircleOpcode::LOGISTIC:
case luci::CircleOpcode::MAXIMUM:
diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
index c9412fbb1..dd81d1380 100644
--- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
+++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
@@ -264,6 +264,22 @@ public:
luci::CircleConst *input2 = nullptr;
};
+class EluGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ elu = g.nodes()->create<luci::CircleElu>();
+ elu->features(input);
+ elu->name("elu");
+
+ return elu;
+ }
+
+public:
+ luci::CircleElu *elu = nullptr;
+};
+
class LeakyReluGraph final : public SimpleGraph
{
protected:
@@ -941,6 +957,26 @@ TEST(ConvertNCHWToNHWC, Concatenation)
EXPECT_EQ(3, g.concat->axis());
}
+TEST(ConvertNCHWToNHWC, Elu)
+{
+ EluGraph g;
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.elu->features());
+
+ auto elu_succs = loco::succs(g.elu);
+ EXPECT_EQ(1, elu_succs.size());
+ check_post_trans(*elu_succs.begin());
+
+ // Check elu shape
+ EXPECT_EQ(1, g.elu->dim(0).value());
+ EXPECT_EQ(4, g.elu->dim(1).value());
+ EXPECT_EQ(4, g.elu->dim(2).value());
+ EXPECT_EQ(16, g.elu->dim(3).value());
+}
+
TEST(ConvertNCHWToNHWC, LeakyRelu)
{
LeakyReluGraph g;
diff --git a/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp
new file mode 100644
index 000000000..11970fff5
--- /dev/null
+++ b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp
@@ -0,0 +1,214 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ConvertToFakeQuantizedModelPass.h"
+#include "luci/Pass/QuantizationParameters.h"
+
+#include "QuantizationUtils.h"
+
+#include <luci/Profile/CircleNodeOrigin.h>
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Log.h>
+
+namespace
+{
+
+// Create Quantize Op whose dtype/shape/qparam are the same with node
+luci::CircleQuantize *create_quantize(luci::CircleNode *node)
+{
+ auto quantize = node->graph()->nodes()->create<luci::CircleQuantize>();
+ quantize->name(node->name() + "_Quantize");
+ quantize->dtype(node->dtype());
+ quantize->rank(node->rank());
+ for (uint32_t i = 0; i < node->rank(); i++)
+ quantize->dim(i).set(node->dim(i).value());
+
+ quantize->shape_status(luci::ShapeStatus::VALID);
+
+ copy_quantparam(node, quantize);
+
+ luci::add_origin(quantize, luci::get_origin(node));
+
+ return quantize;
+}
+
+// Create Dequantize Op whose shape is the same with node
+luci::CircleDequantize *create_dequantize(luci::CircleNode *node)
+{
+ auto dequantize = node->graph()->nodes()->create<luci::CircleDequantize>();
+ dequantize->name(node->name() + "_Dequantize");
+ dequantize->dtype(loco::DataType::FLOAT32);
+ dequantize->rank(node->rank());
+ for (uint32_t i = 0; i < node->rank(); i++)
+ dequantize->dim(i).set(node->dim(i).value());
+
+ dequantize->shape_status(luci::ShapeStatus::VALID);
+
+ luci::add_origin(dequantize, luci::get_origin(node));
+
+ return dequantize;
+}
+
+// Return true if node is quantized activation
+// 1. dtype is u8 or s16
+// 2. node has qparam
+bool is_quant_act(const luci::CircleNode *node)
+{
+ if (node->dtype() != loco::DataType::U8 and node->dtype() != loco::DataType::S16)
+ return false;
+
+ if (not node->quantparam())
+ return false;
+
+ return true;
+}
+
+// Return true if node is quantized const
+// 1. dtype is not fp32
+// 2. node has qparam
+// NOTE Quantized const can have the following types
+// u8 (weights, activation), s16 (weights, activation), s32 (bias), s64 (bias)
+bool is_quant_const(const luci::CircleConst *node)
+{
+ if (node->dtype() == loco::DataType::FLOAT32)
+ return false;
+
+ if (not node->quantparam())
+ return false;
+
+ return true;
+}
+
+// Insert dequantize Op after node
+void insert_dequantize(loco::Node *lnode)
+{
+ auto node = loco::must_cast<luci::CircleNode *>(lnode);
+ auto dequant = create_dequantize(node);
+ loco::replace(node).with(dequant);
+ dequant->input(node);
+}
+
+// Insert quantize Op after node and return the quantize Op
+luci::CircleQuantize *insert_quantize(loco::Node *lnode)
+{
+ auto node = loco::must_cast<luci::CircleNode *>(lnode);
+ auto quant = create_quantize(node);
+ loco::replace(node).with(quant);
+ quant->input(node);
+ return quant;
+}
+
+// Dequantize node
+void dequantize(luci::CircleNode *node)
+{
+ node->dtype(loco::DataType::FLOAT32);
+ node->quantparam(nullptr);
+}
+
+// Do fake quantization on quantized activation
+// 1. Insert Quantize-Dequantize Ops
+// 2. Update dtype/quantparam of node
+void fq_activation(luci::CircleNode *node)
+{
+ if (not is_quant_act(node))
+ return;
+
+ auto quant = insert_quantize(node);
+ insert_dequantize(quant);
+
+ dequantize(node);
+}
+
+#define RETURN_UNLESS(COND) \
+ if (not(COND)) \
+ return;
+
+// Visitor to do fake quantization for each Op
+// For non-const activation, insert Quantize-Dequantize after the ofm
+// For quantized const, insert Dequantize after the const
+struct FakeQuantize final : public luci::CircleNodeMutableVisitor<void>
+{
+ void visit(luci::CircleNode *node)
+ {
+ throw std::runtime_error("Unsupported op for fake quantization in " + node->name());
+ }
+
+ void visit(luci::CircleInput *node)
+ {
+ RETURN_UNLESS(is_quant_act(node));
+
+ auto quant = insert_quantize(node);
+ insert_dequantize(quant);
+
+ dequantize(node);
+
+ // Update graph input
+ const auto inputs = node->graph()->inputs();
+ auto graph_input = inputs->at(node->index());
+ graph_input->dtype(loco::DataType::FLOAT32);
+ }
+
+ void visit(luci::CircleOutput *node)
+ {
+ RETURN_UNLESS(is_quant_act(node));
+
+ dequantize(node);
+
+ // Update graph output
+ const auto outputs = node->graph()->outputs();
+ auto graph_output = outputs->at(node->index());
+ graph_output->dtype(loco::DataType::FLOAT32);
+ }
+
+ // For quantized const, insert Dequantize Op
+ void visit(luci::CircleConst *node)
+ {
+ RETURN_UNLESS(is_quant_const(node));
+
+ insert_dequantize(node);
+ }
+
+ // For non-const activation, insert Quantize-Dequantize Ops
+ // and dequantize the node
+ void visit(luci::CircleConv2D *node) { fq_activation(node); }
+ void visit(luci::CircleAdd *node) { fq_activation(node); }
+};
+
+#undef RETURN_UNLESS
+
+} // namespace
+
+namespace luci
+{
+
+bool ConvertToFakeQuantizedModelPass::run(loco::Graph *g)
+{
+ LOGGER(l);
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ INFO(l) << "ConvertToFakeQuantizedModelPass visit node: " << circle_node->name() << std::endl;
+
+ FakeQuantize fq;
+ circle_node->accept(&fq);
+ }
+
+ // One time run
+ return false;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.test.cpp b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.test.cpp
new file mode 100644
index 000000000..560d68a74
--- /dev/null
+++ b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.test.cpp
@@ -0,0 +1,277 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <logo/Phase.h>
+
+#include "luci/Pass/ConvertToFakeQuantizedModelPass.h"
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+// Check the below pattern
+// Quantize (scale, zp) -> Dequantize (node)
+void check_q_dq(loco::Node *node, float scale, int64_t zp)
+{
+ auto dequant = dynamic_cast<luci::CircleDequantize *>(node);
+ EXPECT_TRUE(dequant != nullptr);
+ auto quant = dynamic_cast<luci::CircleQuantize *>(dequant->input());
+ EXPECT_TRUE(quant != nullptr);
+ auto qparam = quant->quantparam();
+ EXPECT_EQ(scale, qparam->scale[0]);
+ EXPECT_EQ(zp, qparam->zerop[0]);
+}
+
+// Check the below pattern
+// Dequantize (node)
+void check_dq(loco::Node *node)
+{
+ auto dequant = dynamic_cast<luci::CircleDequantize *>(node);
+ EXPECT_TRUE(dequant != nullptr);
+}
+
+void set_qparam(luci::CircleNode *node, float scale, int64_t zp)
+{
+ auto qparam = std::make_unique<luci::CircleQuantParam>();
+ {
+ qparam->scale.push_back(scale);
+ qparam->zerop.push_back(zp);
+ }
+ node->quantparam(std::move(qparam));
+}
+
+/**
+ * SimpleGraph for testing
+ * - Child class should implement insertGraphBody()
+ *
+ * Example (U8ConvGraph inherits SimpleGraph and create Conv2D Op)
+ *
+ * BEFORE
+ * - A model is quantized (ex: u8)
+ *
+ * [Input(u8)] [Filter(u8)] [Bias(s32)]
+ * \ | /
+ * \ | /
+ * \ | /
+ * [Conv2D(u8)]
+ * |
+ * [Output(u8)]
+ *
+ * AFTER
+ * - Ops are converted to fp32
+ * - Quantize/Dequantize Ops are inserted properly
+ * - Q-DQ is inserted after non-const activation
+ * - DQ is inserted after const
+ *
+ * [Input(u8)]
+ * |
+ * [Quant(u8)] [Filter(u8)] [Bias(s32)]
+ * | | |
+ * [Dequant(fp32)] [Dequant(fp32)] [Dequant(fp32)]
+ * \ | /
+ * \ | /
+ * \ | /
+ * [Conv2D(fp32)]
+ * |
+ * [Quant(u8)]
+ * |
+ * [Dequant(fp32)]
+ * |
+ * [Output(fp32)]
+ */
+template <loco::DataType T> class SimpleGraph
+{
+public:
+ void init()
+ {
+ input = g.nodes()->create<luci::CircleInput>();
+ output = g.nodes()->create<luci::CircleOutput>();
+ input->name("input");
+ output->name("output");
+
+ auto graph_input = g.inputs()->create();
+ input->index(graph_input->index());
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ graph_input->dtype(T);
+ input->dtype(T);
+ output->dtype(T);
+ graph_output->dtype(T);
+
+ graph_input->shape({1, 4, 4, 4});
+ input->shape({1, 4, 4, 4});
+ output->shape({1, 4, 4, 4});
+ graph_output->shape({1, 4, 4, 4});
+
+ set_qparam(input, 1.0, 0);
+ set_qparam(output, 1.0, 0);
+
+ auto graph_body = insertGraphBody(input);
+ output->from(graph_body);
+ }
+
+ virtual ~SimpleGraph() = default;
+
+protected:
+ virtual loco::Node *insertGraphBody(loco::Node *input) = 0;
+
+public:
+ loco::Graph g;
+ luci::CircleInput *input = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+class U8ConvGraph final : public SimpleGraph<loco::DataType::U8>
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ conv = g.nodes()->create<luci::CircleConv2D>();
+ weights = g.nodes()->create<luci::CircleConst>();
+ bias = g.nodes()->create<luci::CircleConst>();
+
+ conv->dtype(loco::DataType::U8);
+ weights->dtype(loco::DataType::U8);
+ bias->dtype(loco::DataType::S32);
+
+ conv->shape({1, 4, 4, 4});
+ weights->shape({4, 1, 1, 4});
+ bias->shape({4});
+
+ weights->size<loco::DataType::U8>(16);
+ for (uint32_t i = 0; i < 16; i++)
+ weights->at<loco::DataType::U8>(i) = i;
+
+ bias->size<loco::DataType::S32>(4);
+ for (uint32_t i = 0; i < 4; i++)
+ bias->at<loco::DataType::S32>(i) = i;
+
+ set_qparam(conv, 2.0, 127);
+ set_qparam(weights, 2.0, 127);
+ set_qparam(bias, 2.0, 127);
+
+ conv->input(input);
+ conv->filter(weights);
+ conv->bias(bias);
+
+ conv->name("conv");
+ weights->name("weights");
+ bias->name("bias");
+
+ return conv;
+ }
+
+public:
+ luci::CircleConv2D *conv = nullptr;
+ luci::CircleConst *weights = nullptr;
+ luci::CircleConst *bias = nullptr;
+};
+
+class FP32ConvGraph final : public SimpleGraph<loco::DataType::FLOAT32>
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ conv = g.nodes()->create<luci::CircleConv2D>();
+ weights = g.nodes()->create<luci::CircleConst>();
+ bias = g.nodes()->create<luci::CircleConst>();
+
+ conv->dtype(loco::DataType::FLOAT32);
+ weights->dtype(loco::DataType::FLOAT32);
+ bias->dtype(loco::DataType::FLOAT32);
+
+ conv->shape({1, 4, 4, 4});
+ weights->shape({4, 1, 1, 4});
+ bias->shape({4});
+
+ weights->size<loco::DataType::FLOAT32>(16);
+ for (uint32_t i = 0; i < 16; i++)
+ weights->at<loco::DataType::FLOAT32>(i) = i;
+
+ bias->size<loco::DataType::FLOAT32>(4);
+ for (uint32_t i = 0; i < 4; i++)
+ bias->at<loco::DataType::FLOAT32>(i) = i;
+
+ conv->input(input);
+ conv->filter(weights);
+ conv->bias(bias);
+
+ conv->name("conv");
+ weights->name("weights");
+ bias->name("bias");
+
+ return conv;
+ }
+
+public:
+ luci::CircleConv2D *conv = nullptr;
+ luci::CircleConst *weights = nullptr;
+ luci::CircleConst *bias = nullptr;
+};
+
+} // namespace
+
+TEST(ConvertToFakeQuantizedModelTest, U8Conv2D)
+{
+ U8ConvGraph g;
+ g.init();
+
+ luci::ConvertToFakeQuantizedModelPass fq;
+ fq.run(&g.g);
+
+ // Check ifm
+ check_q_dq(g.conv->input(), 1.0, 0);
+
+ // Check weights
+ check_dq(g.conv->filter());
+
+ // Check bias
+ check_dq(g.conv->bias());
+
+ // Check ofm
+ check_q_dq(g.output->from(), 2.0, 127);
+
+ SUCCEED();
+}
+
+TEST(ConvertToFakeQuantizedModelTest, F32Conv2D_NEG)
+{
+ FP32ConvGraph g;
+ g.init();
+
+ luci::ConvertToFakeQuantizedModelPass fq;
+ fq.run(&g.g);
+
+ uint32_t dequant_count = 0;
+ uint32_t quant_count = 0;
+
+ for (auto node : loco::active_nodes(loco::output_nodes(&g.g)))
+ {
+ auto cnode = loco::must_cast<luci::CircleNode *>(node);
+ auto opcode = cnode->opcode();
+ if (opcode == luci::CircleOpcode::DEQUANTIZE)
+ dequant_count++;
+ if (opcode == luci::CircleOpcode::QUANTIZE)
+ quant_count++;
+ }
+
+ // Check no quant/dequant Op is inserted
+ EXPECT_EQ(0, quant_count);
+ EXPECT_EQ(0, dequant_count);
+}
diff --git a/compiler/luci/pass/src/CopyQuantParamPass.cpp b/compiler/luci/pass/src/CopyQuantParamPass.cpp
new file mode 100644
index 000000000..9b1bb0ea9
--- /dev/null
+++ b/compiler/luci/pass/src/CopyQuantParamPass.cpp
@@ -0,0 +1,82 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/CopyQuantParamPass.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Log.h>
+
+namespace luci
+{
+
+namespace
+{
+
+struct SrcDst
+{
+ CircleNode *src = nullptr;
+ CircleNode *dst = nullptr;
+};
+
+} // namespace
+
+bool CopyQuantParamPass::run(loco::Graph *g)
+{
+ LOGGER(l);
+
+ INFO(l) << "CopyQuantParamPass Start" << std::endl;
+
+ if (_src_tensors.size() != _dst_tensors.size())
+ throw std::runtime_error("The numbers of Source/Destination tensors do not match.");
+
+ // Return src/dst CircleNodes
+ auto get_src_dst = [&g](std::string src, std::string dst) {
+ SrcDst src_dst;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto const cnode = loco::must_cast<CircleNode *>(node);
+ auto const name = cnode->name();
+ if (name == src)
+ src_dst.src = cnode;
+
+ if (name == dst)
+ src_dst.dst = cnode;
+ }
+ return src_dst;
+ };
+
+ for (uint32_t i = 0; i < _src_tensors.size(); i++)
+ {
+ auto src = _src_tensors[i];
+ auto dst = _dst_tensors[i];
+
+ auto nodes = get_src_dst(src, dst);
+ if (not nodes.src)
+ throw std::runtime_error("The tensor named " + src + " does not exist.");
+
+ if (not nodes.dst)
+ throw std::runtime_error("The tensor named " + dst + " does not exist.");
+
+ copy_quantparam(nodes.src, nodes.dst);
+
+ INFO(l) << "Quantparam of " << src << " is copied to " << dst << std::endl;
+ }
+
+ INFO(l) << "CopyQuantParamPass End" << std::endl;
+
+ return false; // one time run
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/FoldGatherPass.cpp b/compiler/luci/pass/src/FoldGatherPass.cpp
new file mode 100644
index 000000000..f179d74bd
--- /dev/null
+++ b/compiler/luci/pass/src/FoldGatherPass.cpp
@@ -0,0 +1,185 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FoldGatherPass.h"
+#include "CircleOptimizerUtils.h"
+
+#include <luci/IR/CircleNodes.h>
+
+namespace
+{
+
+/**
+ * Fold to const if
+ *
+ * 1. params: const and dtype = S32 or S64
+ * 2. indices: const and dtype = S32 or S64
+ *
+ * BEFORE
+ *
+ * [CircleConst] [CircleConst]
+ * | |
+ * +---------[Gather]---------+
+ *
+ * AFTER
+ *
+ * [CircleConst]
+ *
+ **/
+template <loco::DataType InputT, loco::DataType IndexT>
+bool fold_gather(luci::CircleGather *gather_node)
+{
+ const auto params = loco::must_cast<luci::CircleConst *>(gather_node->params());
+ const auto indices = loco::must_cast<luci::CircleConst *>(gather_node->indices());
+
+ const auto rank = params->rank();
+ auto axis = gather_node->axis();
+ if (axis < 0)
+ {
+ axis += static_cast<int32_t>(rank);
+ }
+
+ if (axis < 0 or axis >= static_cast<int32_t>(rank))
+ throw std::runtime_error("Unsupported axis value");
+
+ const auto name = gather_node->name();
+ assert(name.length() > 0);
+
+ auto constant = gather_node->graph()->nodes()->create<luci::CircleConst>();
+ constant->dtype(InputT);
+ constant->name(name + "_folded");
+
+ constant->rank(rank + indices->rank() - 1);
+
+ assert(constant->rank() > 0);
+
+ std::vector<uint32_t> shape;
+ for (uint32_t i = 0; i < rank; ++i)
+ {
+ if (i != static_cast<uint32_t>(axis))
+ {
+ const auto dim = params->dim(i).value();
+ shape.push_back(dim);
+ }
+ else
+ {
+ for (uint32_t j = 0; j < indices->rank(); ++j)
+ {
+ const auto dim = indices->dim(j).value();
+ shape.push_back(dim);
+ }
+ }
+ }
+
+ uint32_t size = 1;
+ for (uint32_t i = 0; i < shape.size(); ++i)
+ {
+ constant->dim(i).set(shape.at(i));
+ size *= shape.at(i);
+ }
+
+ constant->size<InputT>(size);
+
+ uint32_t outer_size = 1;
+ for (uint32_t i = 0; i < static_cast<uint32_t>(axis); ++i)
+ {
+ outer_size *= params->dim(i).value();
+ }
+
+ uint32_t inner_size = 1;
+ for (uint32_t i = axis + 1; i < rank; ++i)
+ {
+ inner_size *= params->dim(i).value();
+ }
+
+ uint32_t coord_size = 1;
+ for (uint32_t i = 0; i < indices->rank(); ++i)
+ {
+ coord_size *= indices->dim(i).value();
+ }
+
+ const auto axis_size = params->dim(axis).value();
+
+ for (uint32_t outer = 0; outer < outer_size; ++outer)
+ {
+ for (uint32_t i = 0; i < coord_size; ++i)
+ {
+ constant->at<InputT>((outer * coord_size + i) * inner_size) =
+ params->at<InputT>((outer * axis_size + indices->at<IndexT>(i)) * inner_size);
+ }
+ }
+ loco::replace(gather_node).with(constant);
+
+ return true;
+}
+
+bool fold_gather(luci::CircleGather *gather_node)
+{
+ const auto params = dynamic_cast<luci::CircleConst *>(gather_node->params());
+ if (not params)
+ return false;
+
+ const auto indices = dynamic_cast<luci::CircleConst *>(gather_node->indices());
+ if (not indices)
+ return false;
+
+ // TODO: support more types
+ if (params->dtype() != loco::DataType::S32 and params->dtype() != loco::DataType::S64)
+ return false;
+
+ if (indices->dtype() != loco::DataType::S32 and indices->dtype() != loco::DataType::S64)
+ throw std::runtime_error("Unsupported type");
+
+ if (params->dtype() == loco::DataType::S64)
+ {
+ if (indices->dtype() == loco::DataType::S64)
+ return fold_gather<loco::DataType::S64, loco::DataType::S64>(gather_node);
+ else
+ return fold_gather<loco::DataType::S64, loco::DataType::S32>(gather_node);
+ }
+ else
+ {
+ if (indices->dtype() == loco::DataType::S64)
+ return fold_gather<loco::DataType::S32, loco::DataType::S64>(gather_node);
+ else
+ return fold_gather<loco::DataType::S32, loco::DataType::S32>(gather_node);
+ }
+}
+
+} // namespace
+
+namespace luci
+{
+
+/**
+ * Constant Folding for Gather Op
+ **/
+bool FoldGatherPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto gather_node = dynamic_cast<luci::CircleGather *>(node))
+ {
+ if (fold_gather(gather_node))
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/FoldGatherPass.test.cpp b/compiler/luci/pass/src/FoldGatherPass.test.cpp
new file mode 100644
index 000000000..b02c034a5
--- /dev/null
+++ b/compiler/luci/pass/src/FoldGatherPass.test.cpp
@@ -0,0 +1,214 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FoldGatherPass.h"
+#include "PassTestGraphs.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+/**
+ *
+ * Graph that has a Gather S64 Op with const inputs
+ *
+ * BEFORE
+ * params: [Const] (shape: [3], values: [1, 2, 3])
+ * indices: [Const] (shape: [1], values: [1])
+ *
+ * [params] [indices]
+ * | |
+ * ---[Gather]---
+ *
+ * AFTER
+ * [Const] (shape: [1], values: [2])
+ *
+ */
+class S64FoldGatherSimpleTest : public luci::ConstantFoldingAddTestGraph, public ::testing::Test
+{
+public:
+ S64FoldGatherSimpleTest() : luci::ConstantFoldingAddTestGraph({1}, loco::DataType::S64) {}
+
+ virtual void SetUp() { init(); }
+
+ loco::Node *createFoldedPattern() override
+ {
+ _gather = _g.nodes()->create<luci::CircleGather>();
+ _params = _g.nodes()->create<luci::CircleConst>();
+ _indices = _g.nodes()->create<luci::CircleConst>();
+
+ _gather->dtype(loco::DataType::S64);
+ _params->dtype(loco::DataType::S64);
+ _indices->dtype(loco::DataType::S64);
+
+ _params->shape({3});
+ _indices->shape({1});
+
+ _params->size<loco::DataType::S64>(3);
+ _params->at<loco::DataType::S64>(0) = 1;
+ _params->at<loco::DataType::S64>(1) = 2;
+ _params->at<loco::DataType::S64>(2) = 3;
+
+ _indices->size<loco::DataType::S64>(1);
+ _indices->at<loco::DataType::S64>(0) = 1;
+
+ _gather->params(_params);
+ _gather->indices(_indices);
+
+ _gather->name("gather");
+ _params->name("params");
+ _indices->name("indices");
+
+ return _gather;
+ }
+
+protected:
+ luci::CircleGather *_gather = nullptr;
+ luci::CircleConst *_params = nullptr;
+ luci::CircleConst *_indices = nullptr;
+};
+
+/**
+ *
+ * Graph that has a Gather S32 Op with axis = 1 and with const inputs
+ *
+ * BEFORE
+ * params: [Const] (shape: [2, 3], values: [0, 1, 2, 3, 4, 5])
+ * indices: [Const] (shape: [2], values: [2, 1])
+ *
+ * [params] [indices]
+ * | |
+ * ---[Gather]---
+ *
+ * AFTER
+ * [Const] (shape: [2, 2], values: [2, 1, 5, 4])
+ *
+ */
+
+class S32FoldGatherTwoDimsTest : public luci::ConstantFoldingAddTestGraph, public ::testing::Test
+{
+public:
+ S32FoldGatherTwoDimsTest() : luci::ConstantFoldingAddTestGraph({4, 2}, loco::DataType::S32) {}
+
+ virtual void SetUp() { init(); }
+
+ loco::Node *createFoldedPattern() override
+ {
+ _gather = _g.nodes()->create<luci::CircleGather>();
+ _params = _g.nodes()->create<luci::CircleConst>();
+ _indices = _g.nodes()->create<luci::CircleConst>();
+
+ _gather->dtype(loco::DataType::S32);
+ _params->dtype(loco::DataType::S32);
+ _indices->dtype(loco::DataType::S32);
+
+ _params->shape({2, 3});
+ _indices->shape({2});
+
+ _params->size<loco::DataType::S32>(6);
+ _params->at<loco::DataType::S32>(0) = 0;
+ _params->at<loco::DataType::S32>(1) = 1;
+ _params->at<loco::DataType::S32>(2) = 2;
+ _params->at<loco::DataType::S32>(3) = 3;
+ _params->at<loco::DataType::S32>(4) = 4;
+ _params->at<loco::DataType::S32>(5) = 5;
+
+ _indices->size<loco::DataType::S32>(2);
+ _indices->at<loco::DataType::S32>(0) = 2;
+ _indices->at<loco::DataType::S32>(1) = 1;
+
+ _gather->params(_params);
+ _gather->indices(_indices);
+
+ _gather->axis(1);
+
+ _gather->name("gather");
+ _params->name("params");
+ _indices->name("indices");
+
+ return _gather;
+ }
+
+protected:
+ luci::CircleGather *_gather = nullptr;
+ luci::CircleConst *_params = nullptr;
+ luci::CircleConst *_indices = nullptr;
+};
+
+} // namespace
+
+TEST(FoldGatherTest, name)
+{
+ luci::FoldGatherPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST_F(S64FoldGatherSimpleTest, fold_gather_simple)
+{
+ luci::FoldGatherPass pass;
+ while (pass.run(graph()))
+ ;
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_NE(nullptr, folded_const);
+
+ // Chec type, shape, values of folded const
+ EXPECT_EQ(loco::DataType::S64, folded_const->dtype());
+ EXPECT_EQ(1, folded_const->rank());
+ EXPECT_EQ(1, folded_const->dim(0).value());
+ EXPECT_EQ(2, folded_const->at<loco::DataType::S64>(0));
+}
+
+TEST_F(S32FoldGatherTwoDimsTest, fold_gather_with_two_dim)
+{
+ luci::FoldGatherPass pass;
+ while (pass.run(graph()))
+ ;
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_NE(nullptr, folded_const);
+
+ // Chec type, shape, values of folded const
+ EXPECT_EQ(loco::DataType::S32, folded_const->dtype());
+ EXPECT_EQ(2, folded_const->rank());
+ EXPECT_EQ(2, folded_const->dim(0).value());
+ EXPECT_EQ(2, folded_const->dim(1).value());
+
+ EXPECT_EQ(2, folded_const->at<loco::DataType::S32>(0));
+ EXPECT_EQ(1, folded_const->at<loco::DataType::S32>(1));
+ EXPECT_EQ(5, folded_const->at<loco::DataType::S32>(2));
+ EXPECT_EQ(4, folded_const->at<loco::DataType::S32>(3));
+}
+
+TEST_F(S64FoldGatherSimpleTest, illegal_input_NEG)
+{
+ _indices->dtype(loco::DataType::FLOAT32);
+
+ luci::FoldGatherPass pass;
+ EXPECT_ANY_THROW(pass.run(graph()));
+}
+
+TEST_F(S64FoldGatherSimpleTest, illegal_axis_NEG)
+{
+ _gather->axis(1);
+
+ luci::FoldGatherPass pass;
+ EXPECT_ANY_THROW(pass.run(graph()));
+}
diff --git a/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp b/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp
index de973a431..68136b244 100644
--- a/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp
+++ b/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp
@@ -186,12 +186,12 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8)
// (1) normal case: qparam is propagated to input_1 and input_2
// (2) input used by other Op: input_1 is an input of input_2. qparam is propagated only to
// input_2
- // (3) subsequent concat: input_1 is concat. qparam is propagated only to input_2
+ // (3) subsequent concat: input_1 is concat. qparam is propagated to subsequent concat
// (4) const input: input_1 is const. constant values are quantized
// normal case: qparam of concat_node is propagated to input_1 and input_2
SimpleConcatGraph g(loco::DataType::U8);
- luci::propagate_concat_quantparam(&g.concat_node, loco::DataType::U8);
+ luci::propagate_concat_quantparam(&g.concat_node);
EXPECT_FLOAT_EQ(3.14, g.concat_node.quantparam()->scale[0]);
EXPECT_EQ(77, g.concat_node.quantparam()->zerop[0]);
EXPECT_FLOAT_EQ(3.14, g.input_1.quantparam()->scale[0]);
@@ -202,7 +202,7 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8)
// input_1 is an input of input_2. qparam is propagated only to input_2
SimpleConcatGraph g2(loco::DataType::U8);
g2.input_2.input(&g2.input_1);
- luci::propagate_concat_quantparam(&g2.concat_node, loco::DataType::U8);
+ luci::propagate_concat_quantparam(&g2.concat_node);
EXPECT_FLOAT_EQ(3.14, g2.concat_node.quantparam()->scale[0]);
EXPECT_EQ(77, g2.concat_node.quantparam()->zerop[0]);
EXPECT_FLOAT_EQ(1.0, g2.input_1.quantparam()->scale[0]);
@@ -210,19 +210,19 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8)
EXPECT_FLOAT_EQ(3.14, g2.input_2.quantparam()->scale[0]);
EXPECT_EQ(77, g2.input_2.quantparam()->zerop[0]);
- // input_1 is concat. qparam is propagated only to input_2
+ // input_1 is concat. qparam is propagated to subsequent concat
SubsequentConcatGraph sg(loco::DataType::U8);
- luci::propagate_concat_quantparam(&sg.concat_node, loco::DataType::U8);
+ luci::propagate_concat_quantparam(&sg.concat_node);
EXPECT_FLOAT_EQ(3.14, sg.concat_node.quantparam()->scale[0]);
EXPECT_EQ(77, sg.concat_node.quantparam()->zerop[0]);
- EXPECT_FLOAT_EQ(1.0, sg.input_1.quantparam()->scale[0]);
- EXPECT_EQ(1, sg.input_1.quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(3.14, sg.input_1.quantparam()->scale[0]);
+ EXPECT_EQ(77, sg.input_1.quantparam()->zerop[0]);
EXPECT_FLOAT_EQ(3.14, sg.input_2.quantparam()->scale[0]);
EXPECT_EQ(77, sg.input_2.quantparam()->zerop[0]);
// input_1 is const. const values are quantized with the qparam of concat
ConstInputConcatGraph cg(loco::DataType::U8);
- luci::propagate_concat_quantparam(cg.concat_node, loco::DataType::U8);
+ luci::propagate_concat_quantparam(cg.concat_node);
EXPECT_FLOAT_EQ(0.1, cg.concat_node->quantparam()->scale[0]);
EXPECT_EQ(10, cg.concat_node->quantparam()->zerop[0]);
const auto cg_input_1 = loco::must_cast<luci::CircleConst *>(cg.concat_node->values(0));
@@ -248,7 +248,7 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8_NEG)
// concat has fused activation function
g.concat_node.fusedActivationFunction(luci::FusedActFunc::RELU);
- luci::propagate_concat_quantparam(&g.concat_node, loco::DataType::U8);
+ luci::propagate_concat_quantparam(&g.concat_node);
EXPECT_FLOAT_EQ(3.14, g.concat_node.quantparam()->scale[0]);
EXPECT_EQ(77, g.concat_node.quantparam()->zerop[0]);
EXPECT_FLOAT_EQ(1.0, g.input_1.quantparam()->scale[0]);
@@ -261,7 +261,7 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8_NEG)
// const values are quantized using its min/max
ConstInputConcatGraph cg(loco::DataType::U8);
cg.concat_node->fusedActivationFunction(luci::FusedActFunc::RELU);
- luci::propagate_concat_quantparam(cg.concat_node, loco::DataType::U8);
+ luci::propagate_concat_quantparam(cg.concat_node);
EXPECT_FLOAT_EQ(0.1, cg.concat_node->quantparam()->scale[0]);
EXPECT_EQ(10, cg.concat_node->quantparam()->zerop[0]);
const auto cg_input_1 = loco::must_cast<luci::CircleConst *>(cg.concat_node->values(0));
@@ -283,12 +283,12 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16)
// (1) normal case: qparam is propagated to input_1 and input_2
// (2) input used by other Op: input_1 is an input of input_2. qparam is propagated only to
// input_2
- // (3) subsequent concat: input_1 is concat. qparam is propagated only to input_2
+ // (3) subsequent concat: input_1 is concat. qparam is propagated to subsequent concat
// (4) const input: input_1 is const. constant values are quantized
// normal case: qparam of concat_node is propagated to input_1 and input_2
SimpleConcatGraph g(loco::DataType::S16);
- luci::propagate_concat_quantparam(&g.concat_node, loco::DataType::S16);
+ luci::propagate_concat_quantparam(&g.concat_node);
EXPECT_FLOAT_EQ(3.14, g.concat_node.quantparam()->scale[0]);
EXPECT_EQ(0, g.concat_node.quantparam()->zerop[0]);
EXPECT_FLOAT_EQ(3.14, g.input_1.quantparam()->scale[0]);
@@ -299,7 +299,7 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16)
// input_1 is an input of input_2. qparam is propagated only to input_2
SimpleConcatGraph g2(loco::DataType::S16);
g2.input_2.input(&g2.input_1);
- luci::propagate_concat_quantparam(&g2.concat_node, loco::DataType::S16);
+ luci::propagate_concat_quantparam(&g2.concat_node);
EXPECT_FLOAT_EQ(3.14, g2.concat_node.quantparam()->scale[0]);
EXPECT_EQ(0, g2.concat_node.quantparam()->zerop[0]);
EXPECT_FLOAT_EQ(1.0, g2.input_1.quantparam()->scale[0]);
@@ -309,17 +309,17 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16)
// input_1 is concat. qparam is propagated only to input_2
SubsequentConcatGraph sg(loco::DataType::S16);
- luci::propagate_concat_quantparam(&sg.concat_node, loco::DataType::S16);
+ luci::propagate_concat_quantparam(&sg.concat_node);
EXPECT_FLOAT_EQ(3.14, sg.concat_node.quantparam()->scale[0]);
EXPECT_EQ(0, sg.concat_node.quantparam()->zerop[0]);
- EXPECT_FLOAT_EQ(1.0, sg.input_1.quantparam()->scale[0]);
+ EXPECT_FLOAT_EQ(3.14, sg.input_1.quantparam()->scale[0]);
EXPECT_EQ(0, sg.input_1.quantparam()->zerop[0]);
EXPECT_FLOAT_EQ(3.14, sg.input_2.quantparam()->scale[0]);
EXPECT_EQ(0, sg.input_2.quantparam()->zerop[0]);
// input_1 is const. const values are quantized with the qparam of concat
ConstInputConcatGraph cg(loco::DataType::S16);
- luci::propagate_concat_quantparam(cg.concat_node, loco::DataType::S16);
+ luci::propagate_concat_quantparam(cg.concat_node);
EXPECT_FLOAT_EQ(0.1, cg.concat_node->quantparam()->scale[0]);
EXPECT_EQ(0, cg.concat_node->quantparam()->zerop[0]);
const auto cg_input_1 = loco::must_cast<luci::CircleConst *>(cg.concat_node->values(0));
@@ -345,7 +345,7 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16_NEG)
// concat has fused activation function
g.concat_node.fusedActivationFunction(luci::FusedActFunc::RELU);
- luci::propagate_concat_quantparam(&g.concat_node, loco::DataType::S16);
+ luci::propagate_concat_quantparam(&g.concat_node);
EXPECT_FLOAT_EQ(3.14, g.concat_node.quantparam()->scale[0]);
EXPECT_EQ(0, g.concat_node.quantparam()->zerop[0]);
EXPECT_FLOAT_EQ(1.0, g.input_1.quantparam()->scale[0]);
@@ -358,7 +358,7 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16_NEG)
// const values are quantized using its min/max
ConstInputConcatGraph cg(loco::DataType::S16);
cg.concat_node->fusedActivationFunction(luci::FusedActFunc::RELU);
- luci::propagate_concat_quantparam(cg.concat_node, loco::DataType::S16);
+ luci::propagate_concat_quantparam(cg.concat_node);
EXPECT_FLOAT_EQ(0.1, cg.concat_node->quantparam()->scale[0]);
EXPECT_EQ(0, cg.concat_node->quantparam()->zerop[0]);
const auto cg_input_1 = loco::must_cast<luci::CircleConst *>(cg.concat_node->values(0));
diff --git a/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp b/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp
new file mode 100644
index 000000000..b4975486d
--- /dev/null
+++ b/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp
@@ -0,0 +1,482 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/PropagateQParamBackwardPass.h"
+#include "QuantizationUtils.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Service/Nodes/CircleConst.h>
+#include <luci/Log.h>
+
+#include <cmath>
+
+namespace
+{
+
+void quant_const_values(luci::CircleConst *const_node, float scaling_factor, float zerop,
+ loco::DataType quant_type)
+{
+ uint32_t size = const_node->size<loco::DataType::FLOAT32>();
+
+ const float scaling_factor_inv = 1.0 / scaling_factor;
+ std::vector<int32_t> quantized_values(size);
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ auto data = static_cast<double>(const_node->at<loco::DataType::FLOAT32>(i));
+ double quantized_data = std::round(data * scaling_factor_inv) + zerop;
+ constexpr double int_max = static_cast<double>(std::numeric_limits<int32_t>::max());
+ constexpr double int_min = static_cast<double>(std::numeric_limits<int32_t>::min());
+ quantized_data = std::min(int_max, std::max(int_min, quantized_data));
+
+ quantized_values[i] = static_cast<int32_t>(quantized_data);
+ }
+
+ switch (quant_type)
+ {
+ case loco::DataType::U8:
+ const_node->dtype(loco::DataType::U8); // change the type of tensor
+ const_node->size<loco::DataType::U8>(size); // resize tensor
+ for (uint32_t i = 0; i < size; ++i)
+ const_node->at<loco::DataType::U8>(i) = std::min(255, std::max(0, quantized_values[i]));
+ break;
+ case loco::DataType::S16:
+ assert(zerop == 0);
+ const_node->dtype(loco::DataType::S16); // change the type of tensor
+ const_node->size<loco::DataType::S16>(size); // resize tensor
+ for (uint32_t i = 0; i < size; ++i)
+ const_node->at<loco::DataType::S16>(i) =
+ std::min(32767, std::max(-32767, quantized_values[i]));
+ break;
+ default:
+ throw std::runtime_error("Unsupported data type");
+ }
+}
+
+void overwrite_quantparam(const luci::CircleNode *source, luci::CircleNode *target)
+{
+ auto source_qparam = source->quantparam();
+ if (source_qparam == nullptr)
+ throw std::runtime_error("source quantparam is not found during overwrite");
+
+ auto target_qparam = target->quantparam();
+ if (target_qparam == nullptr)
+ {
+ auto quantparam = std::make_unique<luci::CircleQuantParam>();
+ target->quantparam(std::move(quantparam));
+ target_qparam = target->quantparam();
+
+ if (target_qparam == nullptr)
+ throw std::runtime_error("Creating new quant param failed");
+ }
+ target_qparam->min = source_qparam->min;
+ target_qparam->max = source_qparam->max;
+ target_qparam->scale = source_qparam->scale;
+ target_qparam->zerop = source_qparam->zerop;
+ target_qparam->quantized_dimension = source_qparam->quantized_dimension;
+}
+
+/**
+ * Tells if pad_v2 quantization should ignore padding value
+ * In that case padding const will be quantized with input parameters, and probably clipped
+ */
+bool ignore_pad_v2_const_quantization(const luci::CirclePadV2 *pad)
+{
+ // This is a workaround to quantize pad generated from MaxPoolWithArgmax operation properly
+ // TODO use metadata hints to detect this case
+ auto const_value_node = dynamic_cast<const luci::CircleConst *>(pad->arg(2));
+ if (!const_value_node)
+ return false;
+ if (const_value_node->dtype() == loco::DataType::FLOAT32)
+ {
+ float const_value = const_value_node->at<loco::DataType::FLOAT32>(0);
+ if (const_value == std::numeric_limits<float>::lowest())
+ return true;
+ }
+ return false;
+}
+
+/** EXAMPLE
+ *
+ * BEFORE
+ *
+ * [CircleNode] [CircleConst]
+ * (qparam1) (FP32)
+ * \ /
+ * \ /
+ * [CirclePack]
+ * (qparam2)
+ *
+ * AFTER
+ *
+ * [CircleNode] [CircleConst] [CircleConst] <- Dead node
+ * (qparam2) (qparam2) (FP32)
+ * \ /
+ * \ /
+ * [CirclePack]
+ * (qparam2)
+ *
+ * NOTE Quantization parameter of CirclePack (qparam2) is propagated to the inputs.
+ */
+void propagate_pack_quantparam(luci::CirclePack *pack)
+{
+ assert(pack->quantparam() != nullptr);
+
+ const auto num_inputs = pack->values_count();
+
+ for (uint32_t i = 0; i < num_inputs; i++)
+ {
+ auto node = loco::must_cast<luci::CircleNode *>(pack->arg(i));
+
+ // Quantize constant values
+ if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
+ {
+ luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
+ if (const_node->dtype() != loco::DataType::FLOAT32)
+ throw std::runtime_error("Unsupported data type for constant input of pack Op");
+
+ const auto pack_qparam = pack->quantparam();
+ if (pack_qparam == nullptr)
+ throw std::runtime_error("quantparam of pack is not found during propagation");
+
+ assert(pack_qparam->scale.size() == 1);
+ assert(pack_qparam->zerop.size() == 1);
+ const auto scaling_factor = pack_qparam->scale[0];
+ const auto zerop = pack_qparam->zerop[0];
+
+ auto new_const = luci::clone(const_node);
+ quant_const_values(new_const, scaling_factor, zerop, pack->dtype());
+ pack->values(i, new_const);
+ overwrite_quantparam(pack, new_const);
+ }
+ else
+ {
+ const auto succs = loco::succs(node);
+ if (succs.size() > 1)
+ continue;
+
+ // Non-const input must have been quantized
+ assert(node->quantparam() != nullptr);
+ overwrite_quantparam(pack, node);
+ }
+ }
+}
+
+/** EXAMPLE
+ *
+ *
+ *
+ * BEFORE
+ *
+ * [CircleNode] [CircleConst] [CircleConst] [CircleNode]
+ * (S32) (S32) (FP32) (U8 qparam1)
+ * \ \ / /
+ * \ \ / /
+ * \ \ / /
+ * -------[CircleOneHot]-------
+ * (U8 qparam2)
+ *
+ * AFTER
+ *
+ * [CircleNode] [CircleConst] [CircleConst] [CircleNode] [CircleConst] <- Dead node
+ * (S32) (S32) (U8 qparam2) (U8 qparam2) (FP32)
+ * \ \ / /
+ * \ \ / /
+ * \ \ / /
+ * -------[CircleOneHot]-------
+ * (U8 qparam2)
+ *
+ * NOTE Quantization parameter of CircleOneHot (qparam2) is propagated to on_value/off_value.
+ */
+void propagate_one_hot_quantparam(luci::CircleOneHot *one_hot)
+{
+ assert(one_hot->quantparam() != nullptr);
+
+ // Propagate quantization parameters from output to inputs,
+ // to fit both input and counstant_value in one quant range.
+ auto quant_input = [one_hot](void (luci::CircleOneHot::*arg_setter)(loco::Node *),
+ loco::Node *(luci::CircleOneHot::*arg_getter)() const) {
+ auto node = loco::must_cast<luci::CircleNode *>((one_hot->*arg_getter)());
+
+ // Quantize constant values
+ if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
+ {
+ luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
+ if (is_quantized(const_node))
+ return;
+
+ if (const_node->dtype() != loco::DataType::FLOAT32)
+ throw std::runtime_error("Unsupported data type for constant input of OneHot Op");
+
+ const auto qparam = one_hot->quantparam();
+ if (qparam == nullptr)
+ throw std::runtime_error("quantparam of OneHot is not found during propagation");
+
+ assert(qparam->scale.size() == 1);
+ const auto scaling_factor = qparam->scale.at(0);
+ const auto zerop = qparam->zerop.at(0);
+
+ auto new_const = luci::clone(const_node);
+ quant_const_values(new_const, scaling_factor, zerop, one_hot->dtype());
+ overwrite_quantparam(one_hot, new_const);
+ (one_hot->*arg_setter)(new_const);
+ }
+ else
+ {
+ const auto succs = loco::succs(node);
+ if (succs.size() > 1)
+ return;
+
+ // Non-const input must have been quantized
+ assert(node->quantparam() != nullptr);
+ overwrite_quantparam(one_hot, node);
+ }
+ };
+
+ quant_input(&luci::CircleOneHot::on_value, &luci::CircleOneHot::on_value);
+ quant_input(&luci::CircleOneHot::off_value, &luci::CircleOneHot::off_value);
+}
+
+} // namespace
+
+namespace luci
+{
+
+/** BEFORE
+ *
+ * [CircleNode] [CircleConst]
+ * (U8 qparam1) (FP32)
+ * \ /
+ * \ /
+ * [CircleConcatenation]
+ * (U8 qparam2)
+ *
+ * AFTER
+ * [CircleNode] [CircleConst] [CircleConst] <- Dead node
+ * (U8 qparam2) (U8 qparam2) (FP32)
+ * \ /
+ * \ /
+ * [CircleConcatenation]
+ * (U8 qparam2)
+ */
+void propagate_concat_quantparam(luci::CircleConcatenation *concat)
+{
+ assert(concat->quantparam() != nullptr);
+
+ const auto num_inputs = concat->numValues();
+
+ // Quantize const inputs using their values if concat has fused act function
+ if (concat->fusedActivationFunction() != luci::FusedActFunc::NONE)
+ {
+ for (uint32_t i = 0; i < num_inputs; i++)
+ {
+ auto node = concat->arg(i);
+ auto const_node = dynamic_cast<luci::CircleConst *>(node);
+ if (const_node != nullptr)
+ {
+ auto new_const = luci::clone(const_node);
+ quant_const(new_const, concat->dtype());
+ concat->values(i, new_const);
+ }
+ }
+ return;
+ }
+
+ for (uint32_t i = 0; i < num_inputs; i++)
+ {
+ auto node = loco::must_cast<luci::CircleNode *>(concat->arg(i));
+
+ // Quantize constant values
+ if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
+ {
+ luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
+
+ const auto concat_qparam = concat->quantparam();
+ assert(concat_qparam->scale.size() == 1);
+ const auto scaling_factor = concat_qparam->scale[0];
+ const auto zerop = concat_qparam->zerop[0];
+
+ auto new_const = luci::clone(const_node);
+ quant_const_values(new_const, scaling_factor, zerop, concat->dtype());
+ concat->values(i, new_const);
+ overwrite_quantparam(concat, new_const);
+ }
+ else
+ {
+ const auto succs = loco::succs(node);
+ if (succs.size() > 1)
+ continue;
+
+ // Non-const input must have been quantized
+ assert(node->quantparam() != nullptr);
+ overwrite_quantparam(concat, node);
+ }
+ }
+}
+
+/** BEFORE
+ *
+ * [CircleNode] [CircleConst] [CircleConst]
+ * (U8 qparam1) (S32) (FP32)
+ * \ | /
+ * \ | /
+ * [CirclePadV2]
+ * (U8 qparam2)
+ *
+ * AFTER (case 1)
+ *
+ * By default qparam is propagated from output to inputs to meet backend requirements.
+ *
+ * [CircleNode] [CircleConst] [CircleConst] [CircleConst] <- Dead node
+ * (U8 qparam2) (S32) (U8 qparam2) (FP32)
+ * \ | /
+ * \ | /
+ * [CirclePadV2]
+ * (U8 qparam2)
+ *
+ * AFTER (case 2)
+ *
+ * In case padded value is the lowest float value
+ * Qparam is propagated from input to output and constant.
+ *
+ * This is a special case for optimization constructed pad, needed to guarantee that
+ * extremely large negative constant do not stretch output quantization range.
+ *
+ * [CircleNode] [CircleConst] [CircleConst] [CircleConst] <- Dead node
+ * (U8 qparam1) (S32) (U8 qparam1) (FP32)
+ * \ | /
+ * \ | /
+ * [CirclePadV2]
+ * (U8 qparam1)
+ */
+void propagate_pad_v2_quantparam(luci::CirclePadV2 *pad_v2)
+{
+ if (ignore_pad_v2_const_quantization(pad_v2))
+ {
+ // propagate input quantization paramters from input to output and padding const value
+ auto pad_v2_input = loco::must_cast<luci::CircleNode *>(pad_v2->arg(0));
+ overwrite_quantparam(pad_v2_input, pad_v2);
+
+ auto const_value_node = loco::must_cast<luci::CircleConst *>(
+ pad_v2->arg(2)); // FIX ignore_pad_v2_const_quantization UNLESS
+ auto new_const = luci::clone(const_value_node);
+
+ const auto pad_v2_input_qparam = pad_v2_input->quantparam();
+ assert(pad_v2_input_qparam != nullptr);
+ assert(pad_v2_input_qparam->scale.size() == 1);
+ const auto scaling_factor = pad_v2_input_qparam->scale.at(0);
+ const auto zerop = pad_v2_input_qparam->zerop.at(0);
+
+ quant_const_values(new_const, scaling_factor, zerop, pad_v2->dtype());
+ overwrite_quantparam(pad_v2_input, new_const);
+ pad_v2->constant_values(new_const);
+ return;
+ }
+
+ // Propagate quantization paramters from output to inputs,
+ // to fit both input and counstant_value in one quant range.
+ auto quant_input = [pad_v2](void (CirclePadV2::*arg_setter)(loco::Node *), uint32_t arg) {
+ auto node = loco::must_cast<luci::CircleNode *>(pad_v2->arg(arg));
+
+ // Quantize constant values
+ if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
+ {
+ luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
+ if (is_quantized(const_node))
+ return;
+
+ if (const_node->dtype() != loco::DataType::FLOAT32)
+ throw std::runtime_error("Unsupported data type for constant input of PadV2 Op");
+
+ const auto pad_v2_qparam = pad_v2->quantparam();
+ if (pad_v2_qparam == nullptr)
+ throw std::runtime_error("quantparam of PadV2 is not found during propagation");
+
+ assert(pad_v2_qparam->scale.size() == 1);
+ const auto scaling_factor = pad_v2_qparam->scale.at(0);
+ const auto zerop = pad_v2_qparam->zerop.at(0);
+
+ auto new_const = luci::clone(const_node);
+ quant_const_values(new_const, scaling_factor, zerop, pad_v2->dtype());
+ overwrite_quantparam(pad_v2, new_const);
+ (pad_v2->*arg_setter)(new_const);
+ }
+ else
+ {
+ const auto succs = loco::succs(node);
+ if (succs.size() > 1)
+ return;
+
+ // Non-const input must have been quantized
+ assert(node->quantparam() != nullptr);
+ overwrite_quantparam(pad_v2, node);
+ }
+ };
+
+ quant_input(&CirclePadV2::input, 0);
+ quant_input(&CirclePadV2::constant_values, 2);
+}
+
+} // namespace luci
+
+namespace
+{
+
+// Visitor to propagate quantization parameters backwards
+struct PropagateQParamBackward final : public luci::CircleNodeMutableVisitor<void>
+{
+ void visit(luci::CircleNode *) {}
+
+ void visit(luci::CircleConcatenation *node) { propagate_concat_quantparam(node); }
+
+ void visit(luci::CircleOneHot *node) { propagate_one_hot_quantparam(node); }
+
+ void visit(luci::CirclePack *node) { propagate_pack_quantparam(node); }
+
+ void visit(luci::CirclePadV2 *node) { propagate_pad_v2_quantparam(node); }
+};
+
+} // namespace
+
+namespace luci
+{
+
+bool PropagateQParamBackwardPass::run(loco::Graph *g)
+{
+ LOGGER(l);
+
+ // We use reverse post-order traversal as qparam is propagated backward
+ auto nodes = loco::postorder_traversal(loco::output_nodes(g));
+ std::reverse(nodes.begin(), nodes.end());
+ for (auto node : nodes)
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ INFO(l) << "PropagateQParamBackwardPass visit node: " << circle_node->name() << std::endl;
+
+ // We can't propagate non-existent qparam
+ if (circle_node->quantparam() == nullptr)
+ continue;
+
+ PropagateQParamBackward pqb;
+ circle_node->accept(&pqb);
+ }
+
+ // This pass is only run once, so return false
+ // TODO Refactoring not to return meaningless value
+ return false;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/PropagateQParamBackwardPass.test.cpp b/compiler/luci/pass/src/PropagateQParamBackwardPass.test.cpp
new file mode 100644
index 000000000..33af70449
--- /dev/null
+++ b/compiler/luci/pass/src/PropagateQParamBackwardPass.test.cpp
@@ -0,0 +1,167 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/PropagateQParamBackwardPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+using namespace luci;
+
+namespace
+{
+
+void set_qparam(luci::CircleNode *node, float scale, int64_t zp)
+{
+ auto qparam = std::make_unique<luci::CircleQuantParam>();
+ qparam->scale.emplace_back(scale);
+ qparam->zerop.emplace_back(zp);
+
+ node->quantparam(std::move(qparam));
+}
+
+/**
+ * @brief Base Test Graph
+ */
+struct TestGraph
+{
+public:
+ virtual void init(void) = 0;
+};
+
+/**
+ * Graph with two concats
+ *
+ * [CircleInput] [CircleConst]
+ * \ /
+ * [CircleConcatenation] [CircleConst]
+ * | |
+ * [CircleConcatenation]
+ * |
+ * [CircleOutput]
+ *
+ * BEFORE
+ * - Concat1 and Concat 2 have different qparams
+ *
+ * AFTER
+ * - All Ops have the same qparam
+ */
+struct SubsequentConcatGraph : public TestGraph
+{
+public:
+ void init(void) final
+ {
+ // graph input and output
+ auto graph_input = g.inputs()->create();
+ auto graph_output = g.outputs()->create();
+
+ // input
+ input = g.nodes()->create<luci::CircleInput>();
+ input->index(graph_input->index());
+ input->shape({1, 4, 4, 3});
+ input->dtype(loco::DataType::U8);
+ set_qparam(input, 1.0, 1);
+
+ // const1
+ const1 = g.nodes()->create<luci::CircleConst>();
+ const1->shape({1, 4, 4, 3});
+ const1->dtype(loco::DataType::FLOAT32);
+ const1->size<loco::DataType::FLOAT32>(48);
+ for (uint32_t i = 0; i < 48; i++)
+ const1->at<loco::DataType::FLOAT32>(i) = i;
+
+ // concat1
+ concat1 = g.nodes()->create<luci::CircleConcatenation>(2);
+ concat1->shape({1, 4, 4, 6});
+ concat1->dtype(loco::DataType::U8);
+ set_qparam(concat1, 2.0, 2);
+ concat1->values(0, input);
+ concat1->values(1, const1);
+ concat1->fusedActivationFunction(luci::FusedActFunc::NONE);
+
+ // const2
+ const2 = g.nodes()->create<luci::CircleConst>();
+ const2->shape({1, 4, 4, 3});
+ const2->dtype(loco::DataType::FLOAT32);
+ const2->size<loco::DataType::FLOAT32>(48);
+ for (uint32_t i = 0; i < 48; i++)
+ const2->at<loco::DataType::FLOAT32>(i) = i;
+
+ // concat2
+ concat2 = g.nodes()->create<luci::CircleConcatenation>(2);
+ concat2->shape({1, 4, 4, 9});
+ concat2->dtype(loco::DataType::U8);
+ set_qparam(concat2, 3.0, 3);
+ concat2->values(0, concat1);
+ concat2->values(1, const2);
+ concat2->fusedActivationFunction(luci::FusedActFunc::NONE);
+
+ // output
+ output = g.nodes()->create<luci::CircleOutput>();
+ output->index(graph_output->index());
+ output->from(concat2);
+ output->shape({1, 4, 4, 9});
+ output->dtype(loco::DataType::U8);
+ set_qparam(output, 3.0, 3);
+ }
+
+public:
+ loco::Graph g;
+ CircleInput *input = nullptr;
+ CircleConcatenation *concat1 = nullptr;
+ CircleConcatenation *concat2 = nullptr;
+ CircleConst *const1 = nullptr;
+ CircleConst *const2 = nullptr;
+ CircleOutput *output = nullptr;
+};
+
+} // namespace
+
+TEST(PropagateQParamBackwardPassTest, name)
+{
+ luci::PropagateQParamBackwardPass pass(loco::DataType::U8);
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST(PropagateQParamBackwardPassTest, subsequent_propagation)
+{
+ SubsequentConcatGraph graph;
+
+ graph.init();
+
+ luci::PropagateQParamBackwardPass pass(loco::DataType::U8);
+
+ pass.run(&graph.g);
+
+ EXPECT_EQ(3.0, graph.concat2->quantparam()->scale[0]);
+ EXPECT_EQ(3, graph.concat2->quantparam()->zerop[0]);
+
+ auto const2 = loco::must_cast<CircleNode *>(graph.concat2->values(1));
+ EXPECT_EQ(3.0, const2->quantparam()->scale[0]);
+ EXPECT_EQ(3, const2->quantparam()->zerop[0]);
+
+ EXPECT_EQ(3.0, graph.concat1->quantparam()->scale[0]);
+ EXPECT_EQ(3, graph.concat1->quantparam()->zerop[0]);
+
+ auto const1 = loco::must_cast<CircleNode *>(graph.concat1->values(1));
+ EXPECT_EQ(3.0, const1->quantparam()->scale[0]);
+ EXPECT_EQ(3, const1->quantparam()->zerop[0]);
+
+ EXPECT_EQ(3.0, graph.input->quantparam()->scale[0]);
+ EXPECT_EQ(3, graph.input->quantparam()->zerop[0]);
+}
diff --git a/compiler/luci/pass/src/PropagateQParamForwardPass.cpp b/compiler/luci/pass/src/PropagateQParamForwardPass.cpp
new file mode 100644
index 000000000..003e4c293
--- /dev/null
+++ b/compiler/luci/pass/src/PropagateQParamForwardPass.cpp
@@ -0,0 +1,194 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/PropagateQParamForwardPass.h"
+
+#include "QuantizationUtils.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Log.h>
+
+#include <iostream>
+
+namespace
+{
+
+bool copy_qparam(luci::CircleQuantParam *src, luci::CircleQuantParam *dst)
+{
+ assert(src->scale.size() == dst->scale.size());
+ assert(src->zerop.size() == dst->zerop.size());
+
+ // src and dst have the same qparam
+ if (std::equal(src->scale.begin(), src->scale.end(), dst->scale.begin()) &&
+ std::equal(src->zerop.begin(), src->zerop.end(), dst->zerop.begin()) &&
+ src->quantized_dimension == dst->quantized_dimension)
+ return false;
+
+ dst->scale.assign(src->scale.begin(), src->scale.end());
+ dst->zerop.assign(src->zerop.begin(), src->zerop.end());
+ dst->quantized_dimension = src->quantized_dimension;
+ return true;
+}
+
+bool copy_qparam(luci::CircleNode *src, luci::CircleNode *dst)
+{
+ // Skip nodes that do not have quantparams
+ auto src_qparam = src->quantparam();
+ if (not src_qparam)
+ return false;
+
+ auto dst_qparam = dst->quantparam();
+ if (not dst_qparam)
+ return false;
+
+ return copy_qparam(src_qparam, dst_qparam);
+}
+
+// Visitor to propagate quantization parameters
+struct PropagateQParamForward final : public luci::CircleNodeMutableVisitor<bool>
+{
+ PropagateQParamForward() = default;
+
+ bool visit(luci::CircleNode *) { return false; }
+
+ bool visit(luci::CircleGather *node)
+ {
+ auto input_node = loco::must_cast<luci::CircleNode *>(node->params());
+ return copy_qparam(input_node, node);
+ }
+
+ bool visit(luci::CircleReshape *node)
+ {
+ auto input_node = loco::must_cast<luci::CircleNode *>(node->tensor());
+ return copy_qparam(input_node, node);
+ }
+
+ bool visit(luci::CircleTranspose *node)
+ {
+ auto input_node = loco::must_cast<luci::CircleNode *>(node->a());
+ return copy_qparam(input_node, node);
+ }
+
+ bool visit(luci::CircleStridedSlice *node)
+ {
+ auto input_node = loco::must_cast<luci::CircleNode *>(node->input());
+ return copy_qparam(input_node, node);
+ }
+
+ bool visit(luci::CircleSplitOut *node)
+ {
+ auto split = loco::must_cast<luci::CircleSplit *>(node->input());
+ auto input_node = loco::must_cast<luci::CircleNode *>(split->input());
+ return copy_qparam(input_node, node);
+ }
+
+ bool visit(luci::CircleSplitVOut *node)
+ {
+ auto splitv = loco::must_cast<luci::CircleSplitV *>(node->input());
+ auto input_node = loco::must_cast<luci::CircleNode *>(splitv->input());
+ return copy_qparam(input_node, node);
+ }
+
+ bool visit(luci::CircleUnpackOut *node)
+ {
+ auto unpack = loco::must_cast<luci::CircleUnpack *>(node->input());
+ auto input_node = loco::must_cast<luci::CircleNode *>(unpack->value());
+ return copy_qparam(input_node, node);
+ }
+
+ // Propagate qparam across Quantize op to ensure
+ // special qparams (pre-defined values, integer scale)
+ bool visit(luci::CircleQuantize *node)
+ {
+ auto input_node = loco::must_cast<luci::CircleNode *>(node->input());
+
+ // Skip if input_node is not quantized activation
+ if (input_node->dtype() != loco::DataType::U8 and input_node->dtype() != loco::DataType::S16)
+ return false;
+
+ // If input_node and node have the same dtype, Quantize op
+ // will do rescale, not requantize for mixed-precision
+ if (input_node->dtype() == node->dtype())
+ return false;
+
+ assert(node->dtype() == loco::DataType::U8 or node->dtype() == loco::DataType::S16);
+
+ auto prev_qparam = node->quantparam();
+ assert(prev_qparam);
+ assert(prev_qparam->scale.size() == 1);
+ assert(prev_qparam->zerop.size() == 1);
+
+ const auto prev_scale = prev_qparam->scale[0];
+ const auto prev_zerop = prev_qparam->zerop[0];
+
+ auto qtype = luci::activation_qtype(input_node);
+ switch (qtype)
+ {
+ case luci::ActivationQType::PreDefinedValue:
+ node->quantparam(luci::make_predefined_qparam(input_node->opcode(), node->dtype()));
+ break;
+ case luci::ActivationQType::IntScale:
+ luci::set_int_scale(node);
+ break;
+ default:
+ break;
+ }
+
+ assert(node->quantparam());
+ assert(node->quantparam()->scale.size() == 1);
+ assert(node->quantparam()->zerop.size() == 1);
+
+ const auto scale = node->quantparam()->scale[0];
+ const auto zerop = node->quantparam()->zerop[0];
+
+ // Compare qparam with saved values to detect update
+ return scale != prev_scale or zerop != prev_zerop;
+ }
+};
+
+} // namespace
+
+namespace luci
+{
+
+bool PropagateQParamForwardPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ LOGGER(l);
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ INFO(l) << "PropagateQParamForwardPass visit node: " << circle_node->name() << std::endl;
+
+ PropagateQParamForward pqp;
+ if (circle_node->accept(&pqp))
+ changed = true;
+
+ if (_TF_style_maxpool)
+ {
+ if (auto maxpool = dynamic_cast<luci::CircleMaxPool2D *>(node))
+ {
+ auto input = loco::must_cast<luci::CircleNode *>(maxpool->value());
+ copy_qparam(input, maxpool);
+ }
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/PropagateQParamForwardPass.test.cpp b/compiler/luci/pass/src/PropagateQParamForwardPass.test.cpp
new file mode 100644
index 000000000..a734c0873
--- /dev/null
+++ b/compiler/luci/pass/src/PropagateQParamForwardPass.test.cpp
@@ -0,0 +1,260 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/PropagateQParamForwardPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+void addQuantParam(luci::CircleNode *node, const std::vector<float> &scale,
+ const std::vector<int64_t> &zp)
+{
+ assert(node->quantparam() == nullptr);
+
+ auto quantparam = std::make_unique<luci::CircleQuantParam>();
+ quantparam->scale = scale;
+ quantparam->zerop = zp;
+ node->quantparam(std::move(quantparam));
+}
+
+/**
+ * Simple graph for test
+ *
+ * BEFORE
+ *
+ * [Conv] (qparam 1)
+ * |
+ * [Reshape] (qparam 2)
+ *
+ * AFTER
+ *
+ * [Conv] (qparam 2)
+ * |
+ * [Reshape] (qparam 2)
+ *
+ */
+class SimpleGraph
+{
+public:
+ SimpleGraph()
+ {
+ input = g.nodes()->create<luci::CircleInput>();
+ conv = g.nodes()->create<luci::CircleConv2D>();
+ reshape = g.nodes()->create<luci::CircleReshape>();
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_input = g.inputs()->create();
+ input->index(graph_input->index());
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ addQuantParam(conv, {0.1, 0.2, 0.3}, {0, 10, 20});
+ addQuantParam(reshape, {0.2, 0.4, 0.6}, {-10, 0, 10});
+
+ conv->input(input);
+ reshape->tensor(conv);
+ output->from(reshape);
+ }
+
+public:
+ loco::Graph g;
+ luci::CircleInput *input = nullptr;
+ luci::CircleConv2D *conv = nullptr;
+ luci::CircleReshape *reshape = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+/**
+ * Test graph for forward propagation in Quantize Op
+ *
+ * BEFORE
+ *
+ * [Tanh U8] (qparam 1 - pre-defined for U8)
+ * |
+ * [Quantize S16] (qparam 2 - not pre-defined value)
+ *
+ * AFTER
+ *
+ * [Tanh U8] (qparam 1 - pre-defined for U8)
+ * |
+ * [Quantize S16] (qparam 3 - pre-defined for S16)
+ *
+ */
+class TanhQuantizeGraph
+{
+public:
+ TanhQuantizeGraph()
+ {
+ input = g.nodes()->create<luci::CircleInput>();
+ tanh = g.nodes()->create<luci::CircleTanh>();
+ quantize = g.nodes()->create<luci::CircleQuantize>();
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_input = g.inputs()->create();
+ input->index(graph_input->index());
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ tanh->dtype(loco::DataType::U8);
+ quantize->dtype(loco::DataType::S16);
+
+ addQuantParam(tanh, {2.0f / 256.0f}, {128}); // pre-defined qparam for U8
+ addQuantParam(quantize, {1.0}, {0}); // not pre-defined values
+
+ tanh->x(input);
+ quantize->input(tanh);
+ output->from(quantize);
+ }
+
+public:
+ loco::Graph g;
+ luci::CircleInput *input = nullptr;
+ luci::CircleTanh *tanh = nullptr;
+ luci::CircleQuantize *quantize = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+/**
+ * Test graph for forward propagation in Quantize Op
+ *
+ * BEFORE
+ *
+ * [Floor U8] (qparam 1 - int scale)
+ * |
+ * [Quantize S16] (qparam 2 - not int scale)
+ *
+ * AFTER
+ *
+ * [Floor U8] (qparam 1 - int scale)
+ * |
+ * [Quantize S16] (qparam 3 - int scale)
+ *
+ */
+class FloorQuantizeGraph
+{
+public:
+ FloorQuantizeGraph()
+ {
+ input = g.nodes()->create<luci::CircleInput>();
+ floor = g.nodes()->create<luci::CircleFloor>();
+ quantize = g.nodes()->create<luci::CircleQuantize>();
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_input = g.inputs()->create();
+ input->index(graph_input->index());
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ floor->dtype(loco::DataType::U8);
+ quantize->dtype(loco::DataType::S16);
+
+ addQuantParam(floor, {4.0f}, {128}); // int scale
+ addQuantParam(quantize, {0.3}, {0}); // not int scale
+
+ floor->x(input);
+ quantize->input(floor);
+ output->from(quantize);
+ }
+
+public:
+ loco::Graph g;
+ luci::CircleInput *input = nullptr;
+ luci::CircleFloor *floor = nullptr;
+ luci::CircleQuantize *quantize = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+} // namespace
+
+TEST(PropagateQParamForwardPassTest, name)
+{
+ luci::PropagateQParamForwardPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST(PropagateQParamForward, simple)
+{
+ SimpleGraph g;
+
+ luci::PropagateQParamForwardPass pass;
+ while (pass.run(&g.g))
+ ;
+
+ EXPECT_FLOAT_EQ(0.1, g.reshape->quantparam()->scale[0]);
+ EXPECT_FLOAT_EQ(0.2, g.reshape->quantparam()->scale[1]);
+ EXPECT_FLOAT_EQ(0.3, g.reshape->quantparam()->scale[2]);
+ EXPECT_EQ(0, g.reshape->quantparam()->zerop[0]);
+ EXPECT_EQ(10, g.reshape->quantparam()->zerop[1]);
+ EXPECT_EQ(20, g.reshape->quantparam()->zerop[2]);
+}
+
+TEST(PropagateQParamForward, wrong_op_NEG)
+{
+ SimpleGraph g;
+ g.output->from(g.conv);
+ g.reshape->drop();
+
+ luci::PropagateQParamForwardPass pass;
+ while (pass.run(&g.g))
+ ;
+
+ EXPECT_FLOAT_EQ(0.1, g.conv->quantparam()->scale[0]);
+ EXPECT_FLOAT_EQ(0.2, g.conv->quantparam()->scale[1]);
+ EXPECT_FLOAT_EQ(0.3, g.conv->quantparam()->scale[2]);
+ EXPECT_EQ(0, g.conv->quantparam()->zerop[0]);
+ EXPECT_EQ(10, g.conv->quantparam()->zerop[1]);
+ EXPECT_EQ(20, g.conv->quantparam()->zerop[2]);
+}
+
+TEST(PropagateQParamForward, tanh_predefined_value)
+{
+ TanhQuantizeGraph g;
+
+ luci::PropagateQParamForwardPass pass;
+ while (pass.run(&g.g))
+ ;
+
+ EXPECT_FLOAT_EQ(1.0f / 32768.0f, g.quantize->quantparam()->scale[0]);
+}
+
+TEST(PropagateQParamForward, floor_int_scale)
+{
+ FloorQuantizeGraph g;
+
+ luci::PropagateQParamForwardPass pass;
+ while (pass.run(&g.g))
+ ;
+
+ EXPECT_FLOAT_EQ(1.0f, g.quantize->quantparam()->scale[0]);
+}
+
+TEST(PropagateQParamForward, same_dtype_NEG)
+{
+ FloorQuantizeGraph g;
+ g.quantize->dtype(loco::DataType::U8);
+
+ luci::PropagateQParamForwardPass pass;
+ while (pass.run(&g.g))
+ ;
+
+ // Qparam is not propagated as ifm/ofm of Quantize Op have the same dtype
+ EXPECT_FLOAT_EQ(0.3f, g.quantize->quantparam()->scale[0]);
+}
diff --git a/compiler/luci/pass/src/PropagateQuantParamPass.cpp b/compiler/luci/pass/src/PropagateQuantParamPass.cpp
deleted file mode 100644
index b1cb7a418..000000000
--- a/compiler/luci/pass/src/PropagateQuantParamPass.cpp
+++ /dev/null
@@ -1,107 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "luci/Pass/PropagateQuantParamPass.h"
-
-#include <luci/IR/CircleNodes.h>
-#include <luci/IR/CircleNodeVisitor.h>
-#include <luci/Log.h>
-
-#include <iostream>
-
-namespace
-{
-
-bool copy_qparam(luci::CircleQuantParam *src, luci::CircleQuantParam *dst)
-{
- assert(src->scale.size() == dst->scale.size());
- assert(src->zerop.size() == dst->zerop.size());
-
- // src and dst have the same qparam
- if (std::equal(src->scale.begin(), src->scale.end(), dst->scale.begin()) &&
- std::equal(src->zerop.begin(), src->zerop.end(), dst->zerop.begin()) &&
- src->quantized_dimension == dst->quantized_dimension)
- return false;
-
- dst->scale.assign(src->scale.begin(), src->scale.end());
- dst->zerop.assign(src->zerop.begin(), src->zerop.end());
- dst->quantized_dimension = src->quantized_dimension;
- return true;
-}
-
-bool copy_qparam(luci::CircleNode *src, luci::CircleNode *dst)
-{
- // Skip nodes that do not have quantparams
- auto src_qparam = src->quantparam();
- if (not src_qparam)
- return false;
-
- auto dst_qparam = dst->quantparam();
- if (not dst_qparam)
- return false;
-
- return copy_qparam(src_qparam, dst_qparam);
-}
-
-// Visitor to propagate quantization parameters
-struct PropagateQuantParam final : public luci::CircleNodeMutableVisitor<bool>
-{
- PropagateQuantParam() = default;
-
- bool visit(luci::CircleNode *) { return false; }
-
- bool visit(luci::CircleReshape *node)
- {
- auto input = node->tensor();
- if (loco::succs(input).size() != 1)
- return false;
-
- auto input_node = loco::must_cast<luci::CircleNode *>(input);
- return copy_qparam(input_node, node);
- }
-
- bool visit(luci::CircleTranspose *node)
- {
- auto input_node = loco::must_cast<luci::CircleNode *>(node->a());
- return copy_qparam(input_node, node);
- }
-
- // TODO : Add more Ops (e.g., layout-changing Ops)
-};
-
-} // namespace
-
-namespace luci
-{
-
-bool PropagateQuantParamPass::run(loco::Graph *g)
-{
- bool changed = false;
- LOGGER(l);
- for (auto node : loco::active_nodes(loco::output_nodes(g)))
- {
- auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- INFO(l) << "PropagateQuantParamPass visit node: " << circle_node->name() << std::endl;
-
- PropagateQuantParam pqp;
- if (circle_node->accept(&pqp))
- changed = true;
- }
-
- return changed;
-}
-
-} // namespace luci
diff --git a/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp b/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp
deleted file mode 100644
index 0f1564223..000000000
--- a/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp
+++ /dev/null
@@ -1,125 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "luci/Pass/PropagateQuantParamPass.h"
-
-#include <luci/IR/CircleNodes.h>
-
-#include <gtest/gtest.h>
-
-namespace
-{
-
-void addQuantParam(luci::CircleNode *node, const std::vector<float> &scale,
- const std::vector<int64_t> &zp)
-{
- assert(node->quantparam() == nullptr);
-
- auto quantparam = std::make_unique<luci::CircleQuantParam>();
- quantparam->scale = scale;
- quantparam->zerop = zp;
- node->quantparam(std::move(quantparam));
-}
-
-/**
- * Simple graph for test
- *
- * BEFORE
- *
- * [Conv] (qparam 1)
- * |
- * [Reshape] (qparam 2)
- *
- * AFTER
- *
- * [Conv] (qparam 2)
- * |
- * [Reshape] (qparam 2)
- *
- */
-class SimpleGraph
-{
-public:
- SimpleGraph()
- {
- input = g.nodes()->create<luci::CircleInput>();
- conv = g.nodes()->create<luci::CircleConv2D>();
- reshape = g.nodes()->create<luci::CircleReshape>();
- output = g.nodes()->create<luci::CircleOutput>();
-
- auto graph_input = g.inputs()->create();
- input->index(graph_input->index());
- auto graph_output = g.outputs()->create();
- output->index(graph_output->index());
-
- addQuantParam(conv, {0.1, 0.2, 0.3}, {0, 10, 20});
- addQuantParam(reshape, {0.2, 0.4, 0.6}, {-10, 0, 10});
-
- conv->input(input);
- reshape->tensor(conv);
- output->from(reshape);
- }
-
-public:
- loco::Graph g;
- luci::CircleInput *input;
- luci::CircleConv2D *conv;
- luci::CircleReshape *reshape;
- luci::CircleOutput *output;
-};
-
-} // namespace
-
-TEST(PropagateQuantParamPassTest, name)
-{
- luci::PropagateQuantParamPass pass;
- auto const name = pass.name();
- ASSERT_NE(nullptr, name);
-}
-
-TEST(PropagateQuantParam, simple)
-{
- SimpleGraph g;
-
- luci::PropagateQuantParamPass pass;
- while (pass.run(&g.g))
- ;
-
- EXPECT_FLOAT_EQ(0.1, g.reshape->quantparam()->scale[0]);
- EXPECT_FLOAT_EQ(0.2, g.reshape->quantparam()->scale[1]);
- EXPECT_FLOAT_EQ(0.3, g.reshape->quantparam()->scale[2]);
- EXPECT_EQ(0, g.reshape->quantparam()->zerop[0]);
- EXPECT_EQ(10, g.reshape->quantparam()->zerop[1]);
- EXPECT_EQ(20, g.reshape->quantparam()->zerop[2]);
-}
-
-TEST(PropagateQuantParam, wrong_op_NEG)
-{
- SimpleGraph g;
- g.output->from(g.conv);
- g.reshape->drop();
-
- luci::PropagateQuantParamPass pass;
- while (pass.run(&g.g))
- ;
-
- EXPECT_FLOAT_EQ(0.1, g.conv->quantparam()->scale[0]);
- EXPECT_FLOAT_EQ(0.2, g.conv->quantparam()->scale[1]);
- EXPECT_FLOAT_EQ(0.3, g.conv->quantparam()->scale[2]);
- EXPECT_EQ(0, g.conv->quantparam()->zerop[0]);
- EXPECT_EQ(10, g.conv->quantparam()->zerop[1]);
- EXPECT_EQ(20, g.conv->quantparam()->zerop[2]);
-}
diff --git a/compiler/luci/pass/src/QuantizationUtils.cpp b/compiler/luci/pass/src/QuantizationUtils.cpp
index 2f6fed46e..ad86cedf4 100644
--- a/compiler/luci/pass/src/QuantizationUtils.cpp
+++ b/compiler/luci/pass/src/QuantizationUtils.cpp
@@ -33,43 +33,6 @@ bool is_quantized(const CircleNode *node)
node->dtype() == loco::DataType::S64); // bias (int16 quant)
}
-// Check if node is weights of conv2d, depthwise_conv2d, or fully_connected layer
-bool is_weights(CircleNode *node)
-{
- auto circle_const = dynamic_cast<CircleConst *>(node);
- if (circle_const == nullptr)
- return false;
-
- auto succs = loco::succs(node);
-
- // Node is weights if it is the weights of all of its successors
- for (auto out : succs)
- {
- bool is_weights = false;
-
- auto conv = dynamic_cast<CircleConv2D *>(out);
- if (conv != nullptr && conv->filter() == circle_const)
- is_weights = true;
-
- auto dw_conv = dynamic_cast<CircleDepthwiseConv2D *>(out);
- if (dw_conv != nullptr && dw_conv->filter() == circle_const)
- is_weights = true;
-
- auto t_conv = dynamic_cast<CircleTransposeConv *>(out);
- if (t_conv != nullptr && t_conv->filter() == circle_const && circle_const->rank() == 4)
- is_weights = true;
-
- auto fc = dynamic_cast<CircleFullyConnected *>(out);
- if (fc != nullptr && fc->weights() == circle_const)
- is_weights = true;
-
- if (!is_weights)
- return false;
- }
-
- return true;
-}
-
uint8_t fp32_to_uint8_cast(float f)
{
assert(std::numeric_limits<uint8_t>::min() <= f);
@@ -77,7 +40,6 @@ uint8_t fp32_to_uint8_cast(float f)
return static_cast<uint8_t>(f);
}
-// Per-layer quantization of weights (const tensor) using given min/max values
void asymmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float max,
float &scaling_factor, int64_t &zp, float &nudged_min,
float &nudged_max)
@@ -107,7 +69,6 @@ void asymmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float
}
}
-// Per-layer quantization of weights (const tensor) using given min/max values
void symmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float max,
float &scaling_factor, int64_t &zp, float &nudged_min,
float &nudged_max)
@@ -315,4 +276,123 @@ uint32_t cal_offset(loco::TensorShape &dimension, uint32_t *indices)
indices[2] * dimension.dim(3).value() + indices[3];
}
+ActivationQType activation_qtype(const CircleNode *node)
+{
+ auto fused_act_node = dynamic_cast<const CircleNodeMixin<CircleNodeTrait::FusedActFunc> *>(node);
+ if (fused_act_node && fused_act_node->fusedActivationFunction() == FusedActFunc::TANH)
+ return ActivationQType::PreDefinedValue;
+
+ switch (node->opcode())
+ {
+ case CircleOpcode::LOGISTIC:
+ case CircleOpcode::TANH:
+ case CircleOpcode::SOFTMAX:
+ return ActivationQType::PreDefinedValue;
+ case CircleOpcode::FLOOR:
+ case CircleOpcode::FLOOR_DIV:
+ case CircleOpcode::FLOOR_MOD:
+ case CircleOpcode::CEIL:
+ return ActivationQType::IntScale;
+ default:
+ break;
+ }
+
+ return ActivationQType::MinMax;
+}
+
+std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, loco::DataType dtype)
+{
+ auto qparam = std::make_unique<CircleQuantParam>();
+
+ auto set_qparam = [&qparam](float scale, int64_t zp) {
+ qparam->scale.emplace_back(scale);
+ qparam->zerop.emplace_back(zp);
+ };
+
+ switch (opcode)
+ {
+ case CircleOpcode::LOGISTIC:
+ if (dtype == loco::DataType::U8)
+ set_qparam(1.0f / 256.0f, 0);
+ else
+ {
+ assert(dtype == loco::DataType::S16);
+ set_qparam(1.0f / 32768.0f, 0);
+ }
+ break;
+ case CircleOpcode::TANH:
+ if (dtype == loco::DataType::U8)
+ set_qparam(2.0f / 256.0f, 128);
+ else
+ {
+ assert(dtype == loco::DataType::S16);
+ set_qparam(1.0f / 32768.0f, 0);
+ }
+ break;
+ case CircleOpcode::SOFTMAX:
+ if (dtype == loco::DataType::U8)
+ set_qparam(1.0f / 255.0f, 0);
+ else
+ {
+ assert(dtype == loco::DataType::S16);
+ set_qparam(1.0f / 32767.0f, 0);
+ }
+ break;
+ default:
+ throw std::runtime_error("Unsupported opcode with pre-defined qparam");
+ }
+ return std::move(qparam);
+}
+
+// For nodes with integer output, we use integer scale
+void set_int_scale(luci::CircleNode *node)
+{
+ assert(node); // FIX_CALLER_UNLESS
+
+ auto qparam = node->quantparam();
+ assert(qparam); // FIX_CALLER_UNLESS
+ assert(qparam->scale.size() == 1); // FIX_CALLER_UNLESS
+
+ auto fp_scale = qparam->scale[0];
+ qparam->scale[0] = fp_scale < 1 ? 1.0f : std::round(fp_scale);
+}
+
+void quant_const(luci::CircleConst *node, loco::DataType quant_type)
+{
+ assert(node->dtype() == loco::DataType::FLOAT32);
+
+ float min = std::numeric_limits<float>::max();
+ float max = std::numeric_limits<float>::lowest();
+ for (uint32_t i = 0; i < node->size<loco::DataType::FLOAT32>(); i++)
+ {
+ auto data = node->at<loco::DataType::FLOAT32>(i);
+ min = data < min ? data : min;
+ max = data > max ? data : max;
+ }
+
+ float scaling_factor{0.0};
+ int64_t zp{0};
+ float nudged_min{0.0};
+ float nudged_max{0.0};
+
+ switch (quant_type)
+ {
+ case loco::DataType::U8:
+ asymmetric_wquant_with_minmax_per_layer(node, min, max, scaling_factor, zp, nudged_min,
+ nudged_max);
+ break;
+ case loco::DataType::S16:
+ symmetric_wquant_with_minmax_per_layer(node, min, max, scaling_factor, zp, nudged_min,
+ nudged_max);
+ break;
+ default:
+ throw std::runtime_error("Unsupported data type");
+ }
+
+ auto quantparam = std::make_unique<luci::CircleQuantParam>();
+ quantparam->scale.push_back(scaling_factor);
+ quantparam->zerop.push_back(zp);
+ node->quantparam(std::move(quantparam));
+}
+
} // namespace luci
diff --git a/compiler/luci/pass/src/QuantizationUtils.h b/compiler/luci/pass/src/QuantizationUtils.h
index 605f6a77e..cd8cec95a 100644
--- a/compiler/luci/pass/src/QuantizationUtils.h
+++ b/compiler/luci/pass/src/QuantizationUtils.h
@@ -23,33 +23,61 @@
namespace luci
{
+// Compute scale/zp using given min/max for symmetric quantization (int16)
void compute_sym_scale_zp(float min, float max, float &scaling_factor, int64_t &zp,
float &nudged_min, float &nudged_max);
+// Compute scale/zp using given min/max for asymmetric quantization (uint8)
void compute_asym_scale_zp(float min, float max, float &scaling_factor, int64_t &zp,
float &nudged_min, float &nudged_max);
+// Asymmetric per-layer quantization of weights (const tensor) using given min/max values
+// NOTE: in-place update of node data
void asymmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float max,
float &scaling_factor, int64_t &zp, float &nudged_min,
float &nudged_max);
+// Symmetric per-layer quantization of weights (const tensor) using given min/max values
+// NOTE: in-place update of node data
void symmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float max,
float &scaling_factor, int64_t &zp, float &nudged_min,
float &nudged_max);
+// Helper function to get channel dimension
+// TODO Embed this function into iterate_per_channel
bool get_channel_dim_index(CircleConst *node, loco::TensorShape &dimension,
int32_t &channel_dim_index);
+// Calculate offset of the given indices in dimension
uint32_t cal_offset(loco::TensorShape &dimension, uint32_t *indices);
-void propagate_concat_quantparam(luci::CircleConcatenation *concat, loco::DataType quant_type);
+// Backward propagation of concatenation qparam
+void propagate_concat_quantparam(luci::CircleConcatenation *concat);
-void propagate_pad_v2_quantparam(luci::CirclePadV2 *pad_v2, loco::DataType quant_type);
-
-bool is_weights(CircleNode *node);
+// Backward propagation of pad_v2 qparam
+void propagate_pad_v2_quantparam(luci::CirclePadV2 *pad_v2);
+// Return true if the node is quantized
bool is_quantized(const CircleNode *node);
+enum ActivationQType
+{
+ MinMax, // Quantize using recorded min/max
+ PreDefinedValue, // Quantize using pre-defined values
+ IntScale, // Round scale to a positive integer
+};
+
+ActivationQType activation_qtype(const CircleNode *node);
+
+// Create qparam with pre-defined values for speical operators
+std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, loco::DataType dtype);
+
+// Update node's scale to a positive integer (for special Ops e.g., Floor, Ceil)
+void set_int_scale(luci::CircleNode *node);
+
+// Quantize const tensor using its min/max values
+void quant_const(luci::CircleConst *node, loco::DataType quant_type);
+
} // namespace luci
#endif // __LUCI_QUANTIZATION_UTILS_H__
diff --git a/compiler/luci/pass/src/QuantizeActivation.cpp b/compiler/luci/pass/src/QuantizeActivation.cpp
new file mode 100644
index 000000000..149331824
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizeActivation.cpp
@@ -0,0 +1,296 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "QuantizeActivation.h"
+#include "QuantizationUtils.h"
+
+#include <luci/Service/Nodes/CircleConst.h>
+#include <luci/Log.h>
+
+#include <algorithm>
+#include <cmath>
+
+using namespace luci;
+
+namespace
+{
+
+bool has_min_max(const CircleNode *node)
+{
+ return node->quantparam() && !node->quantparam()->min.empty() && !node->quantparam()->max.empty();
+}
+
+} // namespace
+
+// QuantizeActivation
+namespace luci
+{
+
+void QuantizeActivation::visit(luci::CircleNode *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeActivation visit node: " << node->name() << std::endl;
+
+ // Check if this is already quantized
+ if (is_quantized(node))
+ return;
+
+ // Check if this is bool type (bool type is not quantized)
+ if (node->dtype() == loco::DataType::BOOL)
+ return;
+
+ // Check if this is const (const activation is handled by QuantizeConstInputActivation)
+ // NOTE QuantizePreChecker guarantees weights/bias are const.
+ // Update this code when we accept non-const weights/bias.
+ if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
+ return;
+
+ // Check if this is activation
+ // We assume min/max are recorded only for activations
+ if (has_min_max(node))
+ {
+ // Quantize using recorded min/max
+ auto quantparam = node->quantparam();
+ assert(quantparam);
+ assert(quantparam->min.size() == 1); // only support layer-wise quant
+ assert(quantparam->max.size() == 1); // only support layer-wise quant
+ auto min = quantparam->min[0];
+ auto max = quantparam->max[0];
+
+ float scaling_factor{0};
+ int64_t zp{0};
+ float nudged_min{0};
+ float nudged_max{0};
+
+ if (output_type == loco::DataType::U8)
+ {
+ compute_asym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
+ node->dtype(loco::DataType::U8);
+ }
+ else
+ {
+ compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
+ node->dtype(loco::DataType::S16);
+ }
+
+ node->quantparam()->scale.push_back(scaling_factor);
+ node->quantparam()->zerop.push_back(zp);
+ }
+ // Fix special attributes
+ if (node->opcode() == luci::CircleOpcode::CAST)
+ {
+ auto *cast = loco::must_cast<luci::CircleCast *>(node);
+ auto *cast_input = loco::must_cast<luci::CircleNode *>(cast->x());
+
+ // make sure that cast_input is already quantized
+ assert(cast_input->dtype() != loco::DataType::FLOAT32);
+ cast->in_data_type(cast_input->dtype());
+ cast->out_data_type(cast->dtype());
+ }
+}
+
+} // namespace luci
+
+// QuantizeSpecialActivation
+namespace luci
+{
+
+void QuantizeSpecialActivation::visit(luci::CircleNode *node)
+{
+ // Nodes fused with activation functions which need special quantization
+ auto fused_act_node = dynamic_cast<CircleNodeMixin<CircleNodeTrait::FusedActFunc> *>(node);
+ if (fused_act_node != nullptr && fused_act_node->fusedActivationFunction() == FusedActFunc::TANH)
+ {
+ auto qparam = make_predefined_qparam(luci::CircleOpcode::TANH, output_type);
+ node->quantparam(std::move(qparam));
+ }
+}
+
+void QuantizeSpecialActivation::visit(luci::CircleLogistic *node)
+{
+ assert(activation_qtype(node) == luci::ActivationQType::PreDefinedValue);
+ auto qparam = make_predefined_qparam(luci::CircleOpcode::LOGISTIC, output_type);
+ node->quantparam(std::move(qparam));
+}
+
+void QuantizeSpecialActivation::visit(luci::CircleTanh *node)
+{
+ assert(activation_qtype(node) == luci::ActivationQType::PreDefinedValue);
+ auto qparam = make_predefined_qparam(luci::CircleOpcode::TANH, output_type);
+ node->quantparam(std::move(qparam));
+}
+
+void QuantizeSpecialActivation::visit(luci::CircleSoftmax *node)
+{
+ assert(activation_qtype(node) == luci::ActivationQType::PreDefinedValue);
+ auto qparam = make_predefined_qparam(luci::CircleOpcode::SOFTMAX, output_type);
+ node->quantparam(std::move(qparam));
+}
+
+void QuantizeSpecialActivation::visit(luci::CircleFloor *node)
+{
+ assert(activation_qtype(node) == luci::ActivationQType::IntScale);
+ set_int_scale(node);
+}
+
+void QuantizeSpecialActivation::visit(luci::CircleFloorDiv *node)
+{
+ assert(activation_qtype(node) == luci::ActivationQType::IntScale);
+ set_int_scale(node);
+}
+
+void QuantizeSpecialActivation::visit(luci::CircleFloorMod *node)
+{
+ assert(activation_qtype(node) == luci::ActivationQType::IntScale);
+ set_int_scale(node);
+}
+
+void QuantizeSpecialActivation::visit(luci::CircleCeil *node)
+{
+ assert(activation_qtype(node) == luci::ActivationQType::IntScale);
+ set_int_scale(node);
+}
+
+} // namespace luci
+
+// QuantizeConstInputActivation
+namespace luci
+{
+
+// Default behavior (NYI)
+void QuantizeConstInputActivation::visit(luci::CircleNode *node)
+{
+ for (uint32_t i = 0; i < node->arity(); i++)
+ {
+ auto input_node = node->arg(i);
+ auto const_node = dynamic_cast<luci::CircleConst *>(input_node);
+ if (const_node != nullptr)
+ throw std::runtime_error("Unsupported Op for const inputs");
+ }
+}
+
+// INPUT_NAME is the only activation of NODE
+#define QUANTIZE_SINGLE_CONST_INPUT(NODE, INPUT_NAME) \
+ void QuantizeConstInputActivation::visit(NODE *node) \
+ { \
+ auto input = node->INPUT_NAME(); \
+ auto const_node = dynamic_cast<luci::CircleConst *>(input); \
+ if (const_node && !is_quantized(const_node)) \
+ { \
+ auto new_const = luci::clone(const_node); \
+ quant_const(new_const, _output_type); \
+ node->INPUT_NAME(new_const); \
+ } \
+ }
+
+// INPUT_NAME1 and INPUT_NAME2 are the only activations of NODE
+#define QUANTIZE_TWO_CONST_INPUTS(NODE, INPUT_NAME1, INPUT_NAME2) \
+ void QuantizeConstInputActivation::visit(NODE *node) \
+ { \
+ auto input1 = node->INPUT_NAME1(); \
+ auto const_node1 = dynamic_cast<luci::CircleConst *>(input1); \
+ if (const_node1 && !is_quantized(const_node1)) \
+ { \
+ auto new_const1 = luci::clone(const_node1); \
+ quant_const(new_const1, _output_type); \
+ node->INPUT_NAME1(new_const1); \
+ } \
+ auto input2 = node->INPUT_NAME2(); \
+ auto const_node2 = dynamic_cast<luci::CircleConst *>(input2); \
+ if (const_node2 && !is_quantized(const_node2)) \
+ { \
+ auto new_const2 = luci::clone(const_node2); \
+ quant_const(new_const2, _output_type); \
+ node->INPUT_NAME2(new_const2); \
+ } \
+ }
+
+// Ops that receive a single activation as an input
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleArgMax, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleArgMin, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleBatchToSpaceND, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleDepthToSpace, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleElu, features)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleExp, x)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleFloor, x)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleGather, params)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleLocalResponseNormalization, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleLogistic, x)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleMean, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleMirrorPad, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CirclePad, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleReduceAny, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleReduceProd, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleReduceMax, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleReduceMin, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleReshape, tensor)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleResizeBilinear, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleResizeNearestNeighbor, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleReverseSequence, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleRsqrt, x)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSlice, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSoftmax, logits)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSpaceToBatchND, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSpaceToDepth, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSplit, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSplitV, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSqrt, x)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleStridedSlice, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleSum, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleTanh, x)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleTile, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleTopKV2, input)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleTranspose, a)
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleUnpack, value)
+
+// Ops that receive two activations as inputs
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleAdd, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleBatchMatMul, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleDiv, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleEqual, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleFloorDiv, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleGreater, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleGreaterEqual, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleLess, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleLessEqual, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleMaximum, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleMinimum, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleMul, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleNotEqual, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CirclePow, x, y)
+QUANTIZE_TWO_CONST_INPUTS(luci::CircleSub, x, y)
+
+// AddN has arbitrary number of inputs
+void QuantizeConstInputActivation::visit(luci::CircleAddN *node)
+{
+ auto arity = node->arity();
+ for (uint32_t i = 0; i < arity; i++)
+ {
+ auto input_node = node->inputs(i);
+ auto const_node = dynamic_cast<luci::CircleConst *>(input_node);
+ if (const_node && !is_quantized(const_node))
+ {
+ auto new_const = luci::clone(const_node);
+ quant_const(new_const, _output_type);
+ node->inputs(i, new_const);
+ }
+ }
+}
+
+#undef QUANTIZE_SINGLE_CONST_INPUT
+#undef QUANTIZE_TWO_CONST_INPUTS
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/QuantizeActivation.h b/compiler/luci/pass/src/QuantizeActivation.h
new file mode 100644
index 000000000..fc32d1cde
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizeActivation.h
@@ -0,0 +1,165 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_QUANTIZATION_ACTIVATION_H__
+#define __LUCI_QUANTIZATION_ACTIVATION_H__
+
+#include <luci/IR/CircleNodeVisitor.h>
+
+namespace luci
+{
+
+/**
+ * @brief Quantize non-const activation using recorded min/max values
+ */
+struct QuantizeActivation final : public luci::CircleNodeMutableVisitor<void>
+{
+ QuantizeActivation(loco::DataType input, loco::DataType output)
+ : input_type(input), output_type(output)
+ {
+ }
+
+ loco::DataType input_type;
+ loco::DataType output_type;
+
+ // Quantize each node using recorded min/max
+ void visit(luci::CircleNode *node);
+};
+
+/**
+ * @brief Quantize non-const activaion using pre-defined scale/zp for special Ops
+ */
+struct QuantizeSpecialActivation final : public luci::CircleNodeMutableVisitor<void>
+{
+ QuantizeSpecialActivation(loco::DataType input, loco::DataType output)
+ : input_type(input), output_type(output)
+ {
+ }
+
+ loco::DataType input_type;
+ loco::DataType output_type;
+
+ void visit(luci::CircleNode *node);
+ void visit(luci::CircleLogistic *node);
+ void visit(luci::CircleTanh *node);
+ void visit(luci::CircleSoftmax *node);
+ void visit(luci::CircleFloor *node);
+ void visit(luci::CircleFloorDiv *node);
+ void visit(luci::CircleFloorMod *node);
+ void visit(luci::CircleCeil *node);
+};
+
+// Quantize constant input activation of a node
+// The input of a node is quantized if it is
+// 1. Constant (instance of CircleConst*)
+// 2. Activation (other inputs e.g., weights, bias, axis, etc should not be quantized here)
+struct QuantizeConstInputActivation final : public luci::CircleNodeMutableVisitor<void>
+{
+ QuantizeConstInputActivation(loco::DataType output_type) : _output_type(output_type) {}
+
+private:
+ loco::DataType _output_type;
+
+// Skip NODE
+#define SKIP(NODE) \
+ void visit(NODE *) {}
+
+ // Handled in QuantizeWeights and QuantizeBias
+ SKIP(luci::CircleConv2D)
+ SKIP(luci::CircleDepthwiseConv2D)
+ SKIP(luci::CircleFullyConnected)
+ SKIP(luci::CircleInstanceNorm)
+ SKIP(luci::CirclePRelu)
+ SKIP(luci::CircleTransposeConv)
+
+ // Handled in PropagateQParamBackwardPass
+ SKIP(luci::CircleConcatenation)
+ SKIP(luci::CirclePadV2)
+ SKIP(luci::CirclePack)
+ SKIP(luci::CircleOneHot)
+
+ // Inputs of logical Ops are bool, thus not quantized
+ SKIP(luci::CircleLogicalOr)
+ SKIP(luci::CircleLogicalAnd)
+ SKIP(luci::CircleLogicalNot)
+
+#undef SKIP
+
+ // Default behavior (NYI)
+ void visit(luci::CircleNode *node);
+
+ // Ops that receive a single activation as an input
+ void visit(luci::CircleArgMax *node);
+ void visit(luci::CircleArgMin *node);
+ void visit(luci::CircleBatchToSpaceND *node);
+ void visit(luci::CircleDepthToSpace *node);
+ void visit(luci::CircleElu *node);
+ void visit(luci::CircleExp *node);
+ void visit(luci::CircleFloor *node);
+ void visit(luci::CircleGather *node);
+ void visit(luci::CircleLocalResponseNormalization *node);
+ void visit(luci::CircleLogistic *node);
+ void visit(luci::CircleMean *node);
+ void visit(luci::CircleMirrorPad *node);
+ void visit(luci::CirclePad *node);
+ void visit(luci::CircleReduceAny *node);
+ void visit(luci::CircleReduceProd *node);
+ void visit(luci::CircleReduceMax *node);
+ void visit(luci::CircleReduceMin *node);
+ void visit(luci::CircleReshape *node);
+ void visit(luci::CircleResizeBilinear *node);
+ void visit(luci::CircleResizeNearestNeighbor *node);
+ void visit(luci::CircleReverseSequence *node);
+ void visit(luci::CircleRsqrt *node);
+ void visit(luci::CircleSlice *node);
+ void visit(luci::CircleSoftmax *node);
+ void visit(luci::CircleSpaceToBatchND *node);
+ void visit(luci::CircleSpaceToDepth *node);
+ void visit(luci::CircleSplit *node);
+ void visit(luci::CircleSplitV *node);
+ void visit(luci::CircleSqrt *node);
+ void visit(luci::CircleStridedSlice *node);
+ void visit(luci::CircleSum *node);
+ void visit(luci::CircleTanh *node);
+ void visit(luci::CircleTile *node);
+ void visit(luci::CircleTopKV2 *node);
+ void visit(luci::CircleTranspose *node);
+ void visit(luci::CircleUnpack *node);
+
+ // Ops that receive two activations as inputs
+ void visit(luci::CircleAdd *node);
+ void visit(luci::CircleBatchMatMul *node);
+ void visit(luci::CircleDiv *node);
+ void visit(luci::CircleEqual *node);
+ void visit(luci::CircleFloorDiv *node);
+ void visit(luci::CircleGreater *node);
+ void visit(luci::CircleGreaterEqual *node);
+ void visit(luci::CircleLess *node);
+ void visit(luci::CircleLessEqual *node);
+ void visit(luci::CircleMaximum *node);
+ void visit(luci::CircleMinimum *node);
+ void visit(luci::CircleMul *node);
+ void visit(luci::CircleNotEqual *node);
+ void visit(luci::CirclePow *node);
+ void visit(luci::CircleSub *node);
+
+ // AddN has arbitrary number of inputs
+ void visit(luci::CircleAddN *node);
+};
+
+} // namespace luci
+
+#endif // __LUCI_QUANTIZATION_ACTIVATION_H__
diff --git a/compiler/luci/pass/src/QuantizeBias.cpp b/compiler/luci/pass/src/QuantizeBias.cpp
new file mode 100644
index 000000000..aa496232a
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizeBias.cpp
@@ -0,0 +1,300 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "QuantizeBias.h"
+#include "QuantizationUtils.h"
+
+#include <luci/Service/Nodes/CircleConst.h>
+#include <luci/Log.h>
+
+#include <algorithm>
+#include <cmath>
+
+using namespace luci;
+
+namespace
+{
+
+// struct to carry Input/Weights/Bias
+struct IWB
+{
+ CircleNode *input = nullptr;
+ CircleNode *weights = nullptr;
+ CircleConst *bias = nullptr;
+
+ IWB(loco::Node *i, loco::Node *w, loco::Node *b)
+ {
+ input = dynamic_cast<luci::CircleNode *>(i);
+ weights = dynamic_cast<luci::CircleNode *>(w);
+ bias = dynamic_cast<luci::CircleConst *>(b);
+ }
+
+ // Return true if bias can be quantized with valid input an weights
+ operator bool()
+ {
+ if (bias == nullptr || is_quantized(bias))
+ return false;
+ if (input == nullptr || weights == nullptr)
+ return false;
+ return true;
+ }
+};
+
+// Create a new const node from an existing node.
+// The new node has the following characteristics
+// type: T
+// shape: same with 'node' (given as an argument)
+// buffer size: 'size' (given as an argument)
+// Note that contents are not filled in this function.
+template <loco::DataType T>
+luci::CircleConst *create_empty_const_from(luci::CircleConst *node, uint32_t size)
+{
+ auto new_node = node->graph()->nodes()->create<CircleConst>();
+ // TODO: We don't have any naming convention for quantized nodes yet.
+ // Fix this when we have one.
+ new_node->name(node->name());
+ new_node->dtype(T);
+ new_node->rank(node->rank());
+ for (uint32_t i = 0; i < node->rank(); i++)
+ new_node->dim(i).set(node->dim(i).value());
+
+ new_node->size<T>(size);
+ new_node->shape_status(luci::ShapeStatus::VALID);
+
+ return new_node;
+}
+
+CircleConst *asym_quant_bias_per_layer(CircleConst *node, float input_scale, float weight_scale,
+ float *scaling_factor, int64_t *zp)
+{
+ float scale = input_scale * weight_scale;
+ const float scaling_factor_inv = (scale == 0) ? 0 : 1.0 / scale;
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+ std::vector<int32_t> quantized_values(size);
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ quantized_values[i] =
+ static_cast<int32_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
+ }
+
+ auto new_bias = create_empty_const_from<loco::DataType::S32>(node, size);
+
+ const int32_t kMinScale = std::numeric_limits<int32_t>::lowest();
+ const int32_t kMaxScale = std::numeric_limits<int32_t>::max();
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ new_bias->at<loco::DataType::S32>(i) =
+ std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
+ }
+ *scaling_factor = scale;
+ *zp = 0;
+
+ return new_bias;
+}
+
+CircleConst *quant_bias_per_channel(CircleConst *node, float input_scale,
+ std::vector<float> &weight_scale,
+ std::vector<float> &scaling_factor, std::vector<int64_t> &zp)
+{
+ float scaling_factor_inv{0};
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+ std::vector<int32_t> quantized_values(size);
+
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ scaling_factor[i] = input_scale * weight_scale[i];
+ scaling_factor_inv = (scaling_factor[i] == 0) ? 0 : 1.0 / scaling_factor[i];
+ quantized_values[i] =
+ static_cast<int32_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
+ zp[i] = 0;
+ }
+
+ auto new_bias = create_empty_const_from<loco::DataType::S32>(node, size);
+
+ const int32_t kMinScale = std::numeric_limits<int32_t>::lowest();
+ const int32_t kMaxScale = std::numeric_limits<int32_t>::max();
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ new_bias->at<loco::DataType::S32>(i) =
+ std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
+ }
+
+ return new_bias;
+}
+
+CircleConst *int16_quant_bias_per_channel(CircleConst *node, float input_scale,
+ std::vector<float> &weight_scale,
+ std::vector<float> &scaling_factor,
+ std::vector<int64_t> &zp)
+{
+ float scaling_factor_inv{0};
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+ std::vector<int64_t> quantized_values(size);
+
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ scaling_factor[i] = input_scale * weight_scale[i];
+ scaling_factor_inv = (scaling_factor[i] == 0) ? 0 : 1.0 / scaling_factor[i];
+ quantized_values[i] =
+ static_cast<int64_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
+ zp[i] = 0;
+ }
+
+ auto new_bias = create_empty_const_from<loco::DataType::S64>(node, size);
+
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ new_bias->at<loco::DataType::S64>(i) = quantized_values[i];
+ }
+
+ return new_bias;
+}
+
+} // namespace
+
+namespace luci
+{
+
+// Return a quantized bias node
+CircleConst *QuantizeBias::quantized_bias(CircleNode *input, const CircleNode *weight,
+ CircleNode *bias)
+{
+ auto const_bias = loco::must_cast<luci::CircleConst *>(bias);
+ assert(const_bias->dtype() == loco::DataType::FLOAT32);
+
+ // If input is const, it is quantized here, not in QuantizeActivation
+ if (auto const_input = dynamic_cast<luci::CircleConst *>(input))
+ {
+ quant_const(const_input, output_type);
+ }
+
+ CircleConst *new_bias = nullptr;
+
+ if (granularity == QuantizationGranularity::ChannelWise)
+ {
+ auto input_q = input->quantparam();
+ assert(input_q);
+ assert(input_q->scale.size() == 1); // input scale's layer-wise
+ auto input_scale = input_q->scale[0];
+
+ assert(weight->quantparam() != nullptr); // weight scale's channel-wise
+ auto weight_scale = weight->quantparam()->scale;
+
+ uint32_t size = const_bias->size<loco::DataType::FLOAT32>();
+ assert(size == weight_scale.size());
+ std::vector<float> scaling_factor(size);
+ std::vector<int64_t> zp(size);
+
+ if (output_type == loco::DataType::U8)
+ {
+ new_bias = quant_bias_per_channel(const_bias, input_scale, weight_scale, scaling_factor, zp);
+ }
+ else if (output_type == loco::DataType::S16)
+ {
+ new_bias =
+ int16_quant_bias_per_channel(const_bias, input_scale, weight_scale, scaling_factor, zp);
+ }
+ else
+ {
+ throw std::runtime_error("Unsupported quantization type.");
+ }
+
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->scale = scaling_factor;
+ quantparam->zerop = zp;
+ assert(new_bias->quantparam() == nullptr); // bias should not be quantized before
+ new_bias->quantparam(std::move(quantparam));
+
+ return new_bias;
+ }
+ else
+ {
+ auto input_q = input->quantparam();
+ assert(input_q);
+ assert(input_q->scale.size() == 1); // Only support per-layer quant
+ auto input_scale = input_q->scale[0];
+
+ auto weight_q = weight->quantparam();
+ assert(weight_q);
+ assert(weight_q->scale.size() == 1); // Only support per-layer quant
+ auto weight_scale = weight_q->scale[0];
+
+ float scaling_factor{0};
+ int64_t zp{0};
+ new_bias =
+ asym_quant_bias_per_layer(const_bias, input_scale, weight_scale, &scaling_factor, &zp);
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->scale.push_back(scaling_factor);
+ quantparam->zerop.push_back(zp);
+ assert(new_bias->quantparam() == nullptr); // bias should not be quantized before
+ new_bias->quantparam(std::move(quantparam));
+
+ return new_bias;
+ }
+}
+
+void QuantizeBias::visit(luci::CircleConv2D *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeBias QuantizeBias::visit node: " << node->name() << std::endl;
+
+ if (auto iwb = IWB(node->input(), node->filter(), node->bias()))
+ {
+ auto new_bias = quantized_bias(iwb.input, iwb.weights, iwb.bias);
+ node->bias(new_bias);
+ }
+}
+
+void QuantizeBias::visit(luci::CircleDepthwiseConv2D *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeBias QuantizeBias::visit node: " << node->name() << std::endl;
+
+ if (auto iwb = IWB(node->input(), node->filter(), node->bias()))
+ {
+ auto new_bias = quantized_bias(iwb.input, iwb.weights, iwb.bias);
+ node->bias(new_bias);
+ }
+}
+
+void QuantizeBias::visit(luci::CircleTransposeConv *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeBias QuantizeBias::visit node: " << node->name() << std::endl;
+
+ if (auto iwb = IWB(node->outBackprop(), node->filter(), node->bias()))
+ {
+ auto new_bias = quantized_bias(iwb.input, iwb.weights, iwb.bias);
+ node->bias(new_bias);
+ }
+}
+
+void QuantizeBias::visit(luci::CircleFullyConnected *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeBias visit node: " << node->name() << std::endl;
+
+ if (auto iwb = IWB(node->input(), node->weights(), node->bias()))
+ {
+ auto new_bias = quantized_bias(iwb.input, iwb.weights, iwb.bias);
+ node->bias(new_bias);
+ }
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/QuantizeBias.h b/compiler/luci/pass/src/QuantizeBias.h
new file mode 100644
index 000000000..8de09df72
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizeBias.h
@@ -0,0 +1,56 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_QUANTIZE_BIAS_H__
+#define __LUCI_QUANTIZE_BIAS_H__
+
+#include <luci/Pass/QuantizationParameters.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+namespace luci
+{
+
+/**
+ * @brief QuantizeBias quantizes tensors for bias
+ * @details Use input/weights scale to quantize values
+ */
+struct QuantizeBias final : public luci::CircleNodeMutableVisitor<void>
+{
+ QuantizeBias(loco::DataType input, loco::DataType output, QuantizationGranularity gr)
+ : input_type(input), output_type(output), granularity(gr)
+ {
+ }
+
+ loco::DataType input_type;
+ loco::DataType output_type;
+ QuantizationGranularity granularity;
+
+private:
+ // Return a quantized bias node
+ CircleConst *quantized_bias(CircleNode *input, const CircleNode *weight, CircleNode *bias);
+
+ void visit(luci::CircleConv2D *node);
+ void visit(luci::CircleDepthwiseConv2D *node);
+ void visit(luci::CircleTransposeConv *node);
+ void visit(luci::CircleFullyConnected *node);
+
+ // Default behavior
+ void visit(luci::CircleNode *) {}
+};
+
+} // namespace luci
+
+#endif // __LUCI_QUANTIZE_BIAS_H__
diff --git a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp
index c8ad87e3d..c9b35e0be 100644
--- a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp
+++ b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp
@@ -16,9 +16,11 @@
#include "luci/Pass/QuantizeDequantizeWeightsPass.h"
#include "QuantizationUtils.h"
+#include "helpers/LayerInfoMap.h"
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Service/Nodes/CircleConst.h>
#include <luci/Log.h>
#include <loco/IR/TensorShape.h>
@@ -251,7 +253,7 @@ void asymmetric_wdequant_with_minmax_per_layer(CircleConst *node, float scaling_
* @brief QuantizeDequantizeWeights quantizes and dequantizes tensors for weights
* @details Find min/max values on the fly, quantize the model, and dequantize the model
*/
-struct QuantizeDequantizeWeights final : public luci::CircleNodeMutableVisitor<bool>
+struct QuantizeDequantizeWeights final : public luci::CircleNodeMutableVisitor<void>
{
QuantizeDequantizeWeights(loco::DataType input, loco::DataType output,
QuantizationGranularity granularity)
@@ -263,88 +265,164 @@ struct QuantizeDequantizeWeights final : public luci::CircleNodeMutableVisitor<b
loco::DataType output_type;
QuantizationGranularity granularity;
- // Quantize and dequantize input tensors of each node
- bool visit(luci::CircleNode *node)
+private:
+ // Fake quantize weights (Only u8 quantization is supported for LWQ)
+ void fake_quantize_lwq(luci::CircleConst *weights) const
{
- assert(output_type == loco::DataType::U8 || output_type == loco::DataType::S16);
- LOGGER(l);
- INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl;
- auto arity = node->arity();
- for (uint32_t i = 0; i < arity; i++)
+ assert(output_type == loco::DataType::U8); // FIX_CALLER_UNLESS
+
+ // Find min/max per layer
+ float min = std::numeric_limits<float>::max();
+ float max = std::numeric_limits<float>::lowest();
+ for (uint32_t i = 0; i < weights->size<loco::DataType::FLOAT32>(); i++)
{
- auto input_node = node->arg(i);
- auto circle_node = loco::must_cast<luci::CircleNode *>(input_node);
+ auto data = weights->at<loco::DataType::FLOAT32>(i);
+ min = data < min ? data : min;
+ max = data > max ? data : max;
+ }
+ float scaling_factor{0};
+ int64_t zp{0};
+ float nudged_min{0};
+ float nudged_max{0};
+
+ asymmetric_wquant_with_minmax_per_layer(weights, min, max, scaling_factor, zp, nudged_min,
+ nudged_max);
+ asymmetric_wdequant_with_minmax_per_layer(weights, scaling_factor, nudged_min);
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->min.push_back(nudged_min);
+ quantparam->max.push_back(nudged_max);
+ quantparam->scale.push_back(scaling_factor);
+ quantparam->zerop.push_back(zp);
+ weights->quantparam(std::move(quantparam));
+ }
- // Check if this is already quantized
- if (is_quantized(circle_node))
- continue;
+private:
+ // Fake quantize weights (u8/s16 quantization are supported for CWQ)
+ void fake_quantize_cwq(luci::CircleConst *weights) const
+ {
+ assert(output_type == loco::DataType::U8 ||
+ output_type == loco::DataType::S16); // FIX_CALLER_UNLESS
- if (is_weights(circle_node))
- {
- auto circle_const = loco::must_cast<luci::CircleConst *>(circle_node);
+ // Find min/max per channel
+ std::vector<float> min;
+ std::vector<float> max;
- // Find min/max per channel-wise
- if (granularity == QuantizationGranularity::ChannelWise)
- {
- std::vector<float> min;
- std::vector<float> max;
-
- cal_minmax_per_channel(circle_const, min, max);
-
- std::vector<float> nudged_min(min.size());
- std::vector<float> nudged_max(min.size());
- std::vector<float> scaling_factor(min.size());
- std::vector<int64_t> zp(min.size());
-
- if (output_type == loco::DataType::U8)
- {
- asymmetric_wquant_per_channel(circle_const, min, max, scaling_factor, zp, nudged_min,
- nudged_max);
- asymmetric_wdequant_per_channel(circle_const, scaling_factor, nudged_min);
- }
- else
- {
- sym_wquant_per_channel(circle_const, min, max, scaling_factor, zp, nudged_min,
- nudged_max);
- sym_wdequant_per_channel(circle_const, scaling_factor);
- }
-
- auto quantparam = std::make_unique<CircleQuantParam>();
- quantparam->min = nudged_min;
- quantparam->max = nudged_max;
- quantparam->scale = scaling_factor;
- quantparam->zerop = zp;
- circle_node->quantparam(std::move(quantparam));
- }
- // Find min/max per layer-wise
- else
- {
- float min = std::numeric_limits<float>::max();
- float max = std::numeric_limits<float>::lowest();
- for (uint32_t i = 0; i < circle_const->size<loco::DataType::FLOAT32>(); i++)
- {
- auto data = circle_const->at<loco::DataType::FLOAT32>(i);
- min = data < min ? data : min;
- max = data > max ? data : max;
- }
- float scaling_factor{0};
- int64_t zp{0};
- float nudged_min{0};
- float nudged_max{0};
-
- asymmetric_wquant_with_minmax_per_layer(circle_const, min, max, scaling_factor, zp,
- nudged_min, nudged_max);
- asymmetric_wdequant_with_minmax_per_layer(circle_const, scaling_factor, nudged_min);
- auto quantparam = std::make_unique<CircleQuantParam>();
- quantparam->min.push_back(nudged_min);
- quantparam->max.push_back(nudged_max);
- quantparam->scale.push_back(scaling_factor);
- quantparam->zerop.push_back(zp);
- circle_node->quantparam(std::move(quantparam));
- }
- }
+ cal_minmax_per_channel(weights, min, max);
+
+ std::vector<float> nudged_min(min.size());
+ std::vector<float> nudged_max(min.size());
+ std::vector<float> scaling_factor(min.size());
+ std::vector<int64_t> zp(min.size());
+
+ if (output_type == loco::DataType::U8)
+ {
+ asymmetric_wquant_per_channel(weights, min, max, scaling_factor, zp, nudged_min, nudged_max);
+ asymmetric_wdequant_per_channel(weights, scaling_factor, nudged_min);
+ }
+ else
+ {
+ sym_wquant_per_channel(weights, min, max, scaling_factor, zp, nudged_min, nudged_max);
+ sym_wdequant_per_channel(weights, scaling_factor);
}
- return false;
+
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->min = nudged_min;
+ quantparam->max = nudged_max;
+ quantparam->scale = scaling_factor;
+ quantparam->zerop = zp;
+ weights->quantparam(std::move(quantparam));
+ }
+
+private:
+ void fake_quantize(luci::CircleConst *weights) const
+ {
+ switch (granularity)
+ {
+ case luci::QuantizationGranularity::ChannelWise:
+ fake_quantize_cwq(weights);
+ break;
+ case luci::QuantizationGranularity::LayerWise:
+ fake_quantize_lwq(weights);
+ break;
+ default:
+ throw std::invalid_argument("Unsupported granularity");
+ }
+ }
+
+private:
+ // Check if
+ // 1. node is const
+ // 2. node was not quantized
+ bool is_quantizable(loco::Node *node)
+ {
+ auto const_node = dynamic_cast<luci::CircleConst *>(node);
+ if (not const_node)
+ return false;
+
+ // Skip if this is already quantized
+ if (is_quantized(const_node))
+ return false;
+
+ return true;
+ }
+
+ // Default behavior (Do nothing)
+ void visit(luci::CircleNode *) {}
+
+ void visit(luci::CircleConv2D *node)
+ {
+ LOGGER(l);
+ INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl;
+
+ if (not is_quantizable(node->filter()))
+ return;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
+ auto new_weights = luci::clone(weights);
+ node->filter(new_weights);
+ fake_quantize(new_weights);
+ }
+
+ void visit(luci::CircleDepthwiseConv2D *node)
+ {
+ LOGGER(l);
+ INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl;
+
+ if (not is_quantizable(node->filter()))
+ return;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
+ auto new_weights = luci::clone(weights);
+ node->filter(new_weights);
+ fake_quantize(new_weights);
+ }
+
+ void visit(luci::CircleTransposeConv *node)
+ {
+ LOGGER(l);
+ INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl;
+
+ if (not is_quantizable(node->filter()))
+ return;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
+ auto new_weights = luci::clone(weights);
+ node->filter(new_weights);
+ fake_quantize(new_weights);
+ }
+
+ void visit(luci::CircleFullyConnected *node)
+ {
+ LOGGER(l);
+ INFO(l) << "QuantizeDequantizeWeights visit node: " << node->name() << std::endl;
+
+ if (not is_quantizable(node->weights()))
+ return;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->weights());
+ auto new_weights = luci::clone(weights);
+ node->weights(new_weights);
+ fake_quantize(new_weights);
}
};
@@ -355,11 +433,36 @@ bool QuantizeDequantizeWeightsPass::run(loco::Graph *g)
LOGGER(l);
INFO(l) << "QuantizeDequantizeWeightsPass Start" << std::endl;
+ auto info_by_name = layer_info_map(g, _ctx->layers_info);
+
+ auto quantize_dtype = [&](const luci::CircleNode *node) {
+ auto iter = info_by_name.find(node->name());
+
+ // Return designated quantization dtype
+ if (iter != info_by_name.end())
+ return iter->second.dtype;
+
+ // Return default quantization dtype
+ return _ctx->output_model_dtype;
+ };
+
+ auto quantize_granularity = [&](const luci::CircleNode *node) {
+ auto iter = info_by_name.find(node->name());
+
+ // Return designated quantization granularity
+ if (iter != info_by_name.end())
+ return iter->second.granularity;
+
+ // Return default quantization granularity
+ return _ctx->granularity;
+ };
+
// Quantize weights
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
- QuantizeDequantizeWeights qw(_input_model_dtype, _output_model_dtype, _granularity);
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ QuantizeDequantizeWeights qw(_ctx->input_model_dtype, quantize_dtype(circle_node),
+ quantize_granularity(circle_node));
circle_node->accept(&qw);
}
diff --git a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.test.cpp b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.test.cpp
index f226253c2..15f5ca7ac 100644
--- a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.test.cpp
+++ b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.test.cpp
@@ -25,3 +25,17 @@ TEST(QuantizeDequantizeWeightsPassTest, name)
auto const name = pass.name();
ASSERT_NE(nullptr, name);
}
+
+TEST(QuantizeDequantizeWeightsPassTest, name_ctx)
+{
+ auto ctx = std::make_unique<luci::QuantizeDequantizeWeightsPass::Context>();
+ {
+ ctx->input_model_dtype = loco::DataType::FLOAT32;
+ ctx->output_model_dtype = loco::DataType::U8;
+ ctx->granularity = luci::QuantizationGranularity::LayerWise;
+ }
+
+ luci::QuantizeDequantizeWeightsPass pass(std::move(ctx));
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
diff --git a/compiler/luci/pass/src/QuantizePreCheckerPass.cpp b/compiler/luci/pass/src/QuantizePreCheckerPass.cpp
new file mode 100644
index 000000000..4b3b7e330
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizePreCheckerPass.cpp
@@ -0,0 +1,119 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/QuantizePreCheckerPass.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+#include <luci/Log.h>
+
+namespace luci
+{
+
+namespace
+{
+
+void check_const_opcode(luci::CircleNode *node)
+{
+ if (node == nullptr)
+ return;
+
+ if (node->opcode() != luci::CircleOpcode::CIRCLECONST and
+ node->opcode() != luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE)
+ {
+ throw std::runtime_error("Unsupported non const input " + node->name());
+ }
+}
+
+struct ConstInputChecker final : public luci::CircleNodeMutableVisitor<void>
+{
+// INPUT_NAME is name for input const for current NODE
+#define CHECK_NODE_WITH_ONE_INPUT_CONST(NODE, INPUT_NAME) \
+ void visit(NODE *node) \
+ { \
+ const auto input = dynamic_cast<luci::CircleNode *>(node->INPUT_NAME()); \
+ check_const_opcode(input); \
+ }
+
+// INPUT_NAME_1 and INPUT_NAME_2 are names for input const for current NODE
+#define CHECK_NODE_WITH_TWO_INPUT_CONST(NODE, INPUT_NAME_1, INPUT_NAME_2) \
+ void visit(NODE *node) \
+ { \
+ const auto input_1 = dynamic_cast<luci::CircleNode *>(node->INPUT_NAME_1()); \
+ const auto input_2 = dynamic_cast<luci::CircleNode *>(node->INPUT_NAME_2()); \
+ \
+ check_const_opcode(input_1); \
+ check_const_opcode(input_2); \
+ }
+
+// INPUT_NAME_1, INPUT_NAME_2 and INPUT_NAME_3 are names for input const for current NODE
+#define CHECK_NODE_WITH_THREE_INPUT_CONST(NODE, INPUT_NAME_1, INPUT_NAME_2, INPUT_NAME_3) \
+ void visit(NODE *node) \
+ { \
+ const auto input_1 = dynamic_cast<luci::CircleNode *>(node->INPUT_NAME_1()); \
+ const auto input_2 = dynamic_cast<luci::CircleNode *>(node->INPUT_NAME_2()); \
+ const auto input_3 = dynamic_cast<luci::CircleNode *>(node->INPUT_NAME_3()); \
+ \
+ check_const_opcode(input_1); \
+ check_const_opcode(input_2); \
+ check_const_opcode(input_3); \
+ }
+
+ // Skip other circle node
+ void visit(luci::CircleNode *) {}
+
+ // Ops that receive one const nodes as inputs
+ CHECK_NODE_WITH_ONE_INPUT_CONST(luci::CirclePRelu, alpha)
+
+ // Ops that receive two const node as an inputs
+ CHECK_NODE_WITH_TWO_INPUT_CONST(luci::CircleConv2D, filter, bias)
+ CHECK_NODE_WITH_TWO_INPUT_CONST(luci::CircleDepthwiseConv2D, filter, bias)
+ CHECK_NODE_WITH_TWO_INPUT_CONST(luci::CircleFullyConnected, weights, bias)
+ CHECK_NODE_WITH_TWO_INPUT_CONST(luci::CircleInstanceNorm, gamma, beta)
+
+ // Ops that receive three const nodes as an inputs
+ CHECK_NODE_WITH_THREE_INPUT_CONST(luci::CircleTransposeConv, inputSizes, filter, bias)
+
+#undef CHECK_NODE_WITH_ONE_INPUT_CONST
+#undef CHECK_NODE_WITH_TWO_INPUT_CONST
+#undef CHECK_NODE_WITH_THREE_INPUT_CONST
+};
+
+} // namespace
+
+/**
+ * Verify the input model has the form acceptable by quantizer
+ */
+bool QuantizePreCheckerPass::run(loco::Graph *g)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizePreCheckerPass Start" << std::endl;
+
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ // Check const inputs
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ ConstInputChecker checker{};
+ circle_node->accept(&checker);
+ }
+
+ INFO(l) << "QuantizePreCheckerPass End" << std::endl;
+
+ return false; // one time run
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp b/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp
new file mode 100644
index 000000000..788353cd8
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp
@@ -0,0 +1,401 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/QuantizePreCheckerPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+class SimpleConv2DGraph
+{
+public:
+ SimpleConv2DGraph(bool make_valid)
+ {
+ conv2d_node = g.nodes()->create<luci::CircleConv2D>();
+ input_1 = g.nodes()->create<luci::CircleInput>();
+ filter = g.nodes()->create<luci::CircleConst>();
+
+ conv2d_node->input(input_1);
+ conv2d_node->filter(filter);
+
+ if (make_valid)
+ {
+ bias = g.nodes()->create<luci::CircleConst>();
+ conv2d_node->bias(bias);
+ }
+ else
+ {
+ input_2 = g.nodes()->create<luci::CircleInput>();
+ conv2d_node->bias(input_2);
+ }
+
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ output->from(conv2d_node);
+ }
+
+public:
+ loco::Graph g;
+
+private:
+ luci::CircleConv2D *conv2d_node = nullptr;
+ luci::CircleInput *input_1 = nullptr;
+ luci::CircleInput *input_2 = nullptr;
+ luci::CircleConst *filter = nullptr;
+ luci::CircleConst *bias = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+class SimpleDepthConv2DGraph
+{
+public:
+ SimpleDepthConv2DGraph(bool make_valid)
+ {
+ depth_conv2d_node = g.nodes()->create<luci::CircleDepthwiseConv2D>();
+ input_1 = g.nodes()->create<luci::CircleInput>();
+ filter = g.nodes()->create<luci::CircleConst>();
+
+ depth_conv2d_node->input(input_1);
+ depth_conv2d_node->filter(filter);
+
+ if (make_valid)
+ {
+ bias = g.nodes()->create<luci::CircleConst>();
+ depth_conv2d_node->bias(bias);
+ }
+ else
+ {
+ input_2 = g.nodes()->create<luci::CircleInput>();
+ depth_conv2d_node->bias(input_2);
+ }
+
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ output->from(depth_conv2d_node);
+ }
+
+public:
+ loco::Graph g;
+
+private:
+ luci::CircleDepthwiseConv2D *depth_conv2d_node = nullptr;
+ luci::CircleInput *input_1 = nullptr;
+ luci::CircleInput *input_2 = nullptr;
+ luci::CircleConst *filter = nullptr;
+ luci::CircleConst *bias = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+class SimpleFCGraph
+{
+public:
+ SimpleFCGraph(bool make_valid)
+ {
+ fc_node = g.nodes()->create<luci::CircleFullyConnected>();
+ input_1 = g.nodes()->create<luci::CircleInput>();
+ weights = g.nodes()->create<luci::CircleConst>();
+
+ fc_node->input(input_1);
+ fc_node->weights(weights);
+
+ if (make_valid)
+ {
+ bias = g.nodes()->create<luci::CircleConst>();
+ fc_node->bias(bias);
+ }
+ else
+ {
+ input_2 = g.nodes()->create<luci::CircleInput>();
+ fc_node->bias(input_2);
+ }
+
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ output->from(fc_node);
+ }
+
+public:
+ loco::Graph g;
+
+private:
+ luci::CircleFullyConnected *fc_node = nullptr;
+ luci::CircleInput *input_1 = nullptr;
+ luci::CircleInput *input_2 = nullptr;
+ luci::CircleConst *weights = nullptr;
+ luci::CircleConst *bias = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+class SimpleInstanceNormGraph
+{
+public:
+ SimpleInstanceNormGraph(bool make_valid)
+ {
+ instance_norm_node = g.nodes()->create<luci::CircleInstanceNorm>();
+ input_1 = g.nodes()->create<luci::CircleInput>();
+ gamma = g.nodes()->create<luci::CircleConst>();
+
+ instance_norm_node->input(input_1);
+ instance_norm_node->gamma(gamma);
+
+ if (make_valid)
+ {
+ beta = g.nodes()->create<luci::CircleConst>();
+ instance_norm_node->beta(beta);
+ }
+ else
+ {
+ input_2 = g.nodes()->create<luci::CircleInput>();
+ instance_norm_node->beta(input_2);
+ }
+
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ output->from(instance_norm_node);
+ }
+
+public:
+ loco::Graph g;
+
+private:
+ luci::CircleInstanceNorm *instance_norm_node = nullptr;
+ luci::CircleInput *input_1 = nullptr;
+ luci::CircleInput *input_2 = nullptr;
+ luci::CircleConst *gamma = nullptr;
+ luci::CircleConst *beta = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+class SimpleTransposeConvGraph
+{
+public:
+ SimpleTransposeConvGraph(bool make_valid)
+ {
+ transpose_conv = g.nodes()->create<luci::CircleTransposeConv>();
+ input_1 = g.nodes()->create<luci::CircleInput>();
+
+ input_sizes = g.nodes()->create<luci::CircleConst>();
+ filter = g.nodes()->create<luci::CircleConst>();
+
+ transpose_conv->outBackprop(input_1);
+ transpose_conv->filter(filter);
+ transpose_conv->inputSizes(input_sizes);
+
+ if (make_valid)
+ {
+ bias = g.nodes()->create<luci::CircleConst>();
+ transpose_conv->bias(bias);
+ }
+ else
+ {
+ input_2 = g.nodes()->create<luci::CircleInput>();
+ transpose_conv->bias(input_2);
+ }
+
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ output->from(transpose_conv);
+ }
+
+public:
+ loco::Graph g;
+
+private:
+ luci::CircleTransposeConv *transpose_conv = nullptr;
+ luci::CircleInput *input_1 = nullptr;
+ luci::CircleInput *input_2 = nullptr;
+ luci::CircleConst *input_sizes = nullptr;
+ luci::CircleConst *filter = nullptr;
+ luci::CircleConst *bias = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+class SimplePReluGraph
+{
+public:
+ SimplePReluGraph(bool make_valid)
+ {
+ prelu = g.nodes()->create<luci::CirclePRelu>();
+ input_1 = g.nodes()->create<luci::CircleInput>();
+
+ prelu->input(input_1);
+
+ if (make_valid)
+ {
+ alpha = g.nodes()->create<luci::CircleConst>();
+ prelu->alpha(alpha);
+ }
+ else
+ {
+ input_2 = g.nodes()->create<luci::CircleInput>();
+ prelu->alpha(input_2);
+ }
+
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ output->from(prelu);
+ }
+
+public:
+ loco::Graph g;
+
+private:
+ luci::CirclePRelu *prelu = nullptr;
+ luci::CircleInput *input_1 = nullptr;
+ luci::CircleInput *input_2 = nullptr;
+ luci::CircleConst *alpha = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+TEST(QuantizePreCheckerPassTest, name)
+{
+ luci::QuantizePreCheckerPass pass{};
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+// Test Conv2d
+TEST(QuantizePreCheckerPassTest, conv2d)
+{
+ SimpleConv2DGraph valid_graph(true);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_NO_THROW(checker.run(&valid_graph.g));
+}
+
+TEST(QuantizePreCheckerPassTest, conv2d_NEG)
+{
+ SimpleConv2DGraph invalid_graph(false);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
+}
+
+// Test DepthwiseConv2d
+TEST(QuantizePreCheckerPassTest, depthwise_conv2d)
+{
+ SimpleDepthConv2DGraph valid_graph(true);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_NO_THROW(checker.run(&valid_graph.g));
+}
+
+TEST(QuantizePreCheckerPassTest, depthwise_conv2d_NEG)
+{
+ SimpleDepthConv2DGraph invalid_graph(false);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
+}
+
+// Test FullyConnected
+TEST(QuantizePreCheckerPassTest, fully_connected)
+{
+ SimpleFCGraph valid_graph(true);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_NO_THROW(checker.run(&valid_graph.g));
+}
+
+TEST(QuantizePreCheckerPassTest, fully_connected_NEG)
+{
+ SimpleFCGraph invalid_graph(false);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
+}
+
+// Test InstanceNorm
+TEST(QuantizePreCheckerPassTest, instance_norm)
+{
+ SimpleInstanceNormGraph valid_graph(true);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_NO_THROW(checker.run(&valid_graph.g));
+}
+
+TEST(QuantizePreCheckerPassTest, instance_norm_NEG)
+{
+ SimpleInstanceNormGraph invalid_graph(false);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
+}
+
+// Test TransposeConv
+TEST(QuantizePreCheckerPassTest, transpose_conv)
+{
+ SimpleTransposeConvGraph valid_graph(true);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_NO_THROW(checker.run(&valid_graph.g));
+}
+
+TEST(QuantizePreCheckerPassTest, transpose_conv_NEG)
+{
+ SimpleTransposeConvGraph invalid_graph(false);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
+}
+
+// Test PRelu
+TEST(QuantizePreCheckerPassTest, prelu)
+{
+ SimplePReluGraph valid_graph(true);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_NO_THROW(checker.run(&valid_graph.g));
+}
+
+TEST(QuantizePreCheckerPassTest, prelu_NEG)
+{
+ SimplePReluGraph invalid_graph(false);
+
+ luci::QuantizePreCheckerPass checker{};
+
+ EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
+}
diff --git a/compiler/luci/pass/src/QuantizeWeights.cpp b/compiler/luci/pass/src/QuantizeWeights.cpp
new file mode 100644
index 000000000..11322ab44
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizeWeights.cpp
@@ -0,0 +1,394 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "QuantizeWeights.h"
+#include "QuantizationUtils.h"
+
+#include <luci/Service/Nodes/CircleConst.h>
+#include <luci/Log.h>
+
+#include <cmath>
+#include <vector>
+#include <functional>
+
+using namespace luci;
+
+namespace
+{
+
+using IterFunc = std::function<void(uint32_t *, loco::TensorShape &, int32_t)>;
+
+void iterate_per_channel(CircleConst *node, int32_t &channel_dim_index, IterFunc func)
+{
+ loco::TensorShape dimension;
+ dimension.rank(4);
+ uint32_t indices[4] = {
+ 0,
+ };
+
+ if (!get_channel_dim_index(node, dimension, channel_dim_index))
+ {
+ assert(false);
+ return;
+ }
+
+ for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++)
+ {
+ for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++)
+ {
+ for (indices[2] = 0; indices[2] < dimension.dim(2).value(); indices[2]++)
+ {
+ for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++)
+ {
+ func(indices, dimension, channel_dim_index);
+ }
+ }
+ }
+ }
+}
+
+void asym_wquant_per_channel(CircleConst *node, std::vector<float> &min,
+ std::vector<float> &scaling_factor, int32_t &channel_dim_index)
+{
+ assert(node->dtype() == loco::DataType::FLOAT32);
+
+ const int32_t kMinScale = 0;
+ const int32_t kMaxScale = 255;
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+ std::vector<int32_t> quantized_values(size);
+
+ auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int32_t channel_dim_index) {
+ int channel_idx = indices[channel_dim_index];
+ const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
+ auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
+ quantized_values[cal_offset(dimension, indices)] =
+ static_cast<int32_t>(std::round((data - min[channel_idx]) * scaling_factor_inv));
+ };
+
+ iterate_per_channel(node, channel_dim_index, quantize);
+
+ node->dtype(loco::DataType::U8); // change the type of tensor
+ node->size<loco::DataType::U8>(size); // resize tensor
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ node->at<loco::DataType::U8>(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
+ }
+}
+
+void sym_wquant_per_channel(CircleConst *node, std::vector<float> &scaling_factor,
+ int32_t &channel_dim_index)
+{
+ assert(node->dtype() == loco::DataType::FLOAT32);
+
+ const int32_t kMaxScale = std::numeric_limits<int16_t>::max();
+ const int32_t kMinScale = -kMaxScale;
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+ std::vector<int32_t> quantized_values(size);
+
+ auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int32_t channel_dim_index) {
+ int channel_idx = indices[channel_dim_index];
+ const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
+ auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
+ quantized_values[cal_offset(dimension, indices)] =
+ static_cast<int32_t>(std::round(data * scaling_factor_inv));
+ };
+
+ iterate_per_channel(node, channel_dim_index, quantize);
+
+ node->dtype(loco::DataType::S16); // change the type of tensor
+ node->size<loco::DataType::S16>(size); // resize tensor
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ node->at<loco::DataType::S16>(i) =
+ std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
+ }
+}
+
+void asym_wquant_per_layer(CircleConst *node, float min, float scaling_factor)
+{
+ const int32_t kMinScale = 0;
+ const int32_t kMaxScale = 255;
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+
+ const float scaling_factor_inv = 1.0 / scaling_factor;
+ std::vector<int32_t> quantized_values(size);
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ auto data = node->at<loco::DataType::FLOAT32>(i);
+ quantized_values[i] = static_cast<int32_t>(std::round((data - min) * scaling_factor_inv));
+ }
+
+ node->dtype(loco::DataType::U8); // change the type of tensor
+ node->size<loco::DataType::U8>(size); // resize tensor
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ node->at<loco::DataType::U8>(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
+ }
+}
+
+// Quantize const per channel
+//
+// The last dimension of const is the same as the dimension of channel
+// And the rest of the const dimensions should be 1
+// So, a 'single value' is quantized per channel
+//
+// Quantization spec (f: fp value, q: quantized value)
+//
+// uint8
+// Positive f: f = f * (q - 0) [q = 1, scale = f, zp = 0]
+// Negative f: f = (-f) * (q - 1) [q = 0, scale = -f, zp = 1]
+//
+// int16
+// Positive f: f = f * (q - 0) [q = 1, scale = f, zp = 0]
+// Negative f: f = (-f) * (q - 0) [q = -1, scale = -f, zp = 0]
+void quant_const_per_channel(CircleConst *node, loco::DataType quant_type)
+{
+ assert(node->dtype() == loco::DataType::FLOAT32);
+ assert(node->rank() > 0);
+
+ for (uint32_t i = 0; i < node->rank() - 1; i++)
+ {
+ // Caller should call this function when the below condition is satisfied
+ if (node->dim(i).value() != 1)
+ throw std::runtime_error("Non-channel dimension of const node must be 1");
+ }
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+ assert(size == node->dim(node->rank() - 1).value());
+
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->quantized_dimension = node->rank() - 1;
+ std::vector<int32_t> quantized_data(size);
+
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ auto data = node->at<loco::DataType::FLOAT32>(i);
+ if (quant_type == loco::DataType::U8)
+ {
+ if (data >= 0)
+ {
+ quantparam->scale.push_back(data);
+ quantparam->zerop.push_back(0);
+ quantized_data[i] = 1;
+ }
+ else
+ {
+ quantparam->scale.push_back(-data);
+ quantparam->zerop.push_back(1);
+ quantized_data[i] = 0;
+ }
+ }
+ else if (quant_type == loco::DataType::S16)
+ {
+ if (data >= 0)
+ {
+ quantparam->scale.push_back(data);
+ quantized_data[i] = 1;
+ }
+ else
+ {
+ quantparam->scale.push_back(-data);
+ quantized_data[i] = -1;
+ }
+ quantparam->zerop.push_back(0);
+ }
+ }
+ node->quantparam(std::move(quantparam));
+
+ switch (quant_type)
+ {
+ case loco::DataType::U8:
+ node->dtype(loco::DataType::U8);
+ node->size<loco::DataType::U8>(size);
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ assert(quantized_data[i] == 0 || quantized_data[i] == 1);
+ node->at<loco::DataType::U8>(i) = quantized_data[i];
+ }
+ break;
+ case loco::DataType::S16:
+ node->dtype(loco::DataType::S16);
+ node->size<loco::DataType::S16>(size);
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ assert(quantized_data[i] == -1 || quantized_data[i] == 1);
+ node->at<loco::DataType::S16>(i) = quantized_data[i];
+ }
+ break;
+ default:
+ throw std::runtime_error("Unsupported data type");
+ }
+}
+
+} // namespace
+
+namespace luci
+{
+
+void QuantizeWeights::quantize_weights(luci::CircleConst *weights)
+{
+ // Find min/max per channel-wise
+ if (granularity == QuantizationGranularity::ChannelWise)
+ {
+ auto quantparam = weights->quantparam();
+ if (quantparam == nullptr)
+ {
+ assert(false && "quantparam is nullptr");
+ return;
+ }
+
+ auto min = quantparam->min;
+ auto scaling_factor = quantparam->scale;
+ int32_t channel_dim_index = 0;
+
+ if (output_type == loco::DataType::U8)
+ {
+ asym_wquant_per_channel(weights, min, scaling_factor, channel_dim_index);
+ }
+ else
+ {
+ sym_wquant_per_channel(weights, scaling_factor, channel_dim_index);
+ }
+ quantparam->min.clear();
+ quantparam->max.clear();
+ quantparam->quantized_dimension = channel_dim_index;
+ }
+ // Find min/max per layer-wise
+ else
+ {
+ // Quantize using recorded quantparam
+ auto quantparam = weights->quantparam();
+ assert(quantparam != nullptr);
+ assert(quantparam->min.size() == 1); // only support layer-wise quant
+ assert(quantparam->scale.size() == 1); // only support layer-wise quant
+ auto min = quantparam->min[0];
+ auto scaling_factor = quantparam->scale[0];
+ asym_wquant_per_layer(weights, min, scaling_factor);
+ quantparam->min.clear();
+ quantparam->max.clear();
+ }
+}
+void QuantizeWeights::visit(luci::CircleConv2D *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeWeights QuantizeWeights::visit node: " << node->name() << std::endl;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
+ if (!is_quantized(weights))
+ {
+ auto new_weights = luci::clone(weights);
+ node->filter(new_weights);
+ quantize_weights(new_weights);
+ }
+}
+
+void QuantizeWeights::visit(luci::CircleDepthwiseConv2D *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeWeights QuantizeWeights::visit node: " << node->name() << std::endl;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
+ if (!is_quantized(weights))
+ {
+ auto new_weights = luci::clone(weights);
+ node->filter(new_weights);
+ quantize_weights(new_weights);
+ }
+}
+
+void QuantizeWeights::visit(luci::CircleInstanceNorm *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeWeights QuantizeWeights::visit node: " << node->name() << std::endl;
+
+ auto gamma = loco::must_cast<luci::CircleConst *>(node->gamma());
+ auto beta = loco::must_cast<luci::CircleConst *>(node->beta());
+
+ if (!is_quantized(gamma))
+ {
+ assert(gamma->dtype() == loco::DataType::FLOAT32);
+ auto new_gamma = luci::clone(gamma);
+ if (granularity == QuantizationGranularity::LayerWise)
+ quant_const(new_gamma, output_type);
+ else if (granularity == QuantizationGranularity::ChannelWise)
+ quant_const_per_channel(new_gamma, output_type);
+ node->gamma(new_gamma);
+ }
+ if (!is_quantized(beta))
+ {
+ assert(beta->dtype() == loco::DataType::FLOAT32);
+ auto new_beta = luci::clone(beta);
+ if (granularity == QuantizationGranularity::LayerWise)
+ quant_const(new_beta, output_type);
+ else if (granularity == QuantizationGranularity::ChannelWise)
+ quant_const_per_channel(new_beta, output_type);
+ node->beta(new_beta);
+ }
+}
+
+void QuantizeWeights::visit(luci::CirclePRelu *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeWeights QuantizeWeights::visit node: " << node->name() << std::endl;
+
+ auto alpha = loco::must_cast<luci::CircleConst *>(node->alpha());
+
+ if (!is_quantized(alpha))
+ {
+ assert(alpha->dtype() == loco::DataType::FLOAT32);
+ auto new_alpha = luci::clone(alpha);
+ if (granularity == QuantizationGranularity::LayerWise)
+ quant_const(new_alpha, output_type);
+ else if (granularity == QuantizationGranularity::ChannelWise)
+ quant_const_per_channel(new_alpha, output_type);
+ node->alpha(new_alpha);
+ }
+}
+
+void QuantizeWeights::visit(luci::CircleTransposeConv *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeWeights QuantizeWeights::visit node: " << node->name() << std::endl;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
+ if (!is_quantized(weights))
+ {
+ auto new_weights = luci::clone(weights);
+ node->filter(new_weights);
+ quantize_weights(new_weights);
+ }
+}
+
+void QuantizeWeights::visit(luci::CircleFullyConnected *node)
+{
+ LOGGER(l);
+ INFO(l) << "QuantizeWeights QuantizeWeights::visit node: " << node->name() << std::endl;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->weights());
+ if (!is_quantized(weights))
+ {
+ auto new_weights = luci::clone(weights);
+ node->weights(new_weights);
+ quantize_weights(new_weights);
+ }
+}
+
+void QuantizeWeights::visit(luci::CircleNode *) {}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/QuantizeWeights.h b/compiler/luci/pass/src/QuantizeWeights.h
new file mode 100644
index 000000000..f62cd40f3
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizeWeights.h
@@ -0,0 +1,55 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_QUANTIZE_WEIGHTS_H__
+#define __LUCI_QUANTIZE_WEIGHTS_H__
+
+#include <luci/Pass/QuantizationParameters.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+namespace luci
+{
+
+/**
+ * @brief QuantizeWeights quantizes tensors for weights
+ * @details Find min/max values on the fly and then quantize
+ */
+struct QuantizeWeights final : public luci::CircleNodeMutableVisitor<void>
+{
+ QuantizeWeights(loco::DataType input, loco::DataType output, QuantizationGranularity gr)
+ : input_type(input), output_type(output), granularity(gr)
+ {
+ }
+
+ loco::DataType input_type;
+ loco::DataType output_type;
+ QuantizationGranularity granularity;
+
+private:
+ void quantize_weights(luci::CircleConst *weights);
+
+ void visit(luci::CircleConv2D *node);
+ void visit(luci::CircleDepthwiseConv2D *node);
+ void visit(luci::CircleInstanceNorm *node);
+ void visit(luci::CirclePRelu *node);
+ void visit(luci::CircleTransposeConv *node);
+ void visit(luci::CircleFullyConnected *node);
+ void visit(luci::CircleNode *);
+};
+
+} // namespace luci
+
+#endif // __LUCI_QUANTIZE_WEIGHTS_H__
diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
index c3552ec52..d9a9d4db7 100644
--- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
+++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
@@ -15,55 +15,32 @@
*/
#include "luci/Pass/QuantizeWithMinMaxPass.h"
+#include "luci/Pass/PropagateQParamForwardPass.h"
+#include "luci/Pass/PropagateQParamBackwardPass.h"
+#include "luci/Pass/RemoveRedundantQuantizePass.h"
+#include "QuantizeActivation.h"
+#include "QuantizeWeights.h"
+#include "QuantizeBias.h"
#include "QuantizationUtils.h"
+#include "ProgressReporter.h"
+#include "helpers/LayerInfoMap.h"
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
#include <luci/Service/Nodes/CircleConst.h>
#include <luci/Profile/CircleNodeOrigin.h>
#include <luci/Log.h>
+#include <logo/Phase.h>
#include <oops/UserExn.h>
#include <iostream>
#include <cmath>
-#include <functional>
namespace
{
using namespace luci;
-using IterFunc = std::function<void(uint32_t *, loco::TensorShape &, int32_t)>;
-
-void iterate_per_channel(CircleConst *node, int32_t &channel_dim_index, IterFunc func)
-{
- loco::TensorShape dimension;
- dimension.rank(4);
- uint32_t indices[4] = {
- 0,
- };
-
- if (!get_channel_dim_index(node, dimension, channel_dim_index))
- {
- assert(false);
- return;
- }
-
- for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++)
- {
- for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++)
- {
- for (indices[2] = 0; indices[2] < dimension.dim(2).value(); indices[2]++)
- {
- for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++)
- {
- func(indices, dimension, channel_dim_index);
- }
- }
- }
- }
-}
-
// Create a Quantize Op whose
// dtype is out_type
// shape is the same with node
@@ -80,7 +57,17 @@ luci::CircleQuantize *create_quantize_op(luci::CircleNode *node, loco::DataType
quantize->shape_status(luci::ShapeStatus::VALID);
auto qparam = node->quantparam();
- assert(qparam); // FIX_CALLER_UNLESS
+ assert(qparam); // FIX_CALLER_UNLESS
+
+ auto qtype = luci::activation_qtype(node);
+ if (qtype == ActivationQType::PreDefinedValue)
+ {
+ quantize->quantparam(luci::make_predefined_qparam(node->opcode(), out_type));
+ return quantize;
+ }
+
+ assert(qtype == ActivationQType::MinMax or qtype == ActivationQType::IntScale);
+
assert(qparam->min.size() == 1); // FIX_CALLER_UNLESS
assert(qparam->max.size() == 1); // FIX_CALLER_UNLESS
auto min = qparam->min[0];
@@ -104,9 +91,17 @@ luci::CircleQuantize *create_quantize_op(luci::CircleNode *node, loco::DataType
auto quantparam = std::make_unique<CircleQuantParam>();
quantparam->scale.push_back(scaling_factor);
quantparam->zerop.push_back(zp);
+ // Save original min/max (not nudged_min/max). Nudged min/max
+ // is different from the real min/max values, causing wrong
+ // qparam when quantization dtype is changed.
+ quantparam->min.push_back(min);
+ quantparam->max.push_back(max);
quantize->quantparam(std::move(quantparam));
+ if (qtype == ActivationQType::IntScale)
+ set_int_scale(quantize);
+
return quantize;
}
@@ -118,1412 +113,232 @@ namespace luci
namespace
{
-// Create a new const node from an existing node.
-// The new node has the following characteristics
-// type: T
-// shape: same with 'node' (given as an argument)
-// buffer size: 'size' (given as an argument)
-// Note that contents are not filled in this function.
-template <loco::DataType T>
-luci::CircleConst *create_empty_const_from(luci::CircleConst *node, uint32_t size)
-{
- auto new_node = node->graph()->nodes()->create<CircleConst>();
- // TODO: We don't have any naming convention for quantized nodes yet.
- // Fix this when we have one.
- new_node->name(node->name());
- new_node->dtype(T);
- new_node->rank(node->rank());
- for (uint32_t i = 0; i < node->rank(); i++)
- new_node->dim(i).set(node->dim(i).value());
-
- new_node->size<T>(size);
- new_node->shape_status(luci::ShapeStatus::VALID);
-
- return new_node;
-}
-
-void overwrite_quantparam(luci::CircleNode *source, luci::CircleNode *target)
-{
- auto source_qparam = source->quantparam();
- if (source_qparam == nullptr)
- throw std::runtime_error("source quantparam is not found during overwrite");
-
- auto target_qparam = target->quantparam();
- if (target_qparam == nullptr)
- {
- auto quantparam = std::make_unique<CircleQuantParam>();
- target->quantparam(std::move(quantparam));
- target_qparam = target->quantparam();
-
- if (target_qparam == nullptr)
- throw std::runtime_error("Creating new quant param failed");
- }
- target_qparam->min = source_qparam->min;
- target_qparam->max = source_qparam->max;
- target_qparam->scale = source_qparam->scale;
- target_qparam->zerop = source_qparam->zerop;
- target_qparam->quantized_dimension = source_qparam->quantized_dimension;
-}
-
-void quant_const_values(luci::CircleConst *const_node, float scaling_factor, float zerop,
- loco::DataType quant_type)
-{
- uint32_t size = const_node->size<loco::DataType::FLOAT32>();
-
- const float scaling_factor_inv = 1.0 / scaling_factor;
- std::vector<int32_t> quantized_values(size);
- for (uint32_t i = 0; i < size; ++i)
- {
- auto data = static_cast<double>(const_node->at<loco::DataType::FLOAT32>(i));
- double quantized_float = std::round(data * scaling_factor_inv) + zerop;
- constexpr auto int_max = static_cast<double>(std::numeric_limits<int32_t>::max());
- constexpr auto int_min = static_cast<double>(std::numeric_limits<int32_t>::min());
- quantized_float = std::min(int_max, std::max(int_min, quantized_float));
-
- quantized_values[i] = static_cast<int32_t>(quantized_float);
- }
-
- switch (quant_type)
- {
- case loco::DataType::U8:
- const_node->dtype(loco::DataType::U8); // change the type of tensor
- const_node->size<loco::DataType::U8>(size); // resize tensor
- for (uint32_t i = 0; i < size; ++i)
- const_node->at<loco::DataType::U8>(i) = std::min(255, std::max(0, quantized_values[i]));
- break;
- case loco::DataType::S16:
- assert(zerop == 0);
- const_node->dtype(loco::DataType::S16); // change the type of tensor
- const_node->size<loco::DataType::S16>(size); // resize tensor
- for (uint32_t i = 0; i < size; ++i)
- const_node->at<loco::DataType::S16>(i) =
- std::min(32767, std::max(-32767, quantized_values[i]));
- break;
- default:
- throw std::runtime_error("Unsupported data type");
- }
-}
-
-// Quantize const per channel
-//
-// The last dimension of const is the same as the dimension of channel
-// And the rest of the const dimensions should be 1
-// So, a 'single value' is quantized per channel
-//
-// Quantization spec (f: fp value, q: quantized value)
-//
-// uint8
-// Positive f: f = f * (q - 0) [q = 1, scale = f, zp = 0]
-// Negative f: f = (-f) * (q - 1) [q = 0, scale = -f, zp = 1]
-//
-// int16
-// Positive f: f = f * (q - 0) [q = 1, scale = f, zp = 0]
-// Negative f: f = (-f) * (q - 0) [q = -1, scale = -f, zp = 0]
-void quant_const_per_channel(CircleConst *node, loco::DataType quant_type)
-{
- assert(node->dtype() == loco::DataType::FLOAT32);
- assert(node->rank() > 0);
-
- for (uint32_t i = 0; i < node->rank() - 1; i++)
- {
- // Caller should call this function when the below condition is satisfied
- if (node->dim(i).value() != 1)
- throw std::runtime_error("Non-channel dimension of const node must be 1");
- }
-
- uint32_t size = node->size<loco::DataType::FLOAT32>();
- assert(size == node->dim(node->rank() - 1).value());
-
- auto quantparam = std::make_unique<CircleQuantParam>();
- quantparam->quantized_dimension = node->rank() - 1;
- std::vector<int32_t> quantized_data(size);
-
- for (uint32_t i = 0; i < size; ++i)
- {
- auto data = node->at<loco::DataType::FLOAT32>(i);
- if (quant_type == loco::DataType::U8)
- {
- if (data >= 0)
- {
- quantparam->scale.push_back(data);
- quantparam->zerop.push_back(0);
- quantized_data[i] = 1;
- }
- else
- {
- quantparam->scale.push_back(-data);
- quantparam->zerop.push_back(1);
- quantized_data[i] = 0;
- }
- }
- else if (quant_type == loco::DataType::S16)
- {
- if (data >= 0)
- {
- quantparam->scale.push_back(data);
- quantized_data[i] = 1;
- }
- else
- {
- quantparam->scale.push_back(-data);
- quantized_data[i] = -1;
- }
- quantparam->zerop.push_back(0);
- }
- }
- node->quantparam(std::move(quantparam));
-
- switch (quant_type)
- {
- case loco::DataType::U8:
- node->dtype(loco::DataType::U8);
- node->size<loco::DataType::U8>(size);
- for (uint32_t i = 0; i < size; ++i)
- {
- assert(quantized_data[i] == 0 || quantized_data[i] == 1);
- node->at<loco::DataType::U8>(i) = quantized_data[i];
- }
- break;
- case loco::DataType::S16:
- node->dtype(loco::DataType::S16);
- node->size<loco::DataType::S16>(size);
- for (uint32_t i = 0; i < size; ++i)
- {
- assert(quantized_data[i] == -1 || quantized_data[i] == 1);
- node->at<loco::DataType::S16>(i) = quantized_data[i];
- }
- break;
- default:
- throw std::runtime_error("Unsupported data type");
- }
-}
-
-void quant_const(CircleConst *node, loco::DataType quant_type)
-{
- assert(node->dtype() == loco::DataType::FLOAT32);
-
- float min = std::numeric_limits<float>::max();
- float max = std::numeric_limits<float>::lowest();
- for (uint32_t i = 0; i < node->size<loco::DataType::FLOAT32>(); i++)
- {
- auto data = node->at<loco::DataType::FLOAT32>(i);
- min = data < min ? data : min;
- max = data > max ? data : max;
- }
-
- float scaling_factor{0.0};
- int64_t zp{0};
- float nudged_min{0.0};
- float nudged_max{0.0};
-
- switch (quant_type)
- {
- case loco::DataType::U8:
- asymmetric_wquant_with_minmax_per_layer(node, min, max, scaling_factor, zp, nudged_min,
- nudged_max);
- break;
- case loco::DataType::S16:
- symmetric_wquant_with_minmax_per_layer(node, min, max, scaling_factor, zp, nudged_min,
- nudged_max);
- break;
- default:
- throw std::runtime_error("Unsupported data type");
- }
-
- auto quantparam = std::make_unique<CircleQuantParam>();
- quantparam->scale.push_back(scaling_factor);
- quantparam->zerop.push_back(zp);
- node->quantparam(std::move(quantparam));
-}
-
-// Check if the node is the bias of Conv2D, DepthwiseConv2D, FullyConnected, or TransposeConv layer
-// Returns a list of <input, weights, output> vectors for the above operators.
-// Note that it returns a 'list' because bias can be used by multiple operators.
-std::vector<std::vector<loco::Node *>> get_input_weight_output_of_bias(CircleNode *node)
-{
- std::vector<std::vector<loco::Node *>> result;
- auto circle_const = dynamic_cast<CircleConst *>(node);
- if (circle_const == nullptr)
- return result;
-
- auto succs = loco::succs(node);
-
- for (auto out : succs)
- {
- auto conv = dynamic_cast<CircleConv2D *>(out);
- if (conv != nullptr && conv->bias() == circle_const)
- {
- assert(conv->input() != nullptr);
- assert(conv->filter() != nullptr);
- result.push_back({conv->input(), conv->filter(), conv});
- continue;
- }
- auto dw_conv = dynamic_cast<CircleDepthwiseConv2D *>(out);
- if (dw_conv != nullptr && dw_conv->bias() == circle_const)
- {
- assert(dw_conv->input() != nullptr);
- assert(dw_conv->filter() != nullptr);
- result.push_back({dw_conv->input(), dw_conv->filter(), dw_conv});
- continue;
- }
- auto fc = dynamic_cast<CircleFullyConnected *>(out);
- if (fc != nullptr && fc->bias() == circle_const)
- {
- assert(fc->input() != nullptr);
- assert(fc->weights() != nullptr);
- result.push_back({fc->input(), fc->weights(), fc});
- continue;
- }
- auto tconv = dynamic_cast<CircleTransposeConv *>(out);
- if (tconv != nullptr && tconv->bias() == circle_const)
- {
- assert(tconv->outBackprop() != nullptr);
- assert(tconv->filter() != nullptr);
- result.push_back({tconv->outBackprop(), tconv->filter(), tconv});
- continue;
- }
- }
- return result;
-}
-
-CircleConst *asym_quant_bias_per_layer(CircleConst *node, float input_scale, float weight_scale,
- float *scaling_factor, int64_t *zp)
-{
- float scale = input_scale * weight_scale;
- const float scaling_factor_inv = (scale == 0) ? 0 : 1.0 / scale;
-
- uint32_t size = node->size<loco::DataType::FLOAT32>();
- std::vector<int32_t> quantized_values(size);
- for (uint32_t i = 0; i < size; ++i)
- {
- quantized_values[i] =
- static_cast<int32_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
- }
-
- auto new_bias = create_empty_const_from<loco::DataType::S32>(node, size);
-
- const int32_t kMinScale = std::numeric_limits<int32_t>::lowest();
- const int32_t kMaxScale = std::numeric_limits<int32_t>::max();
- for (uint32_t i = 0; i < size; ++i)
- {
- new_bias->at<loco::DataType::S32>(i) =
- std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
- }
- *scaling_factor = scale;
- *zp = 0;
-
- return new_bias;
-}
-
-CircleConst *quant_bias_per_channel(CircleConst *node, float input_scale,
- std::vector<float> &weight_scale,
- std::vector<float> &scaling_factor, std::vector<int64_t> &zp)
-{
- float scaling_factor_inv{0};
-
- uint32_t size = node->size<loco::DataType::FLOAT32>();
- std::vector<int32_t> quantized_values(size);
-
- for (uint32_t i = 0; i < size; ++i)
- {
- scaling_factor[i] = input_scale * weight_scale[i];
- scaling_factor_inv = (scaling_factor[i] == 0) ? 0 : 1.0 / scaling_factor[i];
- quantized_values[i] =
- static_cast<int32_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
- zp[i] = 0;
- }
-
- auto new_bias = create_empty_const_from<loco::DataType::S32>(node, size);
-
- const int32_t kMinScale = std::numeric_limits<int32_t>::lowest();
- const int32_t kMaxScale = std::numeric_limits<int32_t>::max();
- for (uint32_t i = 0; i < size; ++i)
- {
- new_bias->at<loco::DataType::S32>(i) =
- std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
- }
-
- return new_bias;
-}
-
-CircleConst *int16_quant_bias_per_channel(CircleConst *node, float input_scale,
- std::vector<float> &weight_scale,
- std::vector<float> &scaling_factor,
- std::vector<int64_t> &zp)
-{
- float scaling_factor_inv{0};
-
- uint32_t size = node->size<loco::DataType::FLOAT32>();
- std::vector<int64_t> quantized_values(size);
-
- for (uint32_t i = 0; i < size; ++i)
- {
- scaling_factor[i] = input_scale * weight_scale[i];
- scaling_factor_inv = (scaling_factor[i] == 0) ? 0 : 1.0 / scaling_factor[i];
- quantized_values[i] =
- static_cast<int64_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
- zp[i] = 0;
- }
-
- auto new_bias = create_empty_const_from<loco::DataType::S64>(node, size);
-
- for (uint32_t i = 0; i < size; ++i)
- {
- new_bias->at<loco::DataType::S64>(i) = quantized_values[i];
- }
-
- return new_bias;
-}
-
-bool has_min_max(const CircleNode *node)
-{
- return node->quantparam() && !node->quantparam()->min.empty() && !node->quantparam()->max.empty();
-}
-
-void sym_wquant_per_channel(CircleConst *node, std::vector<float> &scaling_factor,
- int32_t &channel_dim_index)
-{
- assert(node->dtype() == loco::DataType::FLOAT32);
-
- const int32_t kMaxScale = std::numeric_limits<int16_t>::max();
- const int32_t kMinScale = -kMaxScale;
-
- uint32_t size = node->size<loco::DataType::FLOAT32>();
- std::vector<int32_t> quantized_values(size);
-
- auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int32_t channel_dim_index) {
- int channel_idx = indices[channel_dim_index];
- const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
- auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
- quantized_values[cal_offset(dimension, indices)] =
- static_cast<int32_t>(std::round(data * scaling_factor_inv));
- };
-
- iterate_per_channel(node, channel_dim_index, quantize);
-
- node->dtype(loco::DataType::S16); // change the type of tensor
- node->size<loco::DataType::S16>(size); // resize tensor
- for (uint32_t i = 0; i < size; ++i)
- {
- node->at<loco::DataType::S16>(i) =
- std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
- }
-}
-
-void asym_wquant_per_channel(CircleConst *node, std::vector<float> &min,
- std::vector<float> &scaling_factor, int32_t &channel_dim_index)
-{
- assert(node->dtype() == loco::DataType::FLOAT32);
-
- const int32_t kMinScale = 0;
- const int32_t kMaxScale = 255;
-
- uint32_t size = node->size<loco::DataType::FLOAT32>();
- std::vector<int32_t> quantized_values(size);
-
- auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int32_t channel_dim_index) {
- int channel_idx = indices[channel_dim_index];
- const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
- auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
- quantized_values[cal_offset(dimension, indices)] =
- static_cast<int32_t>(std::round((data - min[channel_idx]) * scaling_factor_inv));
- };
-
- iterate_per_channel(node, channel_dim_index, quantize);
-
- node->dtype(loco::DataType::U8); // change the type of tensor
- node->size<loco::DataType::U8>(size); // resize tensor
- for (uint32_t i = 0; i < size; ++i)
- {
- node->at<loco::DataType::U8>(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
- }
-}
-
-void asym_wquant_per_layer(CircleConst *node, float min, float scaling_factor)
-{
- const int32_t kMinScale = 0;
- const int32_t kMaxScale = 255;
-
- uint32_t size = node->size<loco::DataType::FLOAT32>();
-
- const float scaling_factor_inv = 1.0 / scaling_factor;
- std::vector<int32_t> quantized_values(size);
- for (uint32_t i = 0; i < size; ++i)
- {
- auto data = node->at<loco::DataType::FLOAT32>(i);
- quantized_values[i] = static_cast<int32_t>(std::round((data - min) * scaling_factor_inv));
- }
-
- node->dtype(loco::DataType::U8); // change the type of tensor
- node->size<loco::DataType::U8>(size); // resize tensor
- for (uint32_t i = 0; i < size; ++i)
- {
- node->at<loco::DataType::U8>(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
- }
-}
-
-void set_bias(luci::CircleNode *node, luci::CircleConst *bias)
-{
- if (auto conv = dynamic_cast<CircleConv2D *>(node))
- conv->bias(bias);
- else if (auto dconv = dynamic_cast<CircleDepthwiseConv2D *>(node))
- dconv->bias(bias);
- else if (auto tconv = dynamic_cast<CircleTransposeConv *>(node))
- tconv->bias(bias);
- else if (auto fc = dynamic_cast<CircleFullyConnected *>(node))
- fc->bias(bias);
- else
- throw std::runtime_error("Only convolution, depthwise convolution, transposed convolution, and "
- "fully-connected layer have bias");
-}
-
-void set_act_qparam(luci::CircleNode *node, float scale, int64_t zp)
-{
- assert(node); // FIX_CALLER_UNLESS
- assert(node->quantparam()); // FIX_CALLER_UNLESS
-
- auto qparam = node->quantparam();
- assert(qparam->scale.size() == 1); // FIX_CALLER_UNLESS
- assert(qparam->zerop.size() == 1); // FIX_CALLER_UNLESS
- qparam->scale[0] = scale;
- qparam->zerop[0] = zp;
-}
-
-/**
- * @brief Manually set scale/zp of output tensor of special Ops
- */
-struct QuantizeSpecialActivation final : public luci::CircleNodeMutableVisitor<void>
-{
- QuantizeSpecialActivation(loco::DataType input, loco::DataType output)
- : input_type(input), output_type(output)
- {
- }
-
- loco::DataType input_type;
- loco::DataType output_type;
-
- void visit(luci::CircleNode *)
- {
- // Do nothing by default
- }
-
- void visit(luci::CircleLogistic *node)
- {
- if (output_type == loco::DataType::U8)
- set_act_qparam(node, 1.0f / 256.0f, 0);
- else
- {
- assert(output_type == loco::DataType::S16);
- set_act_qparam(node, 1.0f / 32768.0f, 0);
- }
- }
-
- void visit(luci::CircleTanh *node)
- {
- if (output_type == loco::DataType::U8)
- set_act_qparam(node, 2.0f / 256.0f, 128);
- else
- {
- assert(output_type == loco::DataType::S16);
- set_act_qparam(node, 1.0f / 32768.0f, 0);
- }
- }
-
- void visit(luci::CircleStridedSlice *node)
- {
- auto input = loco::must_cast<luci::CircleNode *>(node->input());
- auto i_qparam = input->quantparam();
- assert(i_qparam);
- assert(i_qparam->scale.size() == 1); // FIX_CALLER_UNLESS
- assert(i_qparam->zerop.size() == 1); // FIX_CALLER_UNLESS
- auto i_scale = i_qparam->scale[0];
- auto i_zp = i_qparam->zerop[0];
-
- set_act_qparam(node, i_scale, i_zp);
- }
-
- void visit(luci::CircleSplitOut *node)
- {
- auto split = loco::must_cast<luci::CircleSplit *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(split->input());
- auto i_qparam = input->quantparam();
- assert(i_qparam);
- assert(i_qparam->scale.size() == 1); // FIX_CALLER_UNLESS
- assert(i_qparam->zerop.size() == 1); // FIX_CALLER_UNLESS
- auto i_scale = i_qparam->scale[0];
- auto i_zp = i_qparam->zerop[0];
-
- set_act_qparam(node, i_scale, i_zp);
- }
-
- void visit(luci::CircleSplitVOut *node)
- {
- auto splitv = loco::must_cast<luci::CircleSplitV *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(splitv->input());
- auto i_qparam = input->quantparam();
- assert(i_qparam);
- assert(i_qparam->scale.size() == 1); // FIX_CALLER_UNLESS
- assert(i_qparam->zerop.size() == 1); // FIX_CALLER_UNLESS
- auto i_scale = i_qparam->scale[0];
- auto i_zp = i_qparam->zerop[0];
-
- set_act_qparam(node, i_scale, i_zp);
- }
-
- void visit(luci::CircleUnpackOut *node)
- {
- auto unpack = loco::must_cast<luci::CircleUnpack *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(unpack->value());
- auto i_qparam = input->quantparam();
- assert(i_qparam);
- assert(i_qparam->scale.size() == 1); // FIX_CALLER_UNLESS
- assert(i_qparam->zerop.size() == 1); // FIX_CALLER_UNLESS
- auto i_scale = i_qparam->scale[0];
- auto i_zp = i_qparam->zerop[0];
-
- set_act_qparam(node, i_scale, i_zp);
- }
-
- // TODO Move Softmax, Floor, Ceil from QuantizeActivation to here
-};
-
/**
- * @brief QuantizeActivation quantizes tensors for activations
- * @details Quantize using recorded min/max values
+ * Insert Quantize operator for mixed-precision quantization
+ * 1. Before input feature map (only for non-const)
+ * 2. After output feature map
+ *
+ * For example, if default_dtype = U8 and op_dtype = S16,
+ * 1. Quantize Op for U8->S16 is inserted before ifm
+ * 2. Quantize Op for S16->U8 is inserted after ofm
+ *
+ * Why not insert Quantize Op for const ifm?
+ * We quantize const tensor at once to preserve precision.
+ * For example, if default dtype = U8, op_dtype = S16, and op is CONV2D,
+ * We directly quantize weights to 16 bits, not 8->16 bits.
*/
-struct QuantizeActivation final : public luci::CircleNodeMutableVisitor<bool>
+struct InsertQuantizeOp final : public luci::CircleNodeMutableVisitor<void>
{
- QuantizeActivation(loco::DataType input, loco::DataType output)
- : input_type(input), output_type(output)
+ InsertQuantizeOp(loco::DataType default_dtype, loco::DataType op_dtype)
+ : _default_dtype(default_dtype), _op_dtype(op_dtype)
{
+ assert(default_dtype != op_dtype); // FIX_CALLER_UNLESS
}
- loco::DataType input_type;
- loco::DataType output_type;
+private:
+ loco::DataType _default_dtype;
+ loco::DataType _op_dtype;
- // Quantize input tensors of each node
- bool visit(luci::CircleNode *node)
+private:
+ luci::CircleQuantize *create_in_quantize(loco::Node *in, loco::Node *origin)
+ {
+ auto input = loco::must_cast<luci::CircleNode *>(in);
+ if (input->opcode() == luci::CircleOpcode::CIRCLECONST)
+ return nullptr;
+
+ auto input_quant = create_quantize_op(input, _op_dtype);
+ input_quant->input(input);
+ auto origin_node = loco::must_cast<luci::CircleNode *>(origin);
+ luci::add_origin(input_quant, luci::get_origin(origin_node));
+ return input_quant;
+ }
+
+ void insert_out_quantize(loco::Node *node)
+ {
+ auto output = loco::must_cast<luci::CircleNode *>(node);
+ assert(output->opcode() != luci::CircleOpcode::CIRCLECONST); // FIX_CALLER_UNLESS
+ auto output_quant = create_quantize_op(output, _default_dtype);
+
+ luci::add_origin(output_quant, luci::get_origin(output));
+ loco::replace(node).with(output_quant);
+ output_quant->input(node);
+ }
+
+// INPUT_NAME is the only activation of NODE
+#define INSERT_QUANTIZE_TO_UNARY_OP(NODE, INPUT_NAME) \
+ void visit(NODE *node) \
+ { \
+ if (auto input_quant = create_in_quantize(node->INPUT_NAME(), node)) \
+ node->INPUT_NAME(input_quant); \
+ \
+ insert_out_quantize(node); \
+ }
+
+// INPUT_NAME is the only activation of NODE
+#define INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(NODE, INPUT_NAME, OUT_NAME) \
+ void visit(NODE *node) \
+ { \
+ if (auto input_quant = create_in_quantize(node->INPUT_NAME(), node)) \
+ node->INPUT_NAME(input_quant); \
+ \
+ auto out_nodes = loco::succs(node); \
+ for (auto out_node : out_nodes) \
+ { \
+ auto out_circle = loco::must_cast<OUT_NAME *>(out_node); \
+ insert_out_quantize(out_circle); \
+ } \
+ }
+
+// INPUT_NAME1 and INPUT_NAME2 are the only activations of NODE
+#define INSERT_QUANTIZE_TO_BINARY_OP(NODE, INPUT_NAME1, INPUT_NAME2) \
+ void visit(NODE *node) \
+ { \
+ if (auto input1_quant = create_in_quantize(node->INPUT_NAME1(), node)) \
+ node->INPUT_NAME1(input1_quant); \
+ \
+ if (auto input2_quant = create_in_quantize(node->INPUT_NAME2(), node)) \
+ node->INPUT_NAME2(input2_quant); \
+ \
+ insert_out_quantize(node); \
+ }
+
+ // Default behavior (NYI)
+ void visit(luci::CircleNode *node)
+ {
+ throw std::runtime_error("Unsupported Op for mixed-precision quantization. Layer name: " +
+ node->name());
+ }
+
+ // Skip output layer
+ void visit(luci::CircleOutput *) {}
+ void visit(luci::CircleSplitVOut *) {}
+ void visit(luci::CircleSplitOut *) {}
+ void visit(luci::CircleTopKV2Out *) {}
+ void visit(luci::CircleUniqueOut *) {}
+ void visit(luci::CircleUnpackOut *) {}
+
+ // Ops that receive a single activation as an input
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleAveragePool2D, value)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleBatchToSpaceND, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleConv2D, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleDepthToSpace, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleDepthwiseConv2D, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleElu, features)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleExp, x)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleFloor, x)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleFullyConnected, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleGather, params)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleInstanceNorm, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLocalResponseNormalization, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLogistic, x)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMaxPool2D, value)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMean, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMirrorPad, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePad, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePadV2, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePRelu, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReduceProd, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReduceMax, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReduceMin, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleRelu, features)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReshape, tensor)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleResizeBilinear, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleResizeNearestNeighbor, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReverseSequence, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleRsqrt, x)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSlice, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSoftmax, logits)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSpaceToBatchND, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSpaceToDepth, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSqrt, x)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleStridedSlice, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSum, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleTanh, x)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleTile, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleTranspose, a)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleTransposeConv, outBackprop)
+
+ // Ops that receive two activations as inputs
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleAdd, x, y)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleBatchMatMul, x, y)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleDiv, x, y)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleFloorDiv, x, y)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleMaximum, x, y)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleMinimum, x, y)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleMul, x, y)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleOneHot, on_value, off_value)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CirclePow, x, y)
+ INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleSub, x, y)
+
+ // Multiple-output ops that receive one activation as inputs
+ INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(luci::CircleSplit, input, luci::CircleSplitOut)
+ INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(luci::CircleSplitV, input, luci::CircleSplitVOut)
+ INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(luci::CircleTopKV2, input, luci::CircleTopKV2Out)
+ INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(luci::CircleUnique, input, luci::CircleUniqueOut)
+ INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(luci::CircleUnpack, value, luci::CircleUnpackOut)
+
+ // AddN has arbitrary number of inputs
+ void visit(luci::CircleAddN *node)
{
- LOGGER(l);
- INFO(l) << "QuantizeActivation visit node: " << node->name() << std::endl;
auto arity = node->arity();
for (uint32_t i = 0; i < arity; i++)
{
- auto input_node = node->arg(i);
- auto circle_node = loco::must_cast<luci::CircleNode *>(input_node);
-
- // Check if this is already quantized
- if (is_quantized(circle_node))
- continue;
-
- // Check if this is bias (bias is quantized later)
- auto iwo = get_input_weight_output_of_bias(circle_node);
- if (iwo.size() > 0)
- continue;
-
- // Check if this is bool type (bool type is not quantized)
- if (circle_node->dtype() == loco::DataType::BOOL)
- continue;
-
- // Check if this is activation
- // We assume min/max are recorded only for activations
- if (has_min_max(circle_node) && !is_weights(circle_node))
- {
- // Quantize using recorded min/max
- auto quantparam = circle_node->quantparam();
- assert(quantparam);
- assert(quantparam->min.size() == 1); // only support layer-wise quant
- assert(quantparam->max.size() == 1); // only support layer-wise quant
- auto min = quantparam->min[0];
- auto max = quantparam->max[0];
-
- // Special values
- if (circle_node->opcode() == luci::CircleOpcode::SOFTMAX)
- {
- min = 0.0f;
- max = 1.0f;
- }
-
- float scaling_factor{0};
- int64_t zp{0};
- float nudged_min{0};
- float nudged_max{0};
-
- if (output_type == loco::DataType::U8)
- {
- compute_asym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
- circle_node->dtype(loco::DataType::U8);
- }
- else
- {
- compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
- circle_node->dtype(loco::DataType::S16);
- }
-
- // Nodes fused with activation functions which need special quantization
- auto fused_act_node =
- dynamic_cast<CircleNodeMixin<CircleNodeTrait::FusedActFunc> *>(circle_node);
- if (fused_act_node != nullptr &&
- fused_act_node->fusedActivationFunction() == FusedActFunc::TANH)
- {
- if (output_type == loco::DataType::U8)
- {
- scaling_factor = 2.0f / 256.0f;
- zp = 128;
- }
- else
- {
- assert(output_type == loco::DataType::S16);
- scaling_factor = 1.0f / 32768.0f;
- zp = 0;
- }
- }
-
- // The output of these Ops should be integer, so scale should be integer
- // TODO Handle cases where the integer scale needs to be propagated
- if (circle_node->opcode() == CircleOpcode::FLOOR ||
- circle_node->opcode() == CircleOpcode::FLOOR_DIV ||
- circle_node->opcode() == CircleOpcode::FLOOR_MOD ||
- circle_node->opcode() == CircleOpcode::CEIL)
- {
- assert(scaling_factor >= 0); // FIX_ME_UNLESS
- scaling_factor = scaling_factor < 1 ? 1.0f : std::round(scaling_factor);
- }
-
- circle_node->quantparam()->scale.push_back(scaling_factor);
- circle_node->quantparam()->zerop.push_back(zp);
- }
- // Fix special attributes
- if (circle_node->opcode() == luci::CircleOpcode::CAST)
- {
- auto *cast = loco::must_cast<luci::CircleCast *>(circle_node);
- auto *cast_input = loco::must_cast<luci::CircleNode *>(cast->x());
-
- // make sure that cast_input is already quantized
- assert(cast_input->dtype() != loco::DataType::FLOAT32);
- cast->in_data_type(cast_input->dtype());
- cast->out_data_type(cast->dtype());
- }
- }
- return false;
- }
-};
-
-struct QuantizeBias final : public luci::CircleNodeMutableVisitor<bool>
-{
- QuantizeBias(loco::DataType input, loco::DataType output, QuantizationGranularity gr)
- : input_type(input), output_type(output), granularity(gr)
- {
- }
-
- loco::DataType input_type;
- loco::DataType output_type;
- QuantizationGranularity granularity;
-
- // Quantize bias node
- bool visit(luci::CircleNode *node)
- {
- // Check if this is already quantized
- if (is_quantized(node))
- return false;
-
- auto iwo_list = get_input_weight_output_of_bias(node);
-
- for (auto iwo : iwo_list)
- {
- assert(iwo.size() == 3);
-
- auto input = loco::must_cast<luci::CircleNode *>(iwo[0]);
- auto weight = loco::must_cast<luci::CircleNode *>(iwo[1]);
- auto output = loco::must_cast<luci::CircleNode *>(iwo[2]);
-
- auto const_bias = loco::must_cast<luci::CircleConst *>(node);
- assert(const_bias->dtype() == loco::DataType::FLOAT32);
-
- // If input is const, it is quantized here, not in QuantizeActivation
- if (auto const_input = dynamic_cast<luci::CircleConst *>(input))
- {
- quant_const(const_input, output_type);
- }
-
- CircleConst *new_bias = nullptr;
-
- if (granularity == QuantizationGranularity::ChannelWise)
- {
- auto input_q = input->quantparam();
- assert(input_q);
- assert(input_q->scale.size() == 1); // input scale's layer-wise
- auto input_scale = input_q->scale[0];
-
- assert(weight->quantparam() != nullptr); // weight scale's channel-wise
- auto weight_scale = weight->quantparam()->scale;
-
- uint32_t size = const_bias->size<loco::DataType::FLOAT32>();
- assert(size == weight_scale.size());
- std::vector<float> scaling_factor(size);
- std::vector<int64_t> zp(size);
-
- if (output_type == loco::DataType::U8)
- {
- new_bias =
- quant_bias_per_channel(const_bias, input_scale, weight_scale, scaling_factor, zp);
- }
- else if (output_type == loco::DataType::S16)
- {
- new_bias =
- int16_quant_bias_per_channel(const_bias, input_scale, weight_scale, scaling_factor, zp);
- }
- else
- {
- throw std::runtime_error("Unsupported quantization type.");
- }
-
- auto quantparam = std::make_unique<CircleQuantParam>();
- quantparam->scale = scaling_factor;
- quantparam->zerop = zp;
- assert(new_bias->quantparam() == nullptr); // bias should not be quantized before
- new_bias->quantparam(std::move(quantparam));
-
- set_bias(output, new_bias);
- }
- else
- {
- auto input_q = input->quantparam();
- assert(input_q);
- assert(input_q->scale.size() == 1); // Only support per-layer quant
- auto input_scale = input_q->scale[0];
-
- auto weight_q = weight->quantparam();
- assert(weight_q);
- assert(weight_q->scale.size() == 1); // Only support per-layer quant
- auto weight_scale = weight_q->scale[0];
-
- float scaling_factor{0};
- int64_t zp{0};
- new_bias =
- asym_quant_bias_per_layer(const_bias, input_scale, weight_scale, &scaling_factor, &zp);
- auto quantparam = std::make_unique<CircleQuantParam>();
- quantparam->scale.push_back(scaling_factor);
- quantparam->zerop.push_back(zp);
- assert(new_bias->quantparam() == nullptr); // bias should not be quantized before
- new_bias->quantparam(std::move(quantparam));
-
- set_bias(output, new_bias);
- }
- }
- return false;
- }
-};
-
-/**
- * @brief QuantizeWeights quantizes tensors for weights
- * @details Find min/max values on the fly and then quantize
- */
-struct QuantizeWeights final : public luci::CircleNodeMutableVisitor<bool>
-{
- QuantizeWeights(loco::DataType input, loco::DataType output, QuantizationGranularity gr)
- : input_type(input), output_type(output), granularity(gr)
- {
- }
-
- loco::DataType input_type;
- loco::DataType output_type;
- QuantizationGranularity granularity;
-
-private:
- void quantize_weights(luci::CircleConst *weights)
- {
- // Find min/max per channel-wise
- if (granularity == QuantizationGranularity::ChannelWise)
- {
- auto quantparam = weights->quantparam();
- if (quantparam == nullptr)
- {
- assert(false && "quantparam is nullptr");
- return;
- }
-
- auto min = quantparam->min;
- auto scaling_factor = quantparam->scale;
- int32_t channel_dim_index = 0;
-
- if (output_type == loco::DataType::U8)
- {
- asym_wquant_per_channel(weights, min, scaling_factor, channel_dim_index);
- }
- else
- {
- sym_wquant_per_channel(weights, scaling_factor, channel_dim_index);
- }
- quantparam->min.clear();
- quantparam->max.clear();
- quantparam->quantized_dimension = channel_dim_index;
- }
- // Find min/max per layer-wise
- else
- {
- // Quantize using recorded quantparam
- auto quantparam = weights->quantparam();
- assert(quantparam != nullptr);
- assert(quantparam->min.size() == 1); // only support layer-wise quant
- assert(quantparam->scale.size() == 1); // only support layer-wise quant
- auto min = quantparam->min[0];
- auto scaling_factor = quantparam->scale[0];
- asym_wquant_per_layer(weights, min, scaling_factor);
- quantparam->min.clear();
- quantparam->max.clear();
- }
- }
-
- bool visit(luci::CircleConv2D *node)
- {
- LOGGER(l);
- INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
-
- auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
- if (!is_quantized(weights))
- {
- auto new_weights = luci::clone(weights);
- node->filter(new_weights);
- quantize_weights(new_weights);
- return true;
+ if (auto input_quant = create_in_quantize(node->inputs(i), node))
+ node->inputs(i, input_quant);
}
- return false;
- }
-
- bool visit(luci::CircleDepthwiseConv2D *node)
- {
- LOGGER(l);
- INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
- auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
- if (!is_quantized(weights))
- {
- auto new_weights = luci::clone(weights);
- node->filter(new_weights);
- quantize_weights(new_weights);
- return true;
- }
- return false;
+ insert_out_quantize(node);
}
- bool visit(luci::CircleInstanceNorm *node)
+ // Concat has arbitrary number of inputs
+ void visit(luci::CircleConcatenation *node)
{
- LOGGER(l);
- INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
-
- auto gamma = loco::must_cast<luci::CircleConst *>(node->gamma());
- auto beta = loco::must_cast<luci::CircleConst *>(node->beta());
-
- bool changed = false;
- if (!is_quantized(gamma))
- {
- assert(gamma->dtype() == loco::DataType::FLOAT32);
- auto new_gamma = luci::clone(gamma);
- if (granularity == QuantizationGranularity::LayerWise)
- quant_const(new_gamma, output_type);
- else if (granularity == QuantizationGranularity::ChannelWise)
- quant_const_per_channel(new_gamma, output_type);
- node->gamma(new_gamma);
- changed = true;
- }
- if (!is_quantized(beta))
- {
- assert(beta->dtype() == loco::DataType::FLOAT32);
- auto new_beta = luci::clone(beta);
- if (granularity == QuantizationGranularity::LayerWise)
- quant_const(new_beta, output_type);
- else if (granularity == QuantizationGranularity::ChannelWise)
- quant_const_per_channel(new_beta, output_type);
- node->beta(new_beta);
- changed = true;
- }
-
- return changed;
- }
-
- bool visit(luci::CirclePRelu *node)
- {
- LOGGER(l);
- INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
-
- auto alpha = loco::must_cast<luci::CircleConst *>(node->alpha());
-
- if (!is_quantized(alpha))
+ auto arity = node->arity();
+ for (uint32_t i = 0; i < arity; i++)
{
- assert(alpha->dtype() == loco::DataType::FLOAT32);
- auto new_alpha = luci::clone(alpha);
- if (granularity == QuantizationGranularity::LayerWise)
- quant_const(new_alpha, output_type);
- else if (granularity == QuantizationGranularity::ChannelWise)
- quant_const_per_channel(new_alpha, output_type);
- node->alpha(new_alpha);
- return true;
+ if (auto input_quant = create_in_quantize(node->values(i), node))
+ node->values(i, input_quant);
}
- return false;
+ insert_out_quantize(node);
}
- bool visit(luci::CircleTransposeConv *node)
+ // Pack has arbitrary number of inputs
+ void visit(luci::CirclePack *node)
{
- LOGGER(l);
- INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
-
- auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
- if (!is_quantized(weights))
+ auto arity = node->arity();
+ for (uint32_t i = 0; i < arity; i++)
{
- auto new_weights = luci::clone(weights);
- node->filter(new_weights);
- quantize_weights(new_weights);
- return true;
+ if (auto input_quant = create_in_quantize(node->values(i), node))
+ node->values(i, input_quant);
}
- return false;
- }
-
- bool visit(luci::CircleFullyConnected *node)
- {
- LOGGER(l);
- INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
- auto weights = loco::must_cast<luci::CircleConst *>(node->weights());
- if (!is_quantized(weights))
- {
- auto new_weights = luci::clone(weights);
- node->weights(new_weights);
- quantize_weights(new_weights);
- return true;
- }
- return false;
+ insert_out_quantize(node);
}
- bool visit(luci::CircleNode *) { return false; }
+#undef INSERT_QUANTIZE_TO_UNARY_OP
+#undef INSERT_QUANTIZE_TO_BINARY_OP
+#undef INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP
};
-/** EXAMPLE
- *
- * BEFORE
- *
- * [CircleNode] [CircleConst]
- * (qparam1) (FP32)
- * \ /
- * \ /
- * [CirclePack]
- * (qparam2)
- *
- * AFTER
- *
- * [CircleNode] [CircleConst] [CircleConst] <- Dead node
- * (qparam2) (qparam2) (FP32)
- * \ /
- * \ /
- * [CirclePack]
- * (qparam2)
- *
- * NOTE Quantization parameter of CirclePack (qparam2) is propagated to the inputs.
- */
-void propagate_pack_quantparam(luci::CirclePack *pack, loco::DataType quant_type)
-{
- assert(pack->quantparam() != nullptr);
-
- const auto num_inputs = pack->values_count();
-
- for (uint32_t i = 0; i < num_inputs; i++)
- {
- auto node = loco::must_cast<luci::CircleNode *>(pack->arg(i));
-
- // Skip if this input is PACK Op
- if (node->opcode() == luci::CircleOpcode::PACK)
- continue;
-
- // Quantize constant values
- if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
- {
- luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
- if (const_node->dtype() != loco::DataType::FLOAT32)
- throw std::runtime_error("Unsupported data type for constant input of pack Op");
-
- const auto pack_qparam = pack->quantparam();
- if (pack_qparam == nullptr)
- throw std::runtime_error("quantparam of pack is not found during propagation");
-
- assert(pack_qparam->scale.size() == 1);
- assert(pack_qparam->zerop.size() == 1);
- const auto scaling_factor = pack_qparam->scale[0];
- const auto zerop = pack_qparam->zerop[0];
-
- auto new_const = luci::clone(const_node);
- quant_const_values(new_const, scaling_factor, zerop, quant_type);
- pack->values(i, new_const);
- overwrite_quantparam(pack, new_const);
- }
- else
- {
- const auto succs = loco::succs(node);
- if (succs.size() > 1)
- continue;
-
- // Non-const input must have been quantized
- assert(node->quantparam() != nullptr);
- overwrite_quantparam(pack, node);
- }
- }
-}
-
-/**
- * @brief Quantize const input tensors using min/max of const values
- */
-void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type)
-{
- auto opcode = node->opcode();
- auto arity = node->arity();
-
- loco::Node *input_node{nullptr};
- luci::CircleConst *const_node{nullptr};
-
- switch (opcode)
- {
- case luci::CircleOpcode::CONV_2D:
- case luci::CircleOpcode::DEPTHWISE_CONV_2D:
- case luci::CircleOpcode::FULLY_CONNECTED:
- case luci::CircleOpcode::INSTANCE_NORM:
- case luci::CircleOpcode::PRELU:
- case luci::CircleOpcode::TRANSPOSE_CONV:
- // Handled in QuantizeWeights and QuantizeBias
- break;
-
- case luci::CircleOpcode::CONCATENATION:
- // Handled in propagate_concat_quantparam
- break;
-
- case luci::CircleOpcode::LOGICAL_OR:
- // Inputs of logical Ops are bool, thus not quantized
- break;
-
- case luci::CircleOpcode::ARG_MAX:
- case luci::CircleOpcode::ARG_MIN:
- case luci::CircleOpcode::BATCH_TO_SPACE_ND:
- case luci::CircleOpcode::LOCAL_RESPONSE_NORMALIZATION:
- case luci::CircleOpcode::MEAN:
- case luci::CircleOpcode::MIRROR_PAD:
- case luci::CircleOpcode::PAD:
- case luci::CircleOpcode::REDUCE_ANY:
- case luci::CircleOpcode::REDUCE_PROD:
- case luci::CircleOpcode::REDUCE_MAX:
- case luci::CircleOpcode::REDUCE_MIN:
- case luci::CircleOpcode::RESHAPE:
- case luci::CircleOpcode::RESIZE_BILINEAR:
- case luci::CircleOpcode::RESIZE_NEAREST_NEIGHBOR:
- case luci::CircleOpcode::REVERSE_SEQUENCE:
- case luci::CircleOpcode::SLICE:
- case luci::CircleOpcode::SPACE_TO_BATCH_ND:
- case luci::CircleOpcode::SPLIT_V:
- case luci::CircleOpcode::STRIDED_SLICE:
- case luci::CircleOpcode::SUM:
- case luci::CircleOpcode::TILE:
- case luci::CircleOpcode::TOPK_V2:
- case luci::CircleOpcode::TRANSPOSE:
- // The second input of these Ops should not be quantized
- // Ex: axis, paddings
- input_node = node->arg(0);
- const_node = dynamic_cast<luci::CircleConst *>(input_node);
- if (const_node != nullptr && !is_quantized(const_node))
- quant_const(const_node, output_type);
- break;
-
- case luci::CircleOpcode::ADD:
- case luci::CircleOpcode::ADD_N:
- case luci::CircleOpcode::DEPTH_TO_SPACE:
- case luci::CircleOpcode::DIV:
- case luci::CircleOpcode::ELU:
- case luci::CircleOpcode::EQUAL:
- case luci::CircleOpcode::EXP:
- case luci::CircleOpcode::FLOOR:
- case luci::CircleOpcode::FLOOR_DIV:
- case luci::CircleOpcode::GREATER:
- case luci::CircleOpcode::GREATER_EQUAL:
- case luci::CircleOpcode::LESS:
- case luci::CircleOpcode::LESS_EQUAL:
- case luci::CircleOpcode::LOGISTIC:
- case luci::CircleOpcode::MAXIMUM:
- case luci::CircleOpcode::MINIMUM:
- case luci::CircleOpcode::MUL:
- case luci::CircleOpcode::NOT_EQUAL:
- case luci::CircleOpcode::POW:
- case luci::CircleOpcode::RSQRT:
- case luci::CircleOpcode::SOFTMAX:
- case luci::CircleOpcode::SPACE_TO_DEPTH:
- case luci::CircleOpcode::SQRT:
- case luci::CircleOpcode::SUB:
- case luci::CircleOpcode::TANH:
- case luci::CircleOpcode::UNPACK:
- // Quantize all const inputs using their values
- for (uint32_t i = 0; i < arity; i++)
- {
- input_node = node->arg(i);
- const_node = dynamic_cast<luci::CircleConst *>(input_node);
- if (const_node != nullptr && !is_quantized(const_node))
- quant_const(const_node, output_type);
- }
- break;
-
- case luci::CircleOpcode::SPLIT:
- // Only the second input is quantized
- // First input should not be quantized (e.g., split_dim)
- input_node = node->arg(1);
- const_node = dynamic_cast<luci::CircleConst *>(input_node);
- if (const_node != nullptr && !is_quantized(const_node))
- quant_const(const_node, output_type);
- break;
-
- case luci::CircleOpcode::PADV2:
- // First and third constant inputs are quantized
- // Second input should not be quantized (e.g., paddings)
- // Quant params are propagated either from output range to the non-constant input
- // or from input to output and constant values
- propagate_pad_v2_quantparam(loco::must_cast<CirclePadV2 *>(node), output_type);
- break;
-
- case luci::CircleOpcode::PACK:
- // Quant param is propagated from output to inputs
- propagate_pack_quantparam(loco::must_cast<CirclePack *>(node), output_type);
- break;
-
- default:
- for (uint32_t i = 0; i < arity; i++)
- {
- input_node = node->arg(i);
- const_node = dynamic_cast<luci::CircleConst *>(input_node);
- if (const_node != nullptr)
- throw std::runtime_error("Unsupported Op for const inputs");
- }
- break;
- }
-}
-
} // namespace
-/** BEFORE
- *
- * [CircleNode] [CircleConst]
- * (U8 qparam1) (FP32)
- * \ /
- * \ /
- * [CircleConcatenation]
- * (U8 qparam2)
- *
- * AFTER
- * [CircleNode] [CircleConst] [CircleConst] <- Dead node
- * (U8 qparam2) (U8 qparam2) (FP32)
- * \ /
- * \ /
- * [CircleConcatenation]
- * (U8 qparam2)
- */
-void propagate_concat_quantparam(luci::CircleConcatenation *concat, loco::DataType quant_type)
-{
- assert(concat->quantparam() != nullptr);
-
- const auto num_inputs = concat->numValues();
-
- // Quantize const inputs using their values if concat has fused act function
- if (concat->fusedActivationFunction() != luci::FusedActFunc::NONE)
- {
- for (uint32_t i = 0; i < num_inputs; i++)
- {
- auto node = concat->arg(i);
- auto const_node = dynamic_cast<luci::CircleConst *>(node);
- if (const_node != nullptr)
- {
- auto new_const = luci::clone(const_node);
- quant_const(new_const, quant_type);
- concat->values(i, new_const);
- }
- }
- return;
- }
-
- for (uint32_t i = 0; i < num_inputs; i++)
- {
- auto node = loco::must_cast<luci::CircleNode *>(concat->arg(i));
-
- // Skip if this input is CONCAT Op
- if (node->opcode() == luci::CircleOpcode::CONCATENATION)
- continue;
-
- // Quantize constant values
- if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
- {
- luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
- if (const_node->dtype() != loco::DataType::FLOAT32)
- throw std::runtime_error("Unsupported data type for constant input of concatenation Op");
-
- const auto concat_qparam = concat->quantparam();
- if (concat_qparam == nullptr)
- throw std::runtime_error("quantparam of concat is not found during propagation");
-
- assert(concat_qparam->scale.size() == 1);
- const auto scaling_factor = concat_qparam->scale[0];
- const auto zerop = concat_qparam->zerop[0];
-
- auto new_const = luci::clone(const_node);
- quant_const_values(new_const, scaling_factor, zerop, quant_type);
- concat->values(i, new_const);
- overwrite_quantparam(concat, new_const);
- }
- else
- {
- const auto succs = loco::succs(node);
- if (succs.size() > 1)
- continue;
-
- // Non-const input must have been quantized
- assert(node->quantparam() != nullptr);
- overwrite_quantparam(concat, node);
- }
- }
-}
-
-/**
- * tells if pad_v2 quantization should ignore padding value
- * In that case padding const will be quantized with input parameters, and probably clipped
- */
-bool ignore_pad_v2_const_quantization(luci::CirclePadV2 *pad)
-{
- // This is a workaround to quantize pad generated from MaxPoolWithArgmax operation properly
- // TODO use metadata hints to detect this case
- auto const_value_node = dynamic_cast<luci::CircleConst *>(pad->arg(2));
- if (!const_value_node)
- return false;
- if (const_value_node->dtype() == loco::DataType::FLOAT32)
- {
- float const_value = const_value_node->at<loco::DataType::FLOAT32>(0);
- if (const_value == std::numeric_limits<float>::lowest())
- return true;
- }
- return false;
-}
-
-/** BEFORE
- *
- * [CircleNode] [CircleConst] [CircleConst]
- * (U8 qparam1) (S32) (FP32)
- * \ | /
- * \ | /
- * [CirclePadV2]
- * (U8 qparam2)
- *
- * AFTER (case 1)
- *
- * By default qparam is propagated from output to inputs to meet backend requirements.
- *
- * [CircleNode] [CircleConst] [CircleConst] [CircleConst] <- Dead node
- * (U8 qparam2) (S32) (U8 qparam2) (FP32)
- * \ | /
- * \ | /
- * [CirclePadV2]
- * (U8 qparam2)
- *
- * AFTER (case 2)
- *
- * In case padded value is the lowest float value
- * Qparam is propagated from input to output and constant.
- *
- * This is a special case for optimization constructed pad, needed to guarantee that
- * extremely large negative constant do not stretch output quantization range.
- *
- * [CircleNode] [CircleConst] [CircleConst] [CircleConst] <- Dead node
- * (U8 qparam1) (S32) (U8 qparam1) (FP32)
- * \ | /
- * \ | /
- * [CirclePadV2]
- * (U8 qparam1)
- */
-void propagate_pad_v2_quantparam(luci::CirclePadV2 *pad_v2, loco::DataType quant_type)
-{
- if (ignore_pad_v2_const_quantization(pad_v2))
- {
- // propagate input quantization paramters from input to output and padding const value
- auto pad_v2_input = loco::must_cast<luci::CircleNode *>(pad_v2->arg(0));
- overwrite_quantparam(pad_v2_input, pad_v2);
-
- auto const_value_node = loco::must_cast<luci::CircleConst *>(
- pad_v2->arg(2)); // FIX ignore_pad_v2_const_quantization UNLESS
- auto new_const = luci::clone(const_value_node);
-
- const auto pad_v2_input_qparam = pad_v2_input->quantparam();
- assert(pad_v2_input_qparam != nullptr);
- assert(pad_v2_input_qparam->scale.size() == 1);
- const auto scaling_factor = pad_v2_input_qparam->scale.at(0);
- const auto zerop = pad_v2_input_qparam->zerop.at(0);
-
- quant_const_values(new_const, scaling_factor, zerop, quant_type);
- overwrite_quantparam(pad_v2_input, new_const);
- pad_v2->constant_values(new_const);
- return;
- }
-
- // Propagate quantization paramters from output to inputs,
- // to fit both input and counstant_value in one quant range.
- auto quant_input = [pad_v2, quant_type](void (CirclePadV2::*arg_setter)(loco::Node *),
- uint32_t arg) {
- auto node = loco::must_cast<luci::CircleNode *>(pad_v2->arg(arg));
-
- // Quantize constant values
- if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
- {
- luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
- if (is_quantized(const_node))
- return;
-
- if (const_node->dtype() != loco::DataType::FLOAT32)
- throw std::runtime_error("Unsupported data type for constant input of PadV2 Op");
-
- const auto pad_v2_qparam = pad_v2->quantparam();
- if (pad_v2_qparam == nullptr)
- throw std::runtime_error("quantparam of PadV2 is not found during propagation");
-
- assert(pad_v2_qparam->scale.size() == 1);
- const auto scaling_factor = pad_v2_qparam->scale.at(0);
- const auto zerop = pad_v2_qparam->zerop.at(0);
-
- auto new_const = luci::clone(const_node);
- quant_const_values(new_const, scaling_factor, zerop, quant_type);
- overwrite_quantparam(pad_v2, new_const);
- (pad_v2->*arg_setter)(new_const);
- }
- // Subsequent PadV2 Ops quant params are not propagated
- else if (node->opcode() == luci::CircleOpcode::PADV2)
- {
- return;
- }
- else
- {
- const auto succs = loco::succs(node);
- if (succs.size() > 1)
- return;
-
- // Non-const input must have been quantized
- assert(node->quantparam() != nullptr);
- overwrite_quantparam(pad_v2, node);
- }
- };
-
- quant_input(&CirclePadV2::input, 0);
- quant_input(&CirclePadV2::constant_values, 2);
-}
-
void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const
{
auto inputs = g->inputs();
for (auto node : loco::input_nodes(g))
{
auto input = loco::must_cast<luci::CircleInput *>(node);
- if (input->dtype() == _input_type)
+ if (input->dtype() == _ctx->input_type)
continue;
// Bool type is not quantizable
if (input->dtype() == loco::DataType::BOOL)
continue;
+ if (input->dtype() == loco::DataType::S32)
+ continue;
+ if (input->dtype() == loco::DataType::S64)
+ continue;
// Insert Quantize Op
auto quant_op = create_quantize_op(input, input->dtype());
@@ -1552,22 +367,22 @@ void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const
float nudged_min{0};
float nudged_max{0};
- if (_input_type == loco::DataType::U8)
+ if (_ctx->input_type == loco::DataType::U8)
{
compute_asym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
}
else
{
- assert(_input_type == loco::DataType::S16);
+ assert(_ctx->input_type == loco::DataType::S16);
compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
}
- input->dtype(_input_type);
+ input->dtype(_ctx->input_type);
input->quantparam()->scale[0] = scaling_factor;
input->quantparam()->zerop[0] = zp;
}
auto graph_input = inputs->at(input->index());
- graph_input->dtype(_input_type);
+ graph_input->dtype(_ctx->input_type);
}
}
@@ -1577,7 +392,7 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const
for (auto node : loco::output_nodes(g))
{
auto output = loco::must_cast<luci::CircleOutput *>(node);
- if (output->dtype() == _output_type)
+ if (output->dtype() == _ctx->output_type)
continue;
// Bool type is not quantizable
@@ -1591,7 +406,7 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const
continue;
// Insert Quantize Op
- auto quant_op = create_quantize_op(from, _output_type);
+ auto quant_op = create_quantize_op(from, _ctx->output_type);
loco::replace(from).with(quant_op);
quant_op->input(from);
@@ -1599,67 +414,165 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const
luci::add_origin(quant_op, luci::get_origin(from));
auto graph_output = outputs->at(output->index());
- graph_output->dtype(_output_type);
+ graph_output->dtype(_ctx->output_type);
}
}
+/**
+ * How QuantizeWithMinMax works?
+ *
+ * We categorized tensors into four groups
+ * - Activation: Feature maps (both Const/Non-const)
+ * - Weights: Const tensors of specific Ops (Conv, FC, ...)
+ * - Bias: Const tensors of specific Ops (Conv, FC, ...)
+ * - Others: padding value, one_hot value, axis, ..
+ *
+ * Activation is quantized in different ways
+ * 1. For non-constant activation, quantize using recorded min/max
+ * 2. For constant activation, quantize using min/max of its value
+ * 3. For some Ops (ex: pad_v2), output qparam is used as input qparam (backward propagation)
+ * 4. For some Ops (ex: reshape), input qparam is used as output qparam (forward propagation)
+ * 5. For some Ops (ex: tanh), output qparam has pre-defined values
+ *
+ * Weights is quantized using min/max of its value
+ *
+ * Bias is quantized using input scale (s_i) and weights scale (s_w)
+ * - Activation and weights should be quantized earlier than bias
+ *
+ * Quantization Steps
+ * 1. Quantize Activation
+ * - Quantize using recorded min/max (QuantizeActivation)
+ * - Insert Quantize Ops for mixed-precision quantization (InsertQuantizeOp)
+ * - Remove redundant Quantize Ops (RemoveRedundantQuantizePass)
+ * - Propagate qparam backward (PropagateQParamBackwardPass)
+ * - Quantize const inputs (QuantizeConstInputActivation)
+ * - Quantize using pre-defined values (QuantizeSpecialActivation)
+ * - Propagate qparam forward (PropagateQParamForwardPass)
+ * 2. Quantize Weights
+ * 3. Quantize Bias
+ * 4. Set input dtype
+ * 5. Set output dtype
+ *
+ * Why quantization sequence was determined as above?
+ * - Activation and weights should be quantized before bias (1->2->3). Input/Output
+ * dtype can be updated at the end (4->5).
+ * - During activation quantization,
+ * - Backward propagation is performed earlier than forward propagation. This allows
+ * backward-propagated qpram to be overwritten during forward propagation.
+ * We made this decision as Ops for forward propagation (reshape, transpose, ..)
+ * are more common than backward propagation. TODO Check this decision is safe.
+ * - QuantizeSpecialActivation is called before forward propagation to make sure that
+ * the pre-defined qparam values are propagated.
+ */
bool QuantizeWithMinMaxPass::run(loco::Graph *g)
{
LOGGER(l);
INFO(l) << "QuantizeWithMinMaxPass Start" << std::endl;
+ auto info_by_name = layer_info_map(g, _ctx->layers_info);
+
+ auto quantize_dtype = [&](const luci::CircleNode *node) {
+ auto iter = info_by_name.find(node->name());
+
+ // Return designated quantization dtype
+ if (iter != info_by_name.end())
+ return iter->second.dtype;
+
+ // Return default quantization dtype
+ return _ctx->output_model_dtype;
+ };
+
+ auto quantize_granularity = [&](const luci::CircleNode *node) {
+ auto iter = info_by_name.find(node->name());
+
+ // Return designated quantization granularity
+ if (iter != info_by_name.end())
+ return iter->second.granularity;
+
+ // Return default quantization granularity
+ return _ctx->granularity;
+ };
+
// Quantize activation
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
- QuantizeActivation qa(_input_model_dtype, _output_model_dtype);
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ QuantizeActivation qa(_ctx->input_model_dtype, quantize_dtype(circle_node));
circle_node->accept(&qa);
}
- // Quantize weights
+ // Insert Quantize Op
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
- QuantizeWeights qw(_input_model_dtype, _output_model_dtype, _granularity);
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- circle_node->accept(&qw);
+ auto op_dtype = quantize_dtype(circle_node);
+ if (op_dtype != _ctx->output_model_dtype)
+ {
+ InsertQuantizeOp iqo(_ctx->output_model_dtype, op_dtype);
+ circle_node->accept(&iqo);
+ }
}
- // Quantize bias
+ // Remove redundant Quantize Op
+ {
+ logo::Phase phase;
+
+ phase.emplace_back(std::make_unique<luci::RemoveRedundantQuantizePass>());
+
+ ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
+ logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
+ phase_runner.attach(&prog);
+ phase_runner.run(phase);
+ }
+
+ // Backward propagation of activation qparam
+ {
+ PropagateQParamBackwardPass pqbp(_ctx->output_model_dtype);
+ pqbp.run(g);
+ }
+
+ // Quantize const input activation
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
- QuantizeBias qb(_input_model_dtype, _output_model_dtype, _granularity);
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- circle_node->accept(&qb);
+ QuantizeConstInputActivation qcia(quantize_dtype(circle_node));
+ circle_node->accept(&qcia);
}
- // Propagate quantization parameters of concat Op
+ // Update qparam of output of special Ops
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
- auto concat = dynamic_cast<luci::CircleConcatenation *>(node);
- if (not concat)
- continue;
-
- // Propagate qparam of concat to its inputs if
- // (1) concat is uint8-quantized
- // (2) concat has no fused activation function
- // (3) the input is not concatenation Op
- // (4) the input is not produced to Ops other than concat
- propagate_concat_quantparam(concat, _output_model_dtype);
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ QuantizeSpecialActivation qsa(_ctx->input_model_dtype, quantize_dtype(circle_node));
+ circle_node->accept(&qsa);
}
- // Quantize const inputs other than weights and bias
+ // Forward propagation of activation qparam
+ logo::Phase phase;
+
+ phase.emplace_back(std::make_unique<luci::PropagateQParamForwardPass>(_ctx->TF_style_maxpool));
+
+ ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
+ logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
+ phase_runner.attach(&prog);
+ phase_runner.run(phase);
+
+ // Quantize weights
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- quantize_const_inputs(circle_node, _output_model_dtype);
+ QuantizeWeights qw(_ctx->input_model_dtype, quantize_dtype(circle_node),
+ quantize_granularity(circle_node));
+ circle_node->accept(&qw);
}
- // Update qparam of output of special Ops
+ // Quantize bias
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
- QuantizeSpecialActivation qsa(_input_model_dtype, _output_model_dtype);
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- circle_node->accept(&qsa);
+ QuantizeBias qb(_ctx->input_model_dtype, quantize_dtype(circle_node),
+ quantize_granularity(circle_node));
+ circle_node->accept(&qb);
}
// Update output dtype
@@ -1667,11 +580,11 @@ bool QuantizeWithMinMaxPass::run(loco::Graph *g)
for (auto node : loco::output_nodes(g))
{
auto circle_node = loco::must_cast<luci::CircleOutput *>(node);
- if (static_cast<luci::CircleNode *>(circle_node->from())->dtype() == _output_model_dtype)
+ if (static_cast<luci::CircleNode *>(circle_node->from())->dtype() == _ctx->output_model_dtype)
{
- circle_node->dtype(_output_model_dtype);
+ circle_node->dtype(_ctx->output_model_dtype);
auto graph_output = graph_outputs->at(circle_node->index());
- graph_output->dtype(_output_model_dtype);
+ graph_output->dtype(_ctx->output_model_dtype);
}
}
diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp
index 75ec0cfd8..d5fa21ffd 100644
--- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp
+++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp
@@ -16,8 +16,41 @@
#include "luci/Pass/QuantizeWithMinMaxPass.h"
+#include <luci/IR/CircleNodes.h>
+
#include <gtest/gtest.h>
+class SimpleConcatGraph
+{
+public:
+ SimpleConcatGraph(loco::DataType quant_type)
+ {
+ concat_node = g.nodes()->create<luci::CircleConcatenation>(2);
+ input_1 = g.nodes()->create<luci::CircleConst>();
+ input_2 = g.nodes()->create<luci::CircleConst>();
+
+ concat_node->dtype(quant_type);
+ concat_node->fusedActivationFunction(luci::FusedActFunc::NONE);
+ input_1->dtype(quant_type);
+ input_2->dtype(quant_type);
+
+ concat_node->values(0, input_1);
+ concat_node->values(1, input_2);
+ }
+
+ ~SimpleConcatGraph()
+ {
+ concat_node->values(0, nullptr);
+ concat_node->values(1, nullptr);
+ }
+
+public:
+ loco::Graph g;
+ luci::CircleConcatenation *concat_node = nullptr;
+ luci::CircleConst *input_1 = nullptr;
+ luci::CircleConst *input_2 = nullptr;
+};
+
TEST(QuantizeWithMinMaxPassTest, name)
{
luci::QuantizeWithMinMaxPass pass(loco::DataType::FLOAT32, loco::DataType::U8,
@@ -25,3 +58,19 @@ TEST(QuantizeWithMinMaxPassTest, name)
auto const name = pass.name();
ASSERT_NE(nullptr, name);
}
+
+// Test concat of integer tensors
+// Integer tensors are not quantized
+TEST(QuantizeWithMinMaxPassTest, int_concat)
+{
+ SimpleConcatGraph g(loco::DataType::S32);
+
+ luci::QuantizeWithMinMaxPass qwmm(loco::DataType::FLOAT32, loco::DataType::U8,
+ luci::QuantizationGranularity::LayerWise);
+
+ qwmm.run(&g.g);
+
+ EXPECT_EQ(nullptr, g.concat_node->quantparam());
+ EXPECT_EQ(nullptr, g.input_1->quantparam());
+ EXPECT_EQ(nullptr, g.input_2->quantparam());
+}
diff --git a/compiler/luci/pass/src/QuantizedModelVerifier.cpp b/compiler/luci/pass/src/QuantizedModelVerifier.cpp
index f02301ed1..684d5d48a 100644
--- a/compiler/luci/pass/src/QuantizedModelVerifier.cpp
+++ b/compiler/luci/pass/src/QuantizedModelVerifier.cpp
@@ -15,10 +15,10 @@
#include "QuantizedModelVerifier.h"
-#include "VerifyQuantizedNodeLayerWiseGranularity.h"
-#include "VerifyQuantizedNodeChannelWiseGranularity.h"
-#include "VerifyQuantizedNodeU8Type.h"
-#include "VerifyQuantizedNodeS16Type.h"
+#include "VerifyQuantizedNodeGranularity.h"
+#include "VerifyQuantizedNodeType.h"
+#include "VerifyQuantizedBiasScale.h"
+#include "helpers/LayerInfoMap.h"
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
@@ -28,12 +28,33 @@ namespace luci
void QuantizedModelVerifier::verify(loco::Graph *g)
{
- if (_quantized_dtype != Type::U8 && _quantized_dtype != Type::S16)
- throw std::runtime_error("Unsupported quantized dtype");
-
- if (_granularity != Granularity::ChannelWise && _granularity != Granularity::LayerWise)
+ if (_ctx->granularity != Granularity::ChannelWise && _ctx->granularity != Granularity::LayerWise)
throw std::runtime_error("Unsupported granularity");
+ auto info_by_name = layer_info_map(g, _ctx->layers_info);
+
+ auto quantize_dtype = [&](const luci::CircleNode *node) {
+ auto iter = info_by_name.find(node->name());
+
+ // Return designated quantization dtype
+ if (iter != info_by_name.end())
+ return iter->second.dtype;
+
+ // Return default quantization dtype
+ return _ctx->output_model_dtype;
+ };
+
+ auto quantize_granularity = [&](const luci::CircleNode *node) {
+ auto iter = info_by_name.find(node->name());
+
+ // Return designated quantization granularity
+ if (iter != info_by_name.end())
+ return iter->second.granularity;
+
+ // Return default quantization granularity
+ return _ctx->granularity;
+ };
+
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
@@ -46,32 +67,17 @@ void QuantizedModelVerifier::verify(loco::Graph *g)
};
// Verify Type
- if (_quantized_dtype == Type::U8)
- {
- VerifyQuantizedNodeU8Type vt;
- if (!circle_node->accept(&vt))
- throw std::runtime_error("Wrong data type detected in " + node_name());
- }
- else if (_quantized_dtype == Type::S16)
- {
- VerifyQuantizedNodeS16Type vt;
- if (!circle_node->accept(&vt))
- throw std::runtime_error("Wrong data type detected in " + node_name());
- }
+ if (!VerifyQuantizedNodeType::create(quantize_dtype(circle_node))->verify(circle_node))
+ throw std::runtime_error("Wrong data type detected in " + node_name());
// Verify Granularity
- if (_granularity == Granularity::LayerWise)
- {
- VerifyQuantizedNodeLayerWiseGranularity vg;
- if (!circle_node->accept(&vg))
- throw std::runtime_error("Wrong granularity detected in " + node_name());
- }
- else if (_granularity == Granularity::ChannelWise)
- {
- VerifyQuantizedNodeChannelWiseGranularity vg;
- if (!circle_node->accept(&vg))
- throw std::runtime_error("Wrong granularity detected in " + node_name());
- }
+ if (!circle_node->accept(
+ VerifyQuantizedNodeGranularity::create(quantize_granularity(circle_node)).get()))
+ throw std::runtime_error("Wrong granularity detected in " + node_name());
+
+ // Verify Bias scale
+ if (!VerifyQuantizedBiasScale::create()->verify(circle_node))
+ throw std::runtime_error("Wrong bias scale detected in " + node_name());
}
}
diff --git a/compiler/luci/pass/src/QuantizedModelVerifier.h b/compiler/luci/pass/src/QuantizedModelVerifier.h
index d5fbb8e74..7409a51d7 100644
--- a/compiler/luci/pass/src/QuantizedModelVerifier.h
+++ b/compiler/luci/pass/src/QuantizedModelVerifier.h
@@ -21,6 +21,8 @@
#include <loco.h>
+#include <memory>
+
namespace luci
{
@@ -31,18 +33,40 @@ namespace luci
*/
struct QuantizedModelVerifier
{
+public:
+ struct Context
+ {
+ loco::DataType output_model_dtype = loco::DataType::Unknown;
+ QuantizationGranularity granularity = QuantizationGranularity::ChannelWise;
+ loco::DataType input_type = loco::DataType::Unknown;
+ loco::DataType output_type = loco::DataType::Unknown;
+ bool TF_style_maxpool = false;
+ std::vector<LayerInfo> layers_info;
+ };
public:
QuantizedModelVerifier(loco::DataType quantized_dtype, QuantizationGranularity granularity)
- : _quantized_dtype(quantized_dtype), _granularity(granularity)
{
+ _ctx = std::make_unique<Context>();
+ {
+ _ctx->output_model_dtype = quantized_dtype;
+ _ctx->granularity = granularity;
+ _ctx->input_type = quantized_dtype;
+ _ctx->output_type = quantized_dtype;
+ _ctx->TF_style_maxpool = false;
+ }
+ }
+
+public:
+ QuantizedModelVerifier(std::unique_ptr<Context> &&ctx) : _ctx{std::move(ctx)}
+ {
+ // DO NOTHING
}
void verify(loco::Graph *g);
private:
- loco::DataType _quantized_dtype;
- QuantizationGranularity _granularity;
+ std::unique_ptr<Context> _ctx;
};
} // namespace luci
diff --git a/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp b/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp
index 3a6d86c33..cebafd32b 100644
--- a/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp
+++ b/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp
@@ -17,6 +17,7 @@
#include "QuantizedModelVerifier.h"
#include "luci/Pass/QuantizeWithMinMaxPass.h"
+#include "luci/Pass/QuantizationParameters.h"
#include <luci/test/TestIOGraph.h>
@@ -112,57 +113,77 @@ void quantize_and_verify(loco::Graph *g, Type quantized_dtype, Granularity granu
verifier.verify(g);
}
-// Helper function to reduce duplicate test codes
-// Assumption: g->output()->from() is the target node
-void quantize_and_verify_with_wrong_type(luci::test::TestIOGraph *g, Type quantized_dtype,
- Granularity granularity, Type wrong_dtype)
+void quantize_and_verify_with_layer_info(loco::Graph *g, Type quantized_dtype,
+ Granularity granularity)
{
- luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity);
- pass.run(g->g());
-
- auto node = loco::must_cast<luci::CircleNode *>(g->output()->from());
- node->dtype(wrong_dtype);
+ // A layer named "test" has dtype different from quantized_dtype
+ luci::LayerInfo info;
+ {
+ info.name = "test";
+ // dtype is different from quantized_dtype
+ info.dtype = quantized_dtype == Type::U8 ? Type::S16 : Type::U8;
+ info.granularity = Granularity::ChannelWise;
+ }
- luci::QuantizedModelVerifier verifier(quantized_dtype, granularity);
- verifier.verify(g->g());
-}
+ // Do quantization
+ {
+ auto ctx = std::make_unique<luci::QuantizeWithMinMaxPass::Context>();
+ {
+ ctx->input_model_dtype = Type::FLOAT32;
+ ctx->output_model_dtype = quantized_dtype;
+ ctx->granularity = granularity;
+ ctx->input_type = quantized_dtype;
+ ctx->output_type = quantized_dtype;
+ ctx->TF_style_maxpool = false;
+ ctx->layers_info.push_back(info);
+ }
-void quantize_and_verify_with_wrong_type(luci::test::TestIOGraph *g, Type quantized_dtype,
- Granularity granularity, Type wrong_dtype,
- luci::CircleNode *target)
-{
- luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity);
- pass.run(g->g());
+ luci::QuantizeWithMinMaxPass pass(std::move(ctx));
+ pass.run(g);
+ }
- target->dtype(wrong_dtype);
+ // Do verification
+ {
+ auto ctx = std::make_unique<luci::QuantizedModelVerifier::Context>();
+ {
+ ctx->output_model_dtype = quantized_dtype;
+ ctx->granularity = granularity;
+ ctx->input_type = quantized_dtype;
+ ctx->output_type = quantized_dtype;
+ ctx->TF_style_maxpool = false;
+ ctx->layers_info.push_back(info);
+ }
- luci::QuantizedModelVerifier verifier(quantized_dtype, granularity);
- verifier.verify(g->g());
+ luci::QuantizedModelVerifier verifier(std::move(ctx));
+ verifier.verify(g);
+ }
}
// Helper function to reduce duplicate test codes
// Assumption: g->output()->from() is the target node
-void quantize_and_verify_with_wrong_granularity(luci::test::TestIOGraph *g, Type quantized_dtype,
- Granularity granularity)
+void quantize_and_verify_with_wrong_type(luci::test::TestIOGraph *g, Type quantized_dtype,
+ Granularity granularity, Type wrong_dtype)
{
luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity);
pass.run(g->g());
auto node = loco::must_cast<luci::CircleNode *>(g->output()->from());
- insert_scale_zp(node, 1.0, 1);
+ node->dtype(wrong_dtype);
luci::QuantizedModelVerifier verifier(quantized_dtype, granularity);
verifier.verify(g->g());
}
// Helper function to reduce duplicate test codes
+// Assumption: g->output()->from() is the target node
void quantize_and_verify_with_wrong_granularity(luci::test::TestIOGraph *g, Type quantized_dtype,
- Granularity granularity, luci::CircleNode *target)
+ Granularity granularity)
{
luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity);
pass.run(g->g());
- insert_scale_zp(target, 1.0, 1);
+ auto node = loco::must_cast<luci::CircleNode *>(g->output()->from());
+ insert_scale_zp(node, 1.0, 1);
luci::QuantizedModelVerifier verifier(quantized_dtype, granularity);
verifier.verify(g->g());
@@ -230,6 +251,8 @@ public:
_instnorm->input(input());
_instnorm->gamma(_gamma);
_instnorm->beta(_beta);
+ _instnorm->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _instnorm->name("test");
}
output()->from(_instnorm);
@@ -256,6 +279,7 @@ public:
_logistic = g()->nodes()->create<luci::CircleLogistic>();
{
_logistic->x(input());
+ _logistic->name("test");
}
output()->from(_logistic);
@@ -275,6 +299,7 @@ public:
_lrn = g()->nodes()->create<luci::CircleLocalResponseNormalization>();
{
_lrn->input(input());
+ _lrn->name("test");
}
output()->from(_lrn);
@@ -295,6 +320,7 @@ public:
{
_softmax->logits(input());
_softmax->beta(0.1);
+ _softmax->name("test");
}
output()->from(_softmax);
@@ -324,6 +350,7 @@ public:
_stob->input(input());
_stob->block_shape(_block_shape);
_stob->paddings(_paddings);
+ _stob->name("test");
}
output()->from(_stob);
@@ -346,6 +373,7 @@ public:
{
_stod->input(input());
_stod->block_size(2);
+ _stod->name("test");
}
output()->from(_stod);
@@ -375,6 +403,7 @@ public:
_slice->input(input());
_slice->begin(_begin);
_slice->size(_size);
+ _slice->name("test");
}
output()->from(_slice);
@@ -472,6 +501,7 @@ public:
_slice->begin(_begin);
_slice->end(_end);
_slice->strides(_strides);
+ _slice->name("test");
}
output()->from(_slice);
@@ -499,6 +529,7 @@ public:
{
_reshape->tensor(input());
_reshape->shape(_shape);
+ _reshape->name("test");
}
output()->from(_reshape);
@@ -519,6 +550,7 @@ public:
_tanh = g()->nodes()->create<luci::CircleTanh>();
{
_tanh->x(input());
+ _tanh->name("test");
}
output()->from(_tanh);
@@ -538,6 +570,7 @@ public:
_floor = g()->nodes()->create<luci::CircleFloor>();
{
_floor->x(input());
+ _floor->name("test");
}
output()->from(_floor);
@@ -601,6 +634,7 @@ public:
_btos->input(input());
_btos->block_shape(_block_shape);
_btos->crops(_crops);
+ _btos->name("test");
}
output()->from(_btos);
@@ -623,6 +657,7 @@ public:
{
_dtos->input(input());
_dtos->block_size(2);
+ _dtos->name("test");
}
output()->from(_dtos);
@@ -645,6 +680,7 @@ public:
_pack->values(0, input());
_pack->values(1, _param);
_pack->axis(0);
+ _pack->name("test");
}
output()->from(_pack);
@@ -680,6 +716,7 @@ public:
{
_pad->input(input());
_pad->paddings(_paddings);
+ _pad->name("test");
}
output()->from(_pad);
@@ -707,6 +744,7 @@ public:
_pad->input(input());
_pad->paddings(_paddings);
_pad->constant_values(_constant_values);
+ _pad->name("test");
}
output()->from(_pad);
@@ -735,6 +773,7 @@ public:
_mirror_pad->input(input());
_mirror_pad->paddings(_paddings);
_mirror_pad->mode(luci::MirrorPadMode::REFLECT);
+ _mirror_pad->name("test");
}
output()->from(_mirror_pad);
@@ -761,6 +800,7 @@ public:
{
_transpose->a(input());
_transpose->perm(_perm);
+ _transpose->name("test");
}
output()->from(_transpose);
@@ -784,6 +824,8 @@ public:
_concat->values(0, input());
_concat->values(1, _param);
_concat->axis(0);
+ _concat->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _concat->name("test");
}
output()->from(_concat);
@@ -795,6 +837,54 @@ private:
luci::CircleConst *_param = nullptr;
};
+template <Type indexT> class OneHotTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32, 10});
+ {
+ // input dtype is float by default, but OneHot's input should have indexType (s32/s64)
+ input()->dtype(indexT);
+ }
+
+ _depth = g()->nodes()->template create<luci::CircleConst>();
+ {
+ _depth->dtype(loco::DataType::S32);
+ }
+
+ _on_value = g()->nodes()->template create<luci::CircleConst>();
+ {
+ _on_value->dtype(loco::DataType::FLOAT32);
+ }
+
+ _off_value = g()->nodes()->template create<luci::CircleConst>();
+ {
+ _off_value->dtype(loco::DataType::FLOAT32);
+ }
+
+ _one_hot = g()->nodes()->template create<luci::CircleOneHot>();
+ {
+ _one_hot->indices(input());
+ _one_hot->depth(_depth);
+ _one_hot->on_value(_on_value);
+ _one_hot->off_value(_off_value);
+ _one_hot->axis(-1);
+ _one_hot->dtype(loco::DataType::FLOAT32);
+ _one_hot->name("test");
+ }
+ output()->from(_one_hot);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+private:
+ luci::CircleOneHot *_one_hot = nullptr;
+ luci::CircleConst *_depth = nullptr;
+ luci::CircleConst *_on_value = nullptr;
+ luci::CircleConst *_off_value = nullptr;
+};
+
// Test graph for comparison Ops
// GREATER, GREATER_EQUAL, LESS, LESS_EQUAL, EQUAL, NOT_EQUAL
template <class Op> class ComparisonOpTestGraph final : public SimpleTestGraph
@@ -866,6 +956,7 @@ public:
{
_div->x(input());
_div->y(_const);
+ _div->name("test");
}
output()->from(_div);
@@ -893,6 +984,7 @@ public:
{
_floor_div->x(input());
_floor_div->y(_const);
+ _floor_div->name("test");
}
output()->from(_floor_div);
@@ -917,6 +1009,7 @@ public:
_rsqrt = g()->nodes()->create<luci::CircleRsqrt>();
{
_rsqrt->x(input());
+ _rsqrt->name("test");
}
output()->from(_rsqrt);
@@ -936,6 +1029,7 @@ public:
_sqrt = g()->nodes()->create<luci::CircleSqrt>();
{
_sqrt->x(input());
+ _sqrt->name("test");
}
output()->from(_sqrt);
@@ -955,6 +1049,7 @@ public:
_elu = g()->nodes()->create<luci::CircleElu>();
{
_elu->features(input());
+ _elu->name("test");
}
output()->from(_elu);
@@ -977,6 +1072,7 @@ public:
{
_pow->x(input());
_pow->y(_const);
+ _pow->name("test");
}
output()->from(_pow);
@@ -1004,6 +1100,7 @@ public:
{
_resize_bilinear->input(input());
_resize_bilinear->size(_size);
+ _resize_bilinear->name("test");
}
output()->from(_resize_bilinear);
@@ -1027,6 +1124,7 @@ public:
{
_resize_nearest_neighbor->input(input());
_resize_nearest_neighbor->size(_size);
+ _resize_nearest_neighbor->name("test");
}
output()->from(_resize_nearest_neighbor);
@@ -1067,6 +1165,62 @@ private:
luci::CircleConst *_unpack_dim = nullptr;
};
+class MulTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+
+ _const = create_dummy_const<Type::FLOAT32>(g(), {32});
+ _mul = g()->nodes()->create<luci::CircleMul>();
+ {
+ _mul->x(input());
+ _mul->y(_const);
+ _mul->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _mul->name("test");
+ }
+ output()->from(_mul);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+ loco::Node *x() { return _mul->x(); }
+ loco::Node *y() { return _mul->y(); }
+
+private:
+ luci::CircleMul *_mul = nullptr;
+ luci::CircleConst *_const = nullptr;
+};
+
+class AddTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+
+ _const = create_dummy_const<Type::FLOAT32>(g(), {32});
+ _add = g()->nodes()->create<luci::CircleAdd>();
+ {
+ _add->x(input());
+ _add->y(_const);
+ _add->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _add->name("test");
+ }
+ output()->from(_add);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+ loco::Node *x() { return _add->x(); }
+ loco::Node *y() { return _add->y(); }
+
+private:
+ luci::CircleAdd *_add = nullptr;
+ luci::CircleConst *_const = nullptr;
+};
+
} // namespace
// Quantize and verify with given configurations
@@ -1078,6 +1232,15 @@ private:
EXPECT_NO_THROW(quantize_and_verify(g.g(), type, granularity)); \
} while (0)
+// Quantize and verify with layer info
+#define TEST_WITH_LAYER_INFO(graph, type, granularity) \
+ do \
+ { \
+ graph g; \
+ g.init(); \
+ EXPECT_NO_THROW(quantize_and_verify_with_layer_info(g.g(), type, granularity)); \
+ } while (0)
+
// Quantize and verify with wrong type
#define TEST_WITH_WRONG_TYPE(graph, type, granularity, wrong_dtype) \
do \
@@ -1098,25 +1261,34 @@ private:
// Quantize and verify with wrong type
// Users can specify the test target
-#define TEST_WITH_WRONG_TYPE_TARGET(graph, type, granularity, wrong_dtype, target) \
- do \
- { \
- graph g; \
- g.init(); \
- auto node = loco::must_cast<luci::CircleNode *>(target); \
- EXPECT_ANY_THROW( \
- quantize_and_verify_with_wrong_type(&g, type, granularity, wrong_dtype, node)); \
+#define TEST_WITH_WRONG_TYPE_TARGET(graph, type, granularity, wrong_dtype, target) \
+ do \
+ { \
+ graph g; \
+ g.init(); \
+ auto node = loco::must_cast<luci::CircleNode *>(target); \
+ luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, type, granularity); \
+ pass.run(g.g()); \
+ auto after_node = loco::must_cast<luci::CircleNode *>(target); \
+ after_node->dtype(wrong_dtype); \
+ luci::QuantizedModelVerifier verifier(type, granularity); \
+ EXPECT_ANY_THROW(verifier.verify(g.g())); \
} while (0)
// Quantize and verify with wrong granularity
// Users can specify the test target
-#define TEST_WITH_WRONG_GRANULARITY_TARGET(graph, type, granularity, target) \
- do \
- { \
- graph g; \
- g.init(); \
- auto node = loco::must_cast<luci::CircleNode *>(target); \
- EXPECT_ANY_THROW(quantize_and_verify_with_wrong_granularity(&g, type, granularity, node)); \
+#define TEST_WITH_WRONG_GRANULARITY_TARGET(graph, type, granularity, target) \
+ do \
+ { \
+ graph g; \
+ g.init(); \
+ auto node = loco::must_cast<luci::CircleNode *>(target); \
+ luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, type, granularity); \
+ pass.run(g.g()); \
+ auto after_node = loco::must_cast<luci::CircleNode *>(target); \
+ insert_scale_zp(after_node, 1.0, 1); \
+ luci::QuantizedModelVerifier verifier(type, granularity); \
+ EXPECT_ANY_THROW(verifier.verify(g.g())); \
} while (0)
// Test a local helper function
@@ -1145,6 +1317,10 @@ TEST(QuantizedModelVerifierTest, InstanceNorm)
TEST_WITH_GRAPH(InstanceNormTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(InstanceNormTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(InstanceNormTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(InstanceNormTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(InstanceNormTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(InstanceNormTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1169,6 +1345,10 @@ TEST(QuantizedModelVerifierTest, LocalResponseNormalization)
TEST_WITH_GRAPH(LocalResponseNormalizationTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(LocalResponseNormalizationTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(LocalResponseNormalizationTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(LocalResponseNormalizationTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(LocalResponseNormalizationTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(LocalResponseNormalizationTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1199,6 +1379,10 @@ TEST(QuantizedModelVerifierTest, Logistic)
TEST_WITH_GRAPH(LogisticTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(LogisticTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(LogisticTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(LogisticTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(LogisticTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(LogisticTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1223,6 +1407,10 @@ TEST(QuantizedModelVerifierTest, Softmax)
TEST_WITH_GRAPH(SoftmaxTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(SoftmaxTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(SoftmaxTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(SoftmaxTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(SoftmaxTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(SoftmaxTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1247,6 +1435,10 @@ TEST(QuantizedModelVerifierTest, SpaceToBatchND)
TEST_WITH_GRAPH(SpaceToBatchNDTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(SpaceToBatchNDTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(SpaceToBatchNDTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(SpaceToBatchNDTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(SpaceToBatchNDTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(SpaceToBatchNDTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1271,6 +1463,10 @@ TEST(QuantizedModelVerifierTest, SpaceToDepth)
TEST_WITH_GRAPH(SpaceToDepthTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(SpaceToDepthTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(SpaceToDepthTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(SpaceToDepthTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(SpaceToDepthTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(SpaceToDepthTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1299,6 +1495,14 @@ TEST(QuantizedModelVerifierTest, Slice)
TEST_WITH_GRAPH(SliceTestGraph<Type::S64>, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(SliceTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(SliceTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(SliceTestGraph<Type::S32>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(SliceTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(SliceTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(SliceTestGraph<Type::S64>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(SliceTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(SliceTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1379,6 +1583,10 @@ TEST(QuantizedModelVerifierTest, StridedSlice)
TEST_WITH_GRAPH(StridedSliceTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(StridedSliceTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(StridedSliceTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(StridedSliceTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(StridedSliceTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(StridedSliceTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1463,6 +1671,10 @@ TEST(QuantizedModelVerifierTest, BatchToSpaceND)
TEST_WITH_GRAPH(BatchToSpaceNDTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(BatchToSpaceNDTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(BatchToSpaceNDTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(BatchToSpaceNDTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(BatchToSpaceNDTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(BatchToSpaceNDTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1487,6 +1699,10 @@ TEST(QuantizedModelVerifierTest, DepthToSpace)
TEST_WITH_GRAPH(DepthToSpaceTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(DepthToSpaceTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(DepthToSpaceTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(DepthToSpaceTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(DepthToSpaceTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(DepthToSpaceTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1511,6 +1727,10 @@ TEST(QuantizedModelVerifierTest, Concatenation)
TEST_WITH_GRAPH(ConcatenationTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(ConcatenationTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(ConcatenationTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(ConcatenationTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(ConcatenationTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(ConcatenationTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1557,6 +1777,10 @@ TEST(QuantizedModelVerifierTest, Reshape)
TEST_WITH_GRAPH(ReshapeTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(ReshapeTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(ReshapeTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(ReshapeTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(ReshapeTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(ReshapeTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1581,6 +1805,10 @@ TEST(QuantizedModelVerifierTest, Tanh)
TEST_WITH_GRAPH(TanhTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(TanhTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(TanhTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(TanhTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(TanhTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(TanhTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1606,6 +1834,10 @@ TEST(QuantizedModelVerifierTest, Pack)
TEST_WITH_GRAPH(PackTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(PackTestGraph, Type::S16, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(PackTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(PackTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(PackTestGraph, Type::S16, Granularity::ChannelWise);
+
// Test if Pack's qparam is propagated to the input
{
PackTestGraph g;
@@ -1640,6 +1872,10 @@ TEST(QuantizedModelVerifierTest, Pad)
TEST_WITH_GRAPH(PadTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(PadTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(PadTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(PadTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(PadTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(PadTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1664,6 +1900,10 @@ TEST(QuantizedModelVerifierTest, PadV2)
TEST_WITH_GRAPH(PadV2TestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(PadV2TestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(PadV2TestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(PadV2TestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(PadV2TestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(PadV2TestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1688,6 +1928,10 @@ TEST(QuantizedModelVerifierTest, MirrorPad)
TEST_WITH_GRAPH(MirrorPadTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(MirrorPadTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(MirrorPadTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(MirrorPadTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(MirrorPadTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(MirrorPadTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1712,6 +1956,10 @@ TEST(QuantizedModelVerifierTest, Transpose)
TEST_WITH_GRAPH(TransposeTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(TransposeTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(TransposeTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(TransposeTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(TransposeTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(TransposeTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1736,6 +1984,10 @@ TEST(QuantizedModelVerifierTest, Floor)
TEST_WITH_GRAPH(FloorTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(FloorTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(FloorTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(FloorTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(FloorTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(FloorTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1869,11 +2121,59 @@ TEST(QuantizedModelVerifierTest, NotEqual_wrong_granularity_NEG)
SUCCEED();
}
+TEST(QuantizedModelVerifierTest, OneHot)
+{
+ TEST_WITH_GRAPH(OneHotTestGraph<Type::S32>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(OneHotTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(OneHotTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_GRAPH(OneHotTestGraph<Type::S64>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(OneHotTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(OneHotTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(OneHotTestGraph<Type::S32>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(OneHotTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(OneHotTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(OneHotTestGraph<Type::S64>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(OneHotTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(OneHotTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, OneHot_wrong_input_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(OneHotTestGraph<Type::S32>, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(OneHotTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(OneHotTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise, Type::U8);
+
+ TEST_WITH_WRONG_TYPE(OneHotTestGraph<Type::S64>, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(OneHotTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(OneHotTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, OneHot_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(OneHotTestGraph<Type::S32>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(OneHotTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(OneHotTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_WRONG_GRANULARITY(OneHotTestGraph<Type::S64>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(OneHotTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(OneHotTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
TEST(QuantizedModelVerifierTest, Div)
{
TEST_WITH_GRAPH(DivTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(DivTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(DivTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(DivTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(DivTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(DivTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1902,6 +2202,10 @@ TEST(QuantizedModelVerifierTest, FloorDiv)
TEST_WITH_GRAPH(FloorDivTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(FloorDivTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(FloorDivTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(FloorDivTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(FloorDivTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(FloorDivTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1930,6 +2234,10 @@ TEST(QuantizedModelVerifierTest, Rsqrt)
TEST_WITH_GRAPH(RsqrtTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(RsqrtTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(RsqrtTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(RsqrtTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(RsqrtTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(RsqrtTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1954,6 +2262,10 @@ TEST(QuantizedModelVerifierTest, Sqrt)
TEST_WITH_GRAPH(SqrtTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(SqrtTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(SqrtTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(SqrtTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(SqrtTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(SqrtTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -1978,6 +2290,10 @@ TEST(QuantizedModelVerifierTest, Elu)
TEST_WITH_GRAPH(EluTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(EluTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(EluTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(EluTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(EluTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(EluTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -2002,6 +2318,10 @@ TEST(QuantizedModelVerifierTest, Pow)
TEST_WITH_GRAPH(PowTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(PowTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(PowTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(PowTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(PowTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(PowTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -2030,6 +2350,10 @@ TEST(QuantizedModelVerifierTest, ResizeBilinear)
TEST_WITH_GRAPH(ResizeBilinearTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(ResizeBilinearTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(ResizeBilinearTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(ResizeBilinearTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(ResizeBilinearTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(ResizeBilinearTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -2054,6 +2378,10 @@ TEST(QuantizedModelVerifierTest, ResizeNearestNeighbor)
TEST_WITH_GRAPH(ResizeNearestNeighborTestGraph, Type::U8, Granularity::LayerWise);
TEST_WITH_GRAPH(ResizeNearestNeighborTestGraph, Type::U8, Granularity::ChannelWise);
TEST_WITH_GRAPH(ResizeNearestNeighborTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(ResizeBilinearTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(ResizeBilinearTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(ResizeBilinearTestGraph, Type::S16, Granularity::ChannelWise);
SUCCEED();
}
@@ -2099,6 +2427,93 @@ TEST(QuantizedModelVerifierTest, Unpack_wrong_granularity_NEG)
SUCCEED();
}
+TEST(QuantizedModelVerifierTest, Add)
+{
+ TEST_WITH_GRAPH(AddTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(AddTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(AddTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(AddTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(AddTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(AddTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Add_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(AddTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(AddTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(AddTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Add_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY_TARGET(AddTestGraph, Type::U8, Granularity::LayerWise, g.x());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(AddTestGraph, Type::U8, Granularity::ChannelWise, g.x());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(AddTestGraph, Type::S16, Granularity::ChannelWise, g.x());
+
+ TEST_WITH_WRONG_GRANULARITY_TARGET(AddTestGraph, Type::U8, Granularity::LayerWise, g.y());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(AddTestGraph, Type::U8, Granularity::ChannelWise, g.y());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(AddTestGraph, Type::S16, Granularity::ChannelWise, g.y());
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Mul)
+{
+ TEST_WITH_GRAPH(MulTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(MulTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(MulTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(MulTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(MulTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(MulTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Mul_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(MulTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(MulTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(MulTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Mul_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY_TARGET(MulTestGraph, Type::U8, Granularity::LayerWise, g.x());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(MulTestGraph, Type::U8, Granularity::ChannelWise, g.x());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(MulTestGraph, Type::S16, Granularity::ChannelWise, g.x());
+
+ TEST_WITH_WRONG_GRANULARITY_TARGET(MulTestGraph, Type::U8, Granularity::LayerWise, g.y());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(MulTestGraph, Type::U8, Granularity::ChannelWise, g.y());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(MulTestGraph, Type::S16, Granularity::ChannelWise, g.y());
+ SUCCEED();
+}
+
+// TODO Add following testcases
+//
+// CircleConv2D
+//
+// CircleDepthwiseConv2D
+//
+// CirclePRelu
+//
+// CircleTransposeConv
+//
+// CircleFullyConnected
+//
+// CircleAveragePool2D
+//
+// CircleMaxPool2D
+//
+// CircleMean
+//
+// CircleRelu
+//
+// CircleCast
+//
+
#undef TEST_WITH_GRAPH
#undef TEST_WITH_WRONG_TYPE
#undef TEST_WITH_WRONG_GRANULARITY
diff --git a/compiler/luci/pass/src/RemoveRedundantQuantizePass.cpp b/compiler/luci/pass/src/RemoveRedundantQuantizePass.cpp
new file mode 100644
index 000000000..8a10ad4a0
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveRedundantQuantizePass.cpp
@@ -0,0 +1,104 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/RemoveRedundantQuantizePass.h"
+
+#include <luci/IR/CircleNode.h>
+
+/**
+ * Remove redundant quantize operations. For subsequent Quantize Ops,
+ * only the last Quantize Op is valid, so we can remove the rest of the Quantize Op.
+ *
+ * BEFORE
+ * [CircleNode_1]
+ * |
+ * [CircleQuantize, dtype_1, scale_1, zero_point_1]
+ * |
+ * [CircleQuantize, dtype_2, scale_2, zero_point_2]
+ * |
+ * [CircleNode_2]
+ *
+ * AFTER
+ * [CircleNode_1]
+ * / \
+ * / \
+ * / \
+ * / \
+ * / \
+ * [CircleQuantize, dtype_2, scale_2, zero_point_2] [CircleQuantize, dtype_1, scale_1, zero_point_1]
+ * |
+ * [CircleNode_2]
+ *
+ */
+
+namespace
+{
+
+bool remove_redundant_quantize(luci::CircleQuantize *node)
+{
+ auto pred_node = loco::must_cast<luci::CircleNode *>(node->input());
+
+ if (node->quantparam() == nullptr or pred_node->quantparam() == nullptr)
+ return false;
+
+ if (node->quantparam()->scale.size() != 1 or node->quantparam()->zerop.size() != 1 or
+ pred_node->quantparam()->scale.size() != 1 or pred_node->quantparam()->zerop.size() != 1)
+ {
+ return false;
+ }
+
+ if (node->dtype() != pred_node->dtype() or
+ pred_node->quantparam()->scale.at(0) != node->quantparam()->scale.at(0) or
+ pred_node->quantparam()->zerop.at(0) != node->quantparam()->zerop.at(0))
+ {
+ return false;
+ }
+
+ replace(node).with(pred_node);
+
+ return true;
+}
+
+bool remove_redundant_subsequent_quantize(luci::CircleQuantize *node)
+{
+ auto pred_node = dynamic_cast<luci::CircleQuantize *>(node->input());
+ if (pred_node == nullptr)
+ return remove_redundant_quantize(node);
+
+ node->input(pred_node->input());
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool RemoveRedundantQuantizePass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
+ {
+ if (auto quantize_node = dynamic_cast<luci::CircleQuantize *>(node))
+ {
+ if (remove_redundant_subsequent_quantize(quantize_node))
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/RemoveRedundantQuantizePass.test.cpp b/compiler/luci/pass/src/RemoveRedundantQuantizePass.test.cpp
new file mode 100644
index 000000000..d0166bd20
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveRedundantQuantizePass.test.cpp
@@ -0,0 +1,166 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/RemoveRedundantQuantizePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <luci/test/TestIOGraph.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class QuantizeGraphlet
+{
+public:
+ QuantizeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g)
+ {
+ _first_quantize = g->nodes()->create<luci::CircleQuantize>();
+ _first_quantize->dtype(loco::DataType::U8);
+ {
+ auto quantize_param = std::make_unique<luci::CircleQuantParam>();
+ quantize_param->scale = {0.5};
+ quantize_param->zerop = {0};
+ _first_quantize->quantparam(std::move(quantize_param));
+ }
+ _first_quantize->name("first_quantize");
+
+ _second_quantize = g->nodes()->create<luci::CircleQuantize>();
+ _second_quantize->dtype(loco::DataType::U8);
+ {
+ auto quantize_param = std::make_unique<luci::CircleQuantParam>();
+ quantize_param->scale = {0.5};
+ quantize_param->zerop = {0};
+ _second_quantize->quantparam(std::move(quantize_param));
+ }
+ _second_quantize->name("second_quantize");
+ }
+
+protected:
+ luci::CircleQuantize *_first_quantize = nullptr;
+ luci::CircleQuantize *_second_quantize = nullptr;
+};
+
+class RedundantSubsequentQuantizeGraph : public TestIOGraph, public QuantizeGraphlet
+{
+public:
+ RedundantSubsequentQuantizeGraph() = default;
+
+public:
+ void init(void)
+ {
+ TestIOGraph::init({1}, {1});
+ QuantizeGraphlet::init(g());
+
+ input()->dtype(loco::DataType::U8);
+ {
+ auto quantize_param = std::make_unique<luci::CircleQuantParam>();
+ quantize_param->scale = {1};
+ quantize_param->zerop = {1};
+ input()->quantparam(std::move(quantize_param));
+ }
+
+ _first_quantize->input(input());
+ _second_quantize->input(_first_quantize);
+
+ output()->from(_second_quantize);
+ output()->dtype(loco::DataType::U8);
+ }
+};
+
+class RedundantQuantizeGraph : public TestIOGraph, public QuantizeGraphlet
+{
+public:
+ RedundantQuantizeGraph() = default;
+
+public:
+ void init(void)
+ {
+ TestIOGraph::init({1}, {1});
+ QuantizeGraphlet::init(g());
+
+ input()->dtype(loco::DataType::U8);
+ {
+ auto quantize_param = std::make_unique<luci::CircleQuantParam>();
+ quantize_param->scale = {0.5};
+ quantize_param->zerop = {0};
+ input()->quantparam(std::move(quantize_param));
+ }
+
+ _first_quantize->input(input());
+
+ output()->from(_first_quantize);
+ output()->dtype(loco::DataType::U8);
+ }
+};
+
+} // namespace
+
+TEST(RemoveRedundantQuantizePass, name)
+{
+ luci::RemoveRedundantQuantizePass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST(RemoveRedundantQuantizePass, remove_subsequent_quantize)
+{
+ RedundantSubsequentQuantizeGraph g;
+ luci::RemoveRedundantQuantizePass pass;
+
+ g.init();
+
+ EXPECT_TRUE(pass.run(g.g()));
+
+ int count = 0;
+ for (auto node : loco::active_nodes(loco::output_nodes(g.g())))
+ {
+ if (dynamic_cast<luci::CircleQuantize *>(node))
+ {
+ count++;
+ }
+ }
+
+ ASSERT_EQ(1, count);
+}
+
+TEST(RemoveRedundantQuantizePass, remove_quantize)
+{
+ RedundantQuantizeGraph g;
+ luci::RemoveRedundantQuantizePass pass;
+
+ g.init();
+
+ EXPECT_TRUE(pass.run(g.g()));
+
+ int count = 0;
+ for (auto node : loco::active_nodes(loco::output_nodes(g.g())))
+ {
+ if (dynamic_cast<luci::CircleQuantize *>(node))
+ {
+ count++;
+ }
+ }
+
+ ASSERT_EQ(0, count);
+}
diff --git a/compiler/luci/pass/src/RemoveRedundantTransposePass.cpp b/compiler/luci/pass/src/RemoveRedundantTransposePass.cpp
index 71c51ecda..75cf72795 100644
--- a/compiler/luci/pass/src/RemoveRedundantTransposePass.cpp
+++ b/compiler/luci/pass/src/RemoveRedundantTransposePass.cpp
@@ -71,7 +71,7 @@ bool remove_consecutive_transpose_function(luci::CircleTranspose *target_node)
for (uint32_t i = 0; i < pred_perm->size<loco::DataType::S32>(); i++)
{
new_const_node->at<loco::DataType::S32>(i) =
- target_perm->at<loco::DataType::S32>(pred_perm->at<loco::DataType::S32>(i));
+ pred_perm->at<loco::DataType::S32>(target_perm->at<loco::DataType::S32>(i));
}
new_const_node->name(name + "/Transpose/perm");
diff --git a/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp b/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp
index e80623499..bb8e292d4 100644
--- a/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp
+++ b/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp
@@ -271,6 +271,31 @@ TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type2)
ASSERT_EQ(2, perm->at<loco::DataType::S32>(3));
}
+TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type3)
+{
+ auto graph = loco::make_graph();
+ create_redundunt_transpose(graph.get(), {0, 3, 2, 1}, {0, 2, 3, 1});
+
+ luci::RemoveRedundantTransposePass pass;
+ while (pass.run(graph.get()))
+ ;
+ luci::CircleTranspose *transpose_node = nullptr;
+ for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
+ {
+ auto trans = dynamic_cast<luci::CircleTranspose *>(node);
+ if (not trans)
+ continue;
+ transpose_node = trans;
+ break;
+ }
+ ASSERT_NE(nullptr, transpose_node);
+ auto perm = loco::must_cast<luci::CircleConst *>(transpose_node->perm());
+ ASSERT_EQ(0, perm->at<loco::DataType::S32>(0));
+ ASSERT_EQ(2, perm->at<loco::DataType::S32>(1));
+ ASSERT_EQ(1, perm->at<loco::DataType::S32>(2));
+ ASSERT_EQ(3, perm->at<loco::DataType::S32>(3));
+}
+
/**
* @brief Test case that first transpose output become input of operations more than one.
*/
diff --git a/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.cpp b/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.cpp
index 3f0c4ee82..fb46f490d 100644
--- a/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.cpp
+++ b/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.cpp
@@ -58,6 +58,25 @@ bool remove_no_effect_reshape(luci::CircleNode *node)
namespace luci
{
+/**
+ * BEFORE
+ * [CircleNode]
+ * |
+ * [CircleReshape]
+ * |
+ * [CircleNode]
+ *
+ * AFTER
+ * [CircleNode]
+ * | \
+ * | [CircleReshape]
+ * |
+ * [CircleNode]
+ *
+ * NOTE
+ * This pass will remove Reshape when input and output has same shape
+ */
+
bool RemoveUnnecessaryReshapePass::run(loco::Graph *g)
{
bool changed = false;
diff --git a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp
index a0cc0194f..bca0a9483 100644
--- a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp
+++ b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp
@@ -26,8 +26,17 @@ namespace
luci::CircleConst *create_weights_from_gamma(luci::CircleConst *gamma)
{
- assert(gamma->rank() == 1);
- auto channel_size = gamma->dim(0).value();
+ assert(gamma->rank() == 1 or gamma->rank() == 4);
+
+ uint32_t channel_idx = gamma->rank() - 1;
+ uint32_t channel_size = gamma->dim(channel_idx).value();
+
+ // Gamma should be broadcastable in the channel direction
+ for (uint32_t i = 0; i < gamma->rank(); i++)
+ {
+ if (i != channel_idx)
+ assert(gamma->dim(i).value() == 1); // FIX is_batchnorm_mul UNLESS
+ }
auto name = gamma->name();
assert(name.length() > 0);
@@ -53,8 +62,17 @@ luci::CircleConst *create_weights_from_gamma(luci::CircleConst *gamma)
luci::CircleConst *create_bias_from_beta(luci::CircleConst *beta)
{
- assert(beta->rank() == 1);
- auto channel_size = beta->dim(0).value();
+ assert(beta->rank() == 1 or beta->rank() == 4);
+
+ uint32_t channel_idx = beta->rank() - 1;
+ uint32_t channel_size = beta->dim(channel_idx).value();
+
+ // Beta should be broadcastable in the channel direction
+ for (uint32_t i = 0; i < beta->rank(); i++)
+ {
+ if (i != channel_idx)
+ assert(beta->dim(i).value() == 1); // FIX is_batchnorm_add UNLESS
+ }
auto name = beta->name();
assert(name.length() > 0);
diff --git a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp
index 903d4dcc9..bac033112 100644
--- a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp
+++ b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp
@@ -141,6 +141,37 @@ TEST(ReplaceMulAddWithDepthwiseConv, simple)
}
}
+TEST(ReplaceMulAddWithDepthwiseConv, simple_rank4)
+{
+ SimpleGraph g;
+
+ const uint32_t channel_size = 16;
+ g.gamma->shape({1, 1, 1, channel_size});
+ g.beta->shape({1, 1, 1, channel_size});
+
+ luci::ReplaceMulAddWithDepthwiseConvPass pass;
+ while (pass.run(&g.g))
+ ;
+
+ auto dwconv = dynamic_cast<luci::CircleDepthwiseConv2D *>(g.output->from());
+ EXPECT_NE(nullptr, dwconv);
+
+ auto weights = dynamic_cast<luci::CircleConst *>(dwconv->filter());
+ auto bias = dynamic_cast<luci::CircleConst *>(dwconv->bias());
+ EXPECT_NE(nullptr, weights);
+ EXPECT_EQ(4, weights->rank());
+ EXPECT_EQ(channel_size, weights->dim(3).value());
+ EXPECT_NE(nullptr, bias);
+ EXPECT_EQ(1, bias->rank());
+ EXPECT_EQ(channel_size, bias->dim(0).value());
+
+ for (int i = 0; i < channel_size; i++)
+ {
+ EXPECT_FLOAT_EQ(i, weights->at<loco::DataType::FLOAT32>(i));
+ EXPECT_FLOAT_EQ(i, bias->at<loco::DataType::FLOAT32>(i));
+ }
+}
+
TEST(ReplaceMulAddWithDepthwiseConv, wrong_op_NEG)
{
SimpleGraph g;
@@ -154,3 +185,18 @@ TEST(ReplaceMulAddWithDepthwiseConv, wrong_op_NEG)
EXPECT_EQ(false, changed);
}
+
+TEST(ReplaceMulAddWithDepthwiseConv, rank3_NEG)
+{
+ SimpleGraph g;
+
+ g.input->shape({4, 4, 16});
+ g.mul->shape({4, 4, 16});
+ g.add->shape({4, 4, 16});
+ g.output->shape({4, 4, 16});
+
+ luci::ReplaceMulAddWithDepthwiseConvPass pass;
+ auto changed = pass.run(&g.g);
+
+ EXPECT_EQ(false, changed);
+}
diff --git a/compiler/luci/pass/src/SubstituteSplitVToSplitPass.cpp b/compiler/luci/pass/src/SubstituteSplitVToSplitPass.cpp
index 9cba9a9e7..57c386d99 100644
--- a/compiler/luci/pass/src/SubstituteSplitVToSplitPass.cpp
+++ b/compiler/luci/pass/src/SubstituteSplitVToSplitPass.cpp
@@ -24,15 +24,6 @@
namespace
{
-void copy_quantparam(luci::CircleNode *dst, const luci::CircleNode *src)
-{
- auto q = src->quantparam();
- if (q == nullptr)
- dst->quantparam(nullptr);
- else
- dst->quantparam(std::make_unique<luci::CircleQuantParam>(*q));
-}
-
// SplitV is substituted to Split if the contents of size_splits are all same
// For example,
// size_splits = [32, 32] -> substitute
@@ -67,7 +58,7 @@ bool resolve_splitv(luci::CircleSplitV *sv)
split_node->split_dim(sv->split_dim());
split_node->num_split(sv->num_split());
split_node->name(sv->name());
- copy_quantparam(split_node, sv);
+ copy_quantparam(sv, split_node);
luci::add_origin(split_node, luci::get_origin(sv));
auto succs = loco::succs(sv);
@@ -78,7 +69,7 @@ bool resolve_splitv(luci::CircleSplitV *sv)
so_node->input(split_node);
so_node->index(svo->index());
so_node->name(svo->name());
- copy_quantparam(so_node, svo);
+ copy_quantparam(svo, so_node);
luci::add_origin(so_node, luci::get_origin(svo));
replace(svo).with(so_node);
diff --git a/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp b/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp
index f48763782..df7266df9 100644
--- a/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp
+++ b/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp
@@ -76,18 +76,6 @@ std::vector<uint32_t> node_shape(const luci::CircleNode *input)
}
/**
- * @brief copy quantparam of src to dst
- */
-void copy_quantparam(luci::CircleNode *dst, const luci::CircleNode *src)
-{
- auto q = src->quantparam();
- if (q == nullptr)
- dst->quantparam(nullptr);
- else
- dst->quantparam(std::make_unique<luci::CircleQuantParam>(*q));
-}
-
-/**
* @brief return CircleConst ptr with values of new_shape
*/
luci::CircleConst *create_shape_const(loco::Graph *graph, const std::vector<uint32_t> &new_shape)
@@ -142,7 +130,7 @@ bool substitute_squeeze_to_reshape(luci::CircleSqueeze *squeeze)
auto graph = squeeze->graph();
auto reshape = graph->nodes()->create<luci::CircleReshape>();
auto shape_const = create_shape_const(graph, reshape_shape);
- copy_quantparam(reshape, squeeze);
+ copy_quantparam(squeeze, reshape);
reshape->name(name + "/Reshape");
luci::add_origin(reshape, luci::get_origin(squeeze));
shape_const->name(name + "/Reshape/shape");
diff --git a/compiler/luci/pass/src/SubstituteStridedSliceToReshapePass.cpp b/compiler/luci/pass/src/SubstituteStridedSliceToReshapePass.cpp
index f50f2f54f..9e1c5a4a3 100644
--- a/compiler/luci/pass/src/SubstituteStridedSliceToReshapePass.cpp
+++ b/compiler/luci/pass/src/SubstituteStridedSliceToReshapePass.cpp
@@ -124,7 +124,7 @@ bool substitute_strided_slice_to_reshape(luci::CircleStridedSlice *ss_node)
std::bitset<32> end_mask(ss_node->end_mask());
std::bitset<32> shrink_axis_mask(ss_node->shrink_axis_mask());
- uint input_rank = input_node->rank();
+ uint32_t input_rank = input_node->rank();
for (uint32_t i = 0; i < input_rank; i++)
{
if (!input_node->dim(i).known())
diff --git a/compiler/luci/pass/src/VerifyQuantizedBiasScale.cpp b/compiler/luci/pass/src/VerifyQuantizedBiasScale.cpp
new file mode 100644
index 000000000..e65d576cd
--- /dev/null
+++ b/compiler/luci/pass/src/VerifyQuantizedBiasScale.cpp
@@ -0,0 +1,105 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "VerifyQuantizedBiasScale.h"
+
+#include <cmath>
+
+// This macro is undef at the end of the file
+#define RETURN_FALSE_UNLESS(ARG) \
+ if (not(ARG)) \
+ { \
+ return false; \
+ }
+
+namespace
+{
+
+bool same(float a, float b)
+{
+ constexpr float epsilon = 1e-10;
+ return abs(a - b) < epsilon;
+}
+
+// Check bias scale = input scale * weight scale
+// This function checks both LWQ and CWQ
+bool check_bias_scale(const loco::Node *input, const loco::Node *weights, const loco::Node *bias)
+{
+ auto input_node = loco::must_cast<const luci::CircleNode *>(input);
+ auto input_qparam = input_node->quantparam();
+ RETURN_FALSE_UNLESS(input_qparam != nullptr);
+
+ auto weights_node = loco::must_cast<const luci::CircleNode *>(weights);
+ auto weights_qparam = weights_node->quantparam();
+ RETURN_FALSE_UNLESS(weights_qparam != nullptr);
+
+ auto bias_node = loco::must_cast<const luci::CircleNode *>(bias);
+ auto bias_qparam = bias_node->quantparam();
+ RETURN_FALSE_UNLESS(bias_qparam != nullptr);
+
+ RETURN_FALSE_UNLESS(input_qparam->scale.size() == 1);
+ RETURN_FALSE_UNLESS(weights_qparam->scale.size() == bias_qparam->scale.size());
+
+ auto input_scale = input_qparam->scale[0];
+ for (uint32_t i = 0; i < weights_qparam->scale.size(); i++)
+ {
+ auto weights_scale = weights_qparam->scale[i];
+ auto bias_scale = bias_qparam->scale[i];
+ RETURN_FALSE_UNLESS(same(bias_scale, input_scale * weights_scale));
+ }
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool VerifyQuantizedBiasScale::visit(const luci::CircleConv2D *node)
+{
+ RETURN_FALSE_UNLESS(check_bias_scale(node->input(), node->filter(), node->bias()));
+ return true;
+}
+
+bool VerifyQuantizedBiasScale::visit(const luci::CircleDepthwiseConv2D *node)
+{
+ RETURN_FALSE_UNLESS(check_bias_scale(node->input(), node->filter(), node->bias()));
+ return true;
+}
+
+bool VerifyQuantizedBiasScale::visit(const luci::CircleFullyConnected *node)
+{
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ {
+ RETURN_FALSE_UNLESS(check_bias_scale(node->input(), node->weights(), node->bias()));
+ }
+ return true;
+}
+
+bool VerifyQuantizedBiasScale::visit(const luci::CircleTransposeConv *node)
+{
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ {
+ RETURN_FALSE_UNLESS(check_bias_scale(node->outBackprop(), node->filter(), node->bias()));
+ }
+ return true;
+}
+
+} // namespace luci
+
+#undef RETURN_FALSE_UNLESS
diff --git a/compiler/luci/pass/src/VerifyQuantizedBiasScale.h b/compiler/luci/pass/src/VerifyQuantizedBiasScale.h
new file mode 100644
index 000000000..b41f78eca
--- /dev/null
+++ b/compiler/luci/pass/src/VerifyQuantizedBiasScale.h
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_VERIFY_QUANTIZED_BIAS_SCALE_H__
+#define __LUCI_VERIFY_QUANTIZED_BIAS_SCALE_H__
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+#include <memory>
+
+namespace luci
+{
+
+/**
+ * @brief Verify the scale of quantized bias node
+ * @details
+ *
+ * Bias of CONV, DCONV, TCONV, FC layers should meet the following condition.
+ *
+ * bias scale = input scale * weights scale
+ */
+class VerifyQuantizedBiasScale : public luci::CircleNodeVisitor<bool>
+{
+public:
+ static std::shared_ptr<VerifyQuantizedBiasScale> create()
+ {
+ return std::make_shared<VerifyQuantizedBiasScale>();
+ };
+
+public:
+ bool verify(luci::CircleNode *node) { return node->accept(this); }
+
+private:
+ // Operators with bias
+ bool visit(const luci::CircleConv2D *node);
+ bool visit(const luci::CircleDepthwiseConv2D *node);
+ bool visit(const luci::CircleFullyConnected *node);
+ bool visit(const luci::CircleTransposeConv *node);
+
+ bool visit(const luci::CircleNode *) { return true; }
+};
+
+} // namespace luci
+
+#endif // __LUCI_VERIFY_QUANTIZED_BIAS_SCALE_H__
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.cpp b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.cpp
new file mode 100644
index 000000000..8697090a7
--- /dev/null
+++ b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.cpp
@@ -0,0 +1,38 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "VerifyQuantizedNodeGranularity.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Pass/QuantizationParameters.h>
+
+#include <memory>
+
+namespace luci
+{
+
+std::shared_ptr<VerifyQuantizedNodeGranularity>
+VerifyQuantizedNodeGranularity::create(Granularity granularity)
+{
+ if (granularity == Granularity::ChannelWise)
+ return std::make_shared<VerifyQuantizedNodeChannelWiseGranularity>();
+ else if (granularity == Granularity::LayerWise)
+ return std::make_shared<VerifyQuantizedNodeLayerWiseGranularity>();
+ else
+ throw std::domain_error("Not supported Granularity type");
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeChannelWiseGranularity.h b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h
index bf3ff2e8a..442183c18 100644
--- a/compiler/luci/pass/src/VerifyQuantizedNodeChannelWiseGranularity.h
+++ b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h
@@ -1,5 +1,6 @@
/*
- * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
@@ -13,13 +14,15 @@
* limitations under the License.
*/
-#ifndef __LUCI_VERIFY_QUANTIZED_NODE_CHANNELWISE_GRANULARITY_H__
-#define __LUCI_VERIFY_QUANTIZED_NODE_CHANNELWISE_GRANULARITY_H__
+#ifndef __LUCI_VERIFY_QUANTIZED_NODE_GRANULARITY_H__
+#define __LUCI_VERIFY_QUANTIZED_NODE_GRANULARITY_H__
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
#include <luci/Pass/QuantizationParameters.h>
+#include <memory>
+
using Granularity = luci::QuantizationGranularity;
// This macro is undef at the end of the file
@@ -33,16 +36,19 @@ namespace luci
{
/**
- * @brief Verify the granualrity of channel-wise quantized node
+ * @brief Verify the granualrity of quantized node
* @details
*
* Targets to verify
* - node's output (i.e., node itself)
* - node's inputs
*/
-struct VerifyQuantizedNodeChannelWiseGranularity final : public luci::CircleNodeVisitor<bool>
+class VerifyQuantizedNodeGranularity : public luci::CircleNodeVisitor<bool>
{
-private:
+public:
+ static std::shared_ptr<VerifyQuantizedNodeGranularity> create(Granularity granularity);
+
+protected:
bool is_lwq(const loco::Node *node)
{
auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
@@ -59,48 +65,15 @@ private:
return true;
}
- uint32_t rank(const loco::Node *node)
- {
- auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
- return circle_node->rank();
- }
-
- bool is_cwq_const(const loco::Node *node, uint32_t channel_dim)
- {
- auto circle_node = loco::must_cast<const luci::CircleConst *>(node);
-
- assert(channel_dim < circle_node->rank()); // FIX_CALLER_UNLESS
- auto channel_size = circle_node->dim(channel_dim).value();
-
- if (circle_node->quantparam() == nullptr)
- return false;
-
- if (circle_node->quantparam()->quantized_dimension != static_cast<int32_t>(channel_dim))
- return false;
-
- if (circle_node->quantparam()->scale.size() != channel_size)
- return false;
-
- if (circle_node->quantparam()->zerop.size() != channel_size)
- return false;
-
- return true;
- }
-
private:
- bool visit(const luci::CircleConv2D *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
- return true;
- }
+ virtual bool visit(const luci::CircleConv2D *node) = 0;
bool visit(const luci::CircleConcatenation *node)
{
+ // Skip granularity check for concatenation of indices
+ if (node->dtype() == loco::DataType::S32 or node->dtype() == loco::DataType::S64)
+ return true;
+
RETURN_FALSE_UNLESS(is_lwq(node))
for (uint32_t i = 0; i < node->numValues(); i++)
{
@@ -116,25 +89,9 @@ private:
return true;
}
- bool visit(const luci::CircleDepthwiseConv2D *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 3))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
- return true;
- }
+ virtual bool visit(const luci::CircleDepthwiseConv2D *node) = 0;
- bool visit(const luci::CircleInstanceNorm *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_cwq_const(node->gamma(), rank(node->gamma()) - 1))
- RETURN_FALSE_UNLESS(is_cwq_const(node->beta(), rank(node->beta()) - 1))
- return true;
- }
+ virtual bool visit(const luci::CircleInstanceNorm *node) = 0;
bool visit(const luci::CirclePack *node)
{
@@ -168,37 +125,11 @@ private:
return true;
}
- bool visit(const luci::CirclePRelu *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_cwq_const(node->alpha(), rank(node->alpha()) - 1))
- return true;
- }
-
- bool visit(const luci::CircleTransposeConv *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->outBackprop()))
- RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
+ virtual bool visit(const luci::CirclePRelu *node) = 0;
- return true;
- }
+ virtual bool visit(const luci::CircleTransposeConv *node) = 0;
- bool visit(const luci::CircleFullyConnected *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_cwq_const(node->weights(), 0))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- // Bias is optional (it can be CircleOutputExclude)
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
- return true;
- }
+ virtual bool visit(const luci::CircleFullyConnected *node) = 0;
bool visit(const luci::CircleAdd *node)
{
@@ -258,6 +189,14 @@ private:
return true;
}
+ bool visit(const luci::CircleOneHot *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->off_value()));
+ RETURN_FALSE_UNLESS(is_lwq(node->on_value()));
+ return true;
+ }
+
bool visit(const luci::CircleRelu *node)
{
RETURN_FALSE_UNLESS(is_lwq(node));
@@ -480,8 +419,186 @@ private:
bool visit(const luci::CircleNode *) { return true; }
};
+class VerifyQuantizedNodeChannelWiseGranularity final : public VerifyQuantizedNodeGranularity
+{
+private:
+ uint32_t rank(const loco::Node *node)
+ {
+ auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
+ return circle_node->rank();
+ }
+
+ bool is_cwq_const(const loco::Node *node, uint32_t channel_dim)
+ {
+ auto circle_node = loco::must_cast<const luci::CircleConst *>(node);
+
+ assert(channel_dim < circle_node->rank()); // FIX_CALLER_UNLESS
+ auto channel_size = circle_node->dim(channel_dim).value();
+
+ if (circle_node->quantparam() == nullptr)
+ return false;
+
+ if (circle_node->quantparam()->quantized_dimension != static_cast<int32_t>(channel_dim))
+ return false;
+
+ if (circle_node->quantparam()->scale.size() != channel_size)
+ return false;
+
+ if (circle_node->quantparam()->zerop.size() != channel_size)
+ return false;
+
+ return true;
+ }
+
+private:
+ bool visit(const luci::CircleConv2D *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
+ return true;
+ }
+
+ bool visit(const luci::CircleDepthwiseConv2D *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 3))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
+ return true;
+ }
+
+ bool visit(const luci::CircleInstanceNorm *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->gamma(), rank(node->gamma()) - 1))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->beta(), rank(node->beta()) - 1))
+ return true;
+ }
+
+ bool visit(const luci::CirclePRelu *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->alpha(), rank(node->alpha()) - 1))
+ return true;
+ }
+
+ bool visit(const luci::CircleTransposeConv *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->outBackprop()))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
+
+ return true;
+ }
+
+ bool visit(const luci::CircleFullyConnected *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->weights(), 0))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ // Bias is optional (it can be CircleOutputExclude)
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
+ return true;
+ }
+};
+
+class VerifyQuantizedNodeLayerWiseGranularity final : public VerifyQuantizedNodeGranularity
+{
+private:
+ bool is_lwq_const(const loco::Node *node)
+ {
+ auto circle_node = loco::must_cast<const luci::CircleConst *>(node);
+
+ if (circle_node->quantparam() == nullptr)
+ return false;
+
+ if (circle_node->quantparam()->scale.size() != 1)
+ return false;
+
+ if (circle_node->quantparam()->zerop.size() != 1)
+ return false;
+
+ return true;
+ }
+
+private:
+ bool visit(const luci::CircleConv2D *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
+ return true;
+ }
+
+ bool visit(const luci::CircleDepthwiseConv2D *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
+ return true;
+ }
+
+ bool visit(const luci::CircleInstanceNorm *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->gamma()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->beta()))
+ return true;
+ }
+
+ bool visit(const luci::CirclePRelu *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->alpha()))
+ return true;
+ }
+
+ bool visit(const luci::CircleTransposeConv *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->outBackprop()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
+ return true;
+ }
+
+ bool visit(const luci::CircleFullyConnected *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->weights()))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
+ return true;
+ }
+};
+
} // namespace luci
#undef RETURN_FALSE_UNLESS
-#endif // __LUCI_VERIFY_QUANTIZED_NODE_CHANNELWISE_GRANULARITY_H__
+#endif // __LUCI_VERIFY_QUANTIZED_NODE_GRANULARITY_H__
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeLayerWiseGranularity.h b/compiler/luci/pass/src/VerifyQuantizedNodeLayerWiseGranularity.h
deleted file mode 100644
index 9bc8b31df..000000000
--- a/compiler/luci/pass/src/VerifyQuantizedNodeLayerWiseGranularity.h
+++ /dev/null
@@ -1,473 +0,0 @@
-/*
- * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __LUCI_VERIFY_QUANTIZED_NODE_LAYERWISE_GRANULARITY_H__
-#define __LUCI_VERIFY_QUANTIZED_NODE_LAYERWISE_GRANULARITY_H__
-
-#include <luci/IR/CircleNodes.h>
-#include <luci/IR/CircleNodeVisitor.h>
-#include <luci/Pass/QuantizationParameters.h>
-
-using Granularity = luci::QuantizationGranularity;
-
-// This macro is undef at the end of the file
-#define RETURN_FALSE_UNLESS(ARG) \
- if (not(ARG)) \
- { \
- return false; \
- }
-
-namespace luci
-{
-
-/**
- * @brief Verify the granualrity of layer-wise quantized node
- * @details
- *
- * Targets to verify
- * - node's output (i.e., node itself)
- * - node's inputs
- */
-struct VerifyQuantizedNodeLayerWiseGranularity final : public luci::CircleNodeVisitor<bool>
-{
-private:
- bool is_lwq(const loco::Node *node)
- {
- auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
-
- if (circle_node->quantparam() == nullptr)
- return false;
-
- if (circle_node->quantparam()->scale.size() != 1)
- return false;
-
- if (circle_node->quantparam()->zerop.size() != 1)
- return false;
-
- return true;
- }
-
- bool is_lwq_const(const loco::Node *node)
- {
- auto circle_node = loco::must_cast<const luci::CircleConst *>(node);
-
- if (circle_node->quantparam() == nullptr)
- return false;
-
- if (circle_node->quantparam()->scale.size() != 1)
- return false;
-
- if (circle_node->quantparam()->zerop.size() != 1)
- return false;
-
- return true;
- }
-
-private:
- bool visit(const luci::CircleConv2D *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
- return true;
- }
-
- bool visit(const luci::CircleConcatenation *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- for (uint32_t i = 0; i < node->numValues(); i++)
- {
- RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
- }
- return true;
- }
-
- bool visit(const luci::CircleDepthToSpace *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- return true;
- }
-
- bool visit(const luci::CircleDepthwiseConv2D *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
- return true;
- }
-
- bool visit(const luci::CircleInstanceNorm *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_lwq_const(node->gamma()))
- RETURN_FALSE_UNLESS(is_lwq_const(node->beta()))
- return true;
- }
-
- bool visit(const luci::CirclePack *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- for (uint32_t i = 0; i < node->values_count(); i++)
- {
- RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
- }
- return true;
- }
-
- bool visit(const luci::CirclePad *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- return true;
- }
-
- bool visit(const luci::CirclePadV2 *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_lwq(node->constant_values()))
- return true;
- }
-
- bool visit(const luci::CircleMirrorPad *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- return true;
- }
-
- bool visit(const luci::CirclePRelu *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_lwq_const(node->alpha()))
- return true;
- }
-
- bool visit(const luci::CircleTransposeConv *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->outBackprop()))
- RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
- return true;
- }
-
- bool visit(const luci::CircleFullyConnected *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()))
- RETURN_FALSE_UNLESS(is_lwq_const(node->weights()))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
- return true;
- }
-
- bool visit(const luci::CircleAdd *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- RETURN_FALSE_UNLESS(is_lwq(node->y()));
- return true;
- }
-
- bool visit(const luci::CircleAveragePool2D *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->value()));
- return true;
- }
-
- bool visit(const luci::CircleLogicalOr *)
- {
- // Logical OR has bool-type inputs and output
- // Nothing to be checked
- return true;
- }
-
- bool visit(const luci::CircleMaxPool2D *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->value()));
- return true;
- }
-
- bool visit(const luci::CircleLocalResponseNormalization *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleMean *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleMul *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- RETURN_FALSE_UNLESS(is_lwq(node->y()));
- return true;
- }
-
- bool visit(const luci::CircleNotEqual *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- RETURN_FALSE_UNLESS(is_lwq(node->y()));
- return true;
- }
-
- bool visit(const luci::CircleRelu *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->features()));
- return true;
- }
-
- bool visit(const luci::CircleReshape *node)
- {
- auto input = loco::must_cast<const luci::CircleNode *>(node->tensor());
- bool input_quantized = input->quantparam() != nullptr;
- bool node_quantized = node->quantparam() != nullptr;
- RETURN_FALSE_UNLESS(input_quantized == node_quantized);
- RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node))
- RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input));
- return true;
- }
-
- bool visit(const luci::CircleLogistic *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- return true;
- }
-
- bool visit(const luci::CircleSoftmax *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->logits()));
- return true;
- }
-
- bool visit(const luci::CircleSpaceToBatchND *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleSpaceToDepth *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleSlice *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleSplit *node)
- {
- // node's output is the input of CircleSplitOut, thus not quantized
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleSplitOut *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- return true;
- }
-
- bool visit(const luci::CircleSplitV *node)
- {
- // node's output is the input of CircleSplitVOut, thus not quantized
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleSplitVOut *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- return true;
- }
-
- bool visit(const luci::CircleStridedSlice *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleArgMax *node)
- {
- // node's output is index, thus not quantized
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleBatchToSpaceND *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleTanh *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- return true;
- }
-
- bool visit(const luci::CircleTranspose *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->a()));
- return true;
- }
-
- bool visit(const luci::CircleFloor *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- return true;
- }
-
- bool visit(const luci::CircleGreater *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- RETURN_FALSE_UNLESS(is_lwq(node->y()));
- return true;
- }
-
- bool visit(const luci::CircleGreaterEqual *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- RETURN_FALSE_UNLESS(is_lwq(node->y()));
- return true;
- }
-
- bool visit(const luci::CircleDiv *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- RETURN_FALSE_UNLESS(is_lwq(node->y()));
- return true;
- }
-
- bool visit(const luci::CircleFloorDiv *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- RETURN_FALSE_UNLESS(is_lwq(node->y()));
- return true;
- }
-
- bool visit(const luci::CircleRsqrt *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- return true;
- }
-
- bool visit(const luci::CircleSqrt *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- return true;
- }
-
- bool visit(const luci::CircleElu *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->features()));
- return true;
- }
-
- bool visit(const luci::CirclePow *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->x()));
- RETURN_FALSE_UNLESS(is_lwq(node->y()));
- return true;
- }
-
- bool visit(const luci::CircleResizeBilinear *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleResizeNearestNeighbor *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- RETURN_FALSE_UNLESS(is_lwq(node->input()));
- return true;
- }
-
- bool visit(const luci::CircleUnpack *node)
- {
- // node's output is the input of CircleUnpackOut, thus not quantized
- RETURN_FALSE_UNLESS(is_lwq(node->value()));
- return true;
- }
-
- bool visit(const luci::CircleUnpackOut *node)
- {
- RETURN_FALSE_UNLESS(is_lwq(node));
- return true;
- }
-
- bool visit(const luci::CircleCast *node)
- {
- auto input = loco::must_cast<const luci::CircleNode *>(node->x());
- bool input_quantized = input->quantparam() != nullptr;
- bool node_quantized = node->quantparam() != nullptr;
- RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input));
- RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node));
- return true;
- }
-
- // TODO: Implement more Ops
-
- bool visit(const luci::CircleNode *) { return true; }
-};
-
-} // namespace luci
-
-#undef RETURN_FALSE_UNLESS
-
-#endif // __LUCI_VERIFY_QUANTIZED_NODE_LAYERWISE_GRANULARITY_H__
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeS16Type.h b/compiler/luci/pass/src/VerifyQuantizedNodeS16Type.h
deleted file mode 100644
index eeec7b82b..000000000
--- a/compiler/luci/pass/src/VerifyQuantizedNodeS16Type.h
+++ /dev/null
@@ -1,516 +0,0 @@
-/*
- * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __LUCI_VERIFY_QUANTIZED_NODE_S16_TYPE_H__
-#define __LUCI_VERIFY_QUANTIZED_NODE_S16_TYPE_H__
-
-#include <luci/IR/CircleNodes.h>
-#include <luci/IR/CircleNodeVisitor.h>
-
-#include <cmath>
-
-using Type = loco::DataType;
-
-// This macro is undef at the end of the file
-#define RETURN_FALSE_UNLESS(ARG) \
- if (not(ARG)) \
- { \
- return false; \
- }
-
-namespace luci
-{
-
-/**
- * @brief Verify the data type of INT16 quantized node
- * @details
- *
- * Targets to verify
- * - node's output (i.e., node itself)
- * - node's inputs
- */
-struct VerifyQuantizedNodeS16Type final : public luci::CircleNodeVisitor<bool>
-{
-private:
- bool has_type(const loco::Node *node, Type dtype)
- {
- auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
- return circle_node->dtype() == dtype;
- }
-
-private:
- bool visit(const luci::CircleConv2D *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->filter(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S64))
- return true;
- }
-
- bool visit(const luci::CircleConcatenation *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- for (uint32_t i = 0; i < node->numValues(); i++)
- {
- RETURN_FALSE_UNLESS(has_type(node->values(i), Type::S16))
- }
- return true;
- }
-
- bool visit(const luci::CircleDepthToSpace *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleDepthwiseConv2D *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->filter(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S64))
- return true;
- }
-
- bool visit(const luci::CircleInstanceNorm *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->gamma(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->beta(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CirclePack *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- for (uint32_t i = 0; i < node->values_count(); i++)
- {
- RETURN_FALSE_UNLESS(has_type(node->values(i), Type::S16))
- }
- return true;
- }
-
- bool visit(const luci::CirclePad *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CirclePadV2 *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32))
- RETURN_FALSE_UNLESS(has_type(node->constant_values(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleMirrorPad *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CirclePRelu *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->alpha(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleTransposeConv *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->outBackprop(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->filter(), Type::S16))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(has_type(bias, Type::S64))
- return true;
- }
-
- bool visit(const luci::CircleFullyConnected *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->weights(), Type::S16))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(has_type(bias, Type::S64))
- return true;
- }
-
- bool visit(const luci::CircleAdd *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleAveragePool2D *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->value(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleLogicalOr *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::BOOL))
- return true;
- }
-
- bool visit(const luci::CircleMaxPool2D *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->value(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleLocalResponseNormalization *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleMean *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->reduction_indices(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleMul *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleNotEqual *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleRelu *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->features(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleReshape *node)
- {
- if (node->quantparam())
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->tensor(), Type::S16))
- }
- else
- {
- RETURN_FALSE_UNLESS(has_type(node->tensor(), node->dtype()))
- }
- luci::CircleConst *shape = dynamic_cast<luci::CircleConst *>(node->shape());
- if (shape != nullptr)
- RETURN_FALSE_UNLESS(has_type(shape, Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleLogistic *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
-
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32768.0f);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
- return true;
- }
-
- bool visit(const luci::CircleSoftmax *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->logits(), Type::S16))
-
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32767.0f);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
- return true;
- }
-
- bool visit(const luci::CircleSpaceToBatchND *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleSpaceToDepth *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleSlice *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->begin(), Type::S32) || has_type(node->begin(), Type::S64))
- RETURN_FALSE_UNLESS(has_type(node->size(), Type::S32) || has_type(node->size(), Type::S64))
- return true;
- }
-
- bool visit(const luci::CircleSplit *node)
- {
- // node's output is the input of CircleSplitOut, thus not quantized
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleSplitOut *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
-
- // SplitOut has the same qparam with the input of Split
- auto split = loco::must_cast<luci::CircleSplit *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(split->input());
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
- return true;
- }
-
- bool visit(const luci::CircleSplitV *node)
- {
- // node's output is the input of CircleSplitVOut, thus not quantized
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleSplitVOut *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
-
- // SplitVOut has the same qparam with the input of SplitV
- auto splitv = loco::must_cast<luci::CircleSplitV *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(splitv->input());
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
- return true;
- }
-
- bool visit(const luci::CircleStridedSlice *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
-
- auto input = loco::must_cast<luci::CircleNode *>(node->input());
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
- return true;
- }
-
- bool visit(const luci::CircleArgMax *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, node->output_type()))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->dimension(), Type::S32) ||
- has_type(node->dimension(), Type::S64))
- return true;
- }
-
- bool visit(const luci::CircleBatchToSpaceND *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleTanh *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
-
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32768.0f);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
- return true;
- }
-
- bool visit(const luci::CircleTranspose *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->a(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->perm(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleFloor *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
-
- // This checks the value of scale is an integer
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]);
- return true;
- }
-
- bool visit(const luci::CircleGreater *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleGreaterEqual *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleDiv *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleFloorDiv *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
-
- // This checks the value of scale is an integer
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]);
- return true;
- }
-
- bool visit(const luci::CircleRsqrt *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleSqrt *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleElu *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->features(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CirclePow *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleResizeBilinear *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleResizeNearestNeighbor *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleUnpack *node)
- {
- // node's output is the input of CircleUnpackOut, thus not quantized
- RETURN_FALSE_UNLESS(has_type(node->value(), Type::S16))
- return true;
- }
-
- bool visit(const luci::CircleUnpackOut *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
-
- // UnpackOut has the same qparam with the input of Unpack
- auto Unpack = loco::must_cast<luci::CircleUnpack *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(Unpack->value());
- RETURN_FALSE_UNLESS(node->quantparam() && input->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
- return true;
- }
-
- bool visit(const luci::CircleCast *node)
- {
- auto *input = loco::must_cast<luci::CircleNode *>(node->x());
- RETURN_FALSE_UNLESS(has_type(input, node->in_data_type()))
-
- bool input_quantized = input->quantparam() != nullptr;
- if (input_quantized)
- RETURN_FALSE_UNLESS(has_type(input, Type::S16))
-
- RETURN_FALSE_UNLESS(has_type(node, node->out_data_type()))
-
- bool node_quantized = node->quantparam() != nullptr;
- if (node_quantized)
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- return true;
- }
-
- // TODO: Implement more Ops
-
- bool visit(const luci::CircleNode *) { return true; }
-};
-
-} // namespace luci
-
-#undef RETURN_FALSE_UNLESS
-
-#endif // __LUCI_VERIFY_QUNTIZED_NODE_S16_TYPE_H__
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp b/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp
new file mode 100644
index 000000000..4e1c062c0
--- /dev/null
+++ b/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp
@@ -0,0 +1,554 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "VerifyQuantizedNodeType.h"
+
+#include <cmath>
+#include <memory>
+
+// This macro is undef at the end of the file
+#define RETURN_FALSE_UNLESS(ARG) \
+ if (not(ARG)) \
+ { \
+ return false; \
+ }
+
+namespace luci
+{
+
+std::shared_ptr<VerifyQuantizedNodeType> VerifyQuantizedNodeType::create(loco::DataType dtype)
+{
+ if (dtype == loco::DataType::U8)
+ return std::make_shared<VerifyQuantizedNodeU8Type>();
+ else if (dtype == loco::DataType::S16)
+ return std::make_shared<VerifyQuantizedNodeS16Type>();
+ else
+ throw std::domain_error("Not supported Quantized type");
+}
+
+} // namespace luci
+
+namespace luci
+{
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleAdd *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleArgMax *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, node->output_type()))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->dimension(), loco::DataType::S32) ||
+ has_type(node->dimension(), loco::DataType::S64))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleAveragePool2D *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleBatchToSpaceND *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleCast *node)
+{
+ auto *input = loco::must_cast<luci::CircleNode *>(node->x());
+ bool input_quantized = input->quantparam() != nullptr;
+ if (input_quantized)
+ {
+ RETURN_FALSE_UNLESS(has_type(input, node->in_data_type()))
+ RETURN_FALSE_UNLESS(has_type(input, Qtype))
+ }
+
+ bool node_quantized = node->quantparam() != nullptr;
+ if (node_quantized)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, node->out_data_type()))
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ }
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleConv2D *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->filter(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->bias(), Btype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleConcatenation *node)
+{
+ // Allow concatenation of indices
+ if (group_has_type(node, loco::DataType::S32) or group_has_type(node, loco::DataType::S64))
+ return true;
+
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleDepthToSpace *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleDepthwiseConv2D *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->filter(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->bias(), Btype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleDiv *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleElu *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleFloor *node)
+{
+ RETURN_FALSE_UNLESS(group_has_type(node, Qtype));
+
+ // This checks the value of scale is an integer
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]);
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleFloorDiv *node)
+{
+ RETURN_FALSE_UNLESS(group_has_type(node, Qtype));
+
+ // This checks the value of scale is an integer
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]);
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleFullyConnected *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->weights(), Qtype))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(has_type(bias, Btype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleGreater *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, loco::DataType::BOOL))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleGreaterEqual *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, loco::DataType::BOOL))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleInstanceNorm *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(
+ const luci::CircleLocalResponseNormalization *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleLogicalOr *node)
+{
+ return group_has_type(node, loco::DataType::BOOL);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleMaxPool2D *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleMean *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->reduction_indices(), loco::DataType::S32))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleMirrorPad *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->paddings(), loco::DataType::S32))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleMul *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleNotEqual *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, loco::DataType::BOOL))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleOneHot *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype));
+ RETURN_FALSE_UNLESS(has_type(node->indices(), loco::DataType::S32) ||
+ has_type(node->indices(), loco::DataType::S64));
+ RETURN_FALSE_UNLESS(has_type(node->depth(), loco::DataType::S32));
+ RETURN_FALSE_UNLESS(has_type(node->on_value(), Qtype));
+ RETURN_FALSE_UNLESS(has_type(node->off_value(), Qtype));
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePack *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePad *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->paddings(), loco::DataType::S32))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePadV2 *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->paddings(), loco::DataType::S32))
+ RETURN_FALSE_UNLESS(has_type(node->constant_values(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePRelu *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePow *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleRelu *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleReshape *node)
+{
+ if (node->quantparam())
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->tensor(), Qtype))
+ }
+ else
+ {
+ RETURN_FALSE_UNLESS(has_type(node->tensor(), node->dtype()))
+ }
+ luci::CircleConst *shape = dynamic_cast<luci::CircleConst *>(node->shape());
+ if (shape != nullptr)
+ RETURN_FALSE_UNLESS(has_type(shape, loco::DataType::S32))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleResizeBilinear *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleResizeNearestNeighbor *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleRsqrt *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSlice *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->begin(), loco::DataType::S32) ||
+ has_type(node->begin(), loco::DataType::S64))
+ RETURN_FALSE_UNLESS(has_type(node->size(), loco::DataType::S32) ||
+ has_type(node->size(), loco::DataType::S64))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSpaceToBatchND *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSpaceToDepth *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSplit *node)
+{
+ // node's output is the input of CircleSplitOut, thus not quantized
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSplitOut *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+
+ // SplitOut has the same qparam with the input of Split
+ auto split = loco::must_cast<luci::CircleSplit *>(node->input());
+ auto input = loco::must_cast<luci::CircleNode *>(split->input());
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSplitV *node)
+{
+ // node's output is the input of CircleSplitVOut, thus not quantized
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSplitVOut *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+
+ // SplitVOut has the same qparam with the input of SplitV
+ auto splitv = loco::must_cast<luci::CircleSplitV *>(node->input());
+ auto input = loco::must_cast<luci::CircleNode *>(splitv->input());
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSqrt *node)
+{
+ return group_has_type(node, Qtype);
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleStridedSlice *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+
+ auto input = loco::must_cast<luci::CircleNode *>(node->input());
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleTranspose *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->a(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->perm(), loco::DataType::S32))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleTransposeConv *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->outBackprop(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->filter(), Qtype))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(has_type(bias, Btype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleUnpack *node)
+{
+ // node's output is the input of CircleUnpackOut, thus not quantized
+ RETURN_FALSE_UNLESS(has_type(node->value(), Qtype))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleUnpackOut *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+
+ // UnpackOut has the same qparam with the input of Unpack
+ auto Unpack = loco::must_cast<luci::CircleUnpack *>(node->input());
+ auto input = loco::must_cast<luci::CircleNode *>(Unpack->value());
+ RETURN_FALSE_UNLESS(node->quantparam() && input->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
+ return true;
+}
+
+} // namespace luci
+
+namespace luci
+{
+
+bool VerifyQuantizedNodeU8Type::visit(const luci::CircleTanh *node)
+{
+ RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::U8));
+
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 2.0f / 256.0f);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 128);
+ return true;
+}
+
+bool VerifyQuantizedNodeU8Type::visit(const luci::CircleLogistic *node)
+{
+ RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::U8));
+
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 256.0f);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
+ return true;
+}
+
+bool VerifyQuantizedNodeU8Type::visit(const luci::CircleSoftmax *node)
+{
+ RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::U8));
+
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 255.0f);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
+ return true;
+}
+
+} // namespace luci
+
+namespace luci
+{
+
+bool VerifyQuantizedNodeS16Type::visit(const luci::CircleTanh *node)
+{
+ RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::S16));
+
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32768.0f);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
+ return true;
+}
+
+bool VerifyQuantizedNodeS16Type::visit(const luci::CircleLogistic *node)
+{
+ RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::S16));
+
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32768.0f);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
+ return true;
+}
+
+bool VerifyQuantizedNodeS16Type::visit(const luci::CircleSoftmax *node)
+{
+ RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::S16));
+
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32767.0f);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
+ return true;
+}
+
+} // namespace luci
+
+#undef RETURN_FALSE_UNLESS
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeType.h b/compiler/luci/pass/src/VerifyQuantizedNodeType.h
new file mode 100644
index 000000000..ff1acbd6f
--- /dev/null
+++ b/compiler/luci/pass/src/VerifyQuantizedNodeType.h
@@ -0,0 +1,157 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_VERIFY_QUANTIZED_NODE_TYPE_H__
+#define __LUCI_VERIFY_QUANTIZED_NODE_TYPE_H__
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+namespace luci
+{
+
+/**
+ * @brief Verify the data type of quantized node
+ * @details
+ *
+ * Targets to verify
+ * - node's output (i.e., node itself)
+ * - node's inputs
+ */
+class VerifyQuantizedNodeType
+{
+public:
+ static std::shared_ptr<VerifyQuantizedNodeType> create(loco::DataType dtype);
+
+public:
+ virtual bool verify(luci::CircleNode *node) = 0;
+};
+
+/**
+ * @brief Verify using quantization type of a node and bias
+ *
+ * @tparam Qtype Quantization type for a node (e.g. Q8, Q16, ...)
+ * @tparam Btype Bias quantization type (e.g. For Q8, S32 is used)
+ */
+template <loco::DataType Qtype, loco::DataType Btype>
+class VerifyQuantizedNodeTypeBase : public luci::CircleNodeVisitor<bool>,
+ public VerifyQuantizedNodeType
+{
+public:
+ bool verify(luci::CircleNode *node) { return node->accept(this); }
+
+protected:
+ bool has_type(const loco::Node *node, loco::DataType dtype)
+ {
+ auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
+ return circle_node->dtype() == dtype;
+ }
+
+ // Check whether a node and all of its inputs have dtype or not
+ bool group_has_type(const loco::Node *node, loco::DataType dtype)
+ {
+ if (!has_type(node, dtype))
+ return false;
+
+ for (uint32_t i = 0; i < node->arity(); ++i)
+ if (!has_type(node->arg(i), dtype))
+ return false;
+
+ return true;
+ }
+
+private:
+ bool visit(const luci::CircleAdd *node);
+ bool visit(const luci::CircleArgMax *node);
+ bool visit(const luci::CircleAveragePool2D *node);
+ bool visit(const luci::CircleBatchToSpaceND *node);
+ bool visit(const luci::CircleCast *node);
+ bool visit(const luci::CircleConv2D *node);
+ bool visit(const luci::CircleConcatenation *node);
+ bool visit(const luci::CircleDepthToSpace *node);
+ bool visit(const luci::CircleDepthwiseConv2D *node);
+ bool visit(const luci::CircleDiv *node);
+ bool visit(const luci::CircleElu *node);
+ bool visit(const luci::CircleFloor *node);
+ bool visit(const luci::CircleFloorDiv *node);
+ bool visit(const luci::CircleFullyConnected *node);
+ bool visit(const luci::CircleGreater *node);
+ bool visit(const luci::CircleGreaterEqual *node);
+ bool visit(const luci::CircleInstanceNorm *node);
+ bool visit(const luci::CircleLocalResponseNormalization *node);
+ bool visit(const luci::CircleLogicalOr *node);
+ bool visit(const luci::CircleMaxPool2D *node);
+ bool visit(const luci::CircleMean *node);
+ bool visit(const luci::CircleMirrorPad *node);
+ bool visit(const luci::CircleMul *node);
+ bool visit(const luci::CircleNotEqual *node);
+ bool visit(const luci::CircleOneHot *node);
+ bool visit(const luci::CirclePack *node);
+ bool visit(const luci::CirclePad *node);
+ bool visit(const luci::CirclePadV2 *node);
+ bool visit(const luci::CirclePRelu *node);
+ bool visit(const luci::CirclePow *node);
+ bool visit(const luci::CircleRelu *node);
+ bool visit(const luci::CircleReshape *node);
+ bool visit(const luci::CircleResizeBilinear *node);
+ bool visit(const luci::CircleResizeNearestNeighbor *node);
+ bool visit(const luci::CircleRsqrt *node);
+ bool visit(const luci::CircleSlice *node);
+ bool visit(const luci::CircleSpaceToBatchND *node);
+ bool visit(const luci::CircleSpaceToDepth *node);
+ bool visit(const luci::CircleSplit *node);
+ bool visit(const luci::CircleSplitOut *node);
+ bool visit(const luci::CircleSplitV *node);
+ bool visit(const luci::CircleSplitVOut *node);
+ bool visit(const luci::CircleSqrt *node);
+ bool visit(const luci::CircleStridedSlice *node);
+ bool visit(const luci::CircleTranspose *node);
+ bool visit(const luci::CircleTransposeConv *node);
+ bool visit(const luci::CircleUnpack *node);
+ bool visit(const luci::CircleUnpackOut *node);
+
+ // NOTE below nodes has differnent implementation for Qtype/Btype and
+ // implementations exist in VerifyQuantizedNodeU8Type, VerifyQuantizedNodeS16Type
+ // bool visit(const luci::CircleLogistic *node);
+ // bool visit(const luci::CircleSoftmax *node);
+ // bool visit(const luci::CircleTanh *node);
+
+ // TODO: Implement more Ops
+
+ bool visit(const luci::CircleNode *) { return true; }
+};
+
+class VerifyQuantizedNodeU8Type
+ : public VerifyQuantizedNodeTypeBase<loco::DataType::U8, loco::DataType::S32>
+{
+private:
+ bool visit(const luci::CircleLogistic *node);
+ bool visit(const luci::CircleSoftmax *node);
+ bool visit(const luci::CircleTanh *node);
+};
+
+class VerifyQuantizedNodeS16Type
+ : public VerifyQuantizedNodeTypeBase<loco::DataType::S16, loco::DataType::S64>
+{
+private:
+ bool visit(const luci::CircleLogistic *node);
+ bool visit(const luci::CircleSoftmax *node);
+ bool visit(const luci::CircleTanh *node);
+};
+
+} // namespace luci
+
+#endif // __LUCI_VERIFY_QUANTIZED_NODE_TYPE_H__
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h b/compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h
deleted file mode 100644
index e7dd1b072..000000000
--- a/compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h
+++ /dev/null
@@ -1,518 +0,0 @@
-/*
- * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __LUCI_VERIFY_QUANTIZED_NODE_U8_TYPE_H__
-#define __LUCI_VERIFY_QUANTIZED_NODE_U8_TYPE_H__
-
-#include <luci/IR/CircleNodes.h>
-#include <luci/IR/CircleNodeVisitor.h>
-
-#include <cmath>
-
-using Type = loco::DataType;
-
-// This macro is undef at the end of the file
-#define RETURN_FALSE_UNLESS(ARG) \
- if (not(ARG)) \
- { \
- return false; \
- }
-
-namespace luci
-{
-
-/**
- * @brief Verify the data type of UINT8 quantized node
- * @details
- *
- * Targets to verify
- * - node's output (i.e., node itself)
- * - node's inputs
- */
-struct VerifyQuantizedNodeU8Type final : public luci::CircleNodeVisitor<bool>
-{
-private:
- bool has_type(const loco::Node *node, Type dtype)
- {
- auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
- return circle_node->dtype() == dtype;
- }
-
-private:
- bool visit(const luci::CircleConv2D *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->filter(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleConcatenation *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- for (uint32_t i = 0; i < node->numValues(); i++)
- {
- RETURN_FALSE_UNLESS(has_type(node->values(i), Type::U8))
- }
- return true;
- }
-
- bool visit(const luci::CircleDepthToSpace *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleDepthwiseConv2D *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->filter(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleInstanceNorm *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->gamma(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->beta(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CirclePack *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- for (uint32_t i = 0; i < node->values_count(); i++)
- {
- RETURN_FALSE_UNLESS(has_type(node->values(i), Type::U8))
- }
- return true;
- }
-
- bool visit(const luci::CirclePad *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CirclePadV2 *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32))
- RETURN_FALSE_UNLESS(has_type(node->constant_values(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleMirrorPad *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CirclePRelu *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->alpha(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleTransposeConv *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->outBackprop(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->filter(), Type::U8))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(has_type(bias, Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleFullyConnected *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->weights(), Type::U8))
- luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
- if (bias != nullptr)
- RETURN_FALSE_UNLESS(has_type(bias, Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleAdd *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleAveragePool2D *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->value(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleBatchToSpaceND *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleLogicalOr *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::BOOL))
- return true;
- }
-
- bool visit(const luci::CircleMaxPool2D *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->value(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleLocalResponseNormalization *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleMean *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->reduction_indices(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleMul *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleNotEqual *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleRelu *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->features(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleReshape *node)
- {
- if (node->quantparam())
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->tensor(), Type::U8))
- }
- else
- {
- RETURN_FALSE_UNLESS(has_type(node->tensor(), node->dtype()))
- }
- luci::CircleConst *shape = dynamic_cast<luci::CircleConst *>(node->shape());
- if (shape != nullptr)
- RETURN_FALSE_UNLESS(has_type(shape, Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleLogistic *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
-
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 256.0f);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
- return true;
- }
-
- bool visit(const luci::CircleSoftmax *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->logits(), Type::U8))
-
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 255.0f);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
- return true;
- }
-
- bool visit(const luci::CircleSpaceToBatchND *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleSpaceToDepth *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleSlice *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->begin(), Type::S32) || has_type(node->begin(), Type::S64))
- RETURN_FALSE_UNLESS(has_type(node->size(), Type::S32) || has_type(node->size(), Type::S64))
- return true;
- }
-
- bool visit(const luci::CircleSplit *node)
- {
- // node's output is the input of CircleSplitOut, thus not quantized
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleSplitOut *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
-
- // SplitOut has the same qparam with the input of Split
- auto split = loco::must_cast<luci::CircleSplit *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(split->input());
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
- return true;
- }
-
- bool visit(const luci::CircleSplitV *node)
- {
- // node's output is the input of CircleSplitVOut, thus not quantized
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleSplitVOut *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
-
- // SplitVOut has the same qparam with the input of SplitV
- auto splitv = loco::must_cast<luci::CircleSplitV *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(splitv->input());
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
- return true;
- }
-
- bool visit(const luci::CircleStridedSlice *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
-
- auto input = loco::must_cast<luci::CircleNode *>(node->input());
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
- return true;
- }
-
- bool visit(const luci::CircleArgMax *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, node->output_type()))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->dimension(), Type::S32) ||
- has_type(node->dimension(), Type::S64))
- return true;
- }
-
- bool visit(const luci::CircleTanh *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
-
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 2.0f / 256.0f);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 128);
- return true;
- }
-
- bool visit(const luci::CircleTranspose *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->a(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->perm(), Type::S32))
- return true;
- }
-
- bool visit(const luci::CircleFloor *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
-
- // This checks the value of scale is an integer
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]);
- return true;
- }
-
- bool visit(const luci::CircleGreater *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleGreaterEqual *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleDiv *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleFloorDiv *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
-
- // This checks the value of scale is an integer
- RETURN_FALSE_UNLESS(node->quantparam());
- RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]);
- return true;
- }
-
- bool visit(const luci::CircleRsqrt *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleSqrt *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleElu *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->features(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CirclePow *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleResizeBilinear *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleResizeNearestNeighbor *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleUnpack *node)
- {
- // node's output is the input of CircleUnpackOut, thus not quantized
- RETURN_FALSE_UNLESS(has_type(node->value(), Type::U8))
- return true;
- }
-
- bool visit(const luci::CircleUnpackOut *node)
- {
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
-
- // UnpackOut has the same qparam with the input of Unpack
- auto Unpack = loco::must_cast<luci::CircleUnpack *>(node->input());
- auto input = loco::must_cast<luci::CircleNode *>(Unpack->value());
- RETURN_FALSE_UNLESS(node->quantparam() && input->quantparam());
- RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
- RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
- return true;
- }
-
- bool visit(const luci::CircleCast *node)
- {
- auto *input = loco::must_cast<luci::CircleNode *>(node->x());
- bool input_quantized = input->quantparam() != nullptr;
- if (input_quantized)
- {
- RETURN_FALSE_UNLESS(has_type(input, node->in_data_type()))
- RETURN_FALSE_UNLESS(has_type(input, Type::U8))
- }
-
- bool node_quantized = node->quantparam() != nullptr;
- if (node_quantized)
- {
- RETURN_FALSE_UNLESS(has_type(node, node->out_data_type()))
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- }
- return true;
- }
-
- // TODO: Implement more Ops
-
- bool visit(const luci::CircleNode *) { return true; }
-};
-
-} // namespace luci
-
-#undef RETURN_FALSE_UNLESS
-
-#endif // __LUCI_VERIFY_QUNTIZED_NODE_U8_TYPE_H__
diff --git a/compiler/luci/pass/src/helpers/LayerInfoMap.cpp b/compiler/luci/pass/src/helpers/LayerInfoMap.cpp
new file mode 100644
index 000000000..ac07f9ec9
--- /dev/null
+++ b/compiler/luci/pass/src/helpers/LayerInfoMap.cpp
@@ -0,0 +1,189 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "LayerInfoMap.h"
+
+#include <luci/IR/CircleNode.h>
+
+#include <cassert>
+
+namespace luci
+{
+namespace
+{
+
+bool is_multiple_output_node(const luci::CircleNode *node)
+{
+ switch (node->opcode())
+ {
+ // The following nodes have multiple outputs. Output tensors are not produced by themselves but
+ // by the corresponding *Out nodes.
+ case luci::CircleOpcode::SPLIT:
+ case luci::CircleOpcode::SPLIT_V:
+ case luci::CircleOpcode::TOPK_V2:
+ case luci::CircleOpcode::UNIQUE:
+ case luci::CircleOpcode::UNPACK:
+ return true;
+ // TODO: Support ops
+ case luci::CircleOpcode::BIDIRECTIONAL_SEQUENCE_LSTM:
+ case luci::CircleOpcode::CUSTOM:
+ case luci::CircleOpcode::IF:
+ case luci::CircleOpcode::NON_MAX_SUPPRESSION_V4:
+ case luci::CircleOpcode::NON_MAX_SUPPRESSION_V5:
+ case luci::CircleOpcode::WHILE:
+ throw std::runtime_error("Unsupported op now");
+ default:
+ return false;
+ }
+}
+
+const luci::CircleNode *get_multi_output_node(const luci::CircleNode *node)
+{
+ if (is_multiple_output_node(node))
+ return node;
+
+ switch (node->opcode())
+ {
+ // The following nodes denote outputs of multiple-output nodes.
+ case luci::CircleOpcode::CIRCLESPLITOUT:
+ {
+ const auto split_out = loco::must_cast<const CircleSplitOut *>(node);
+ return loco::must_cast<luci::CircleNode *>(split_out->input());
+ }
+ case luci::CircleOpcode::CIRCLESPLITVOUT:
+ {
+ const auto splitv_out = loco::must_cast<const CircleSplitVOut *>(node);
+ return loco::must_cast<luci::CircleNode *>(splitv_out->input());
+ }
+ case luci::CircleOpcode::CIRCLETOPKV2OUT:
+ {
+ const auto top_kv2_out = loco::must_cast<const CircleTopKV2Out *>(node);
+ return loco::must_cast<luci::CircleNode *>(top_kv2_out->input());
+ }
+ case luci::CircleOpcode::CIRCLEUNIQUEOUT:
+ {
+ const auto unique_out = loco::must_cast<const CircleUniqueOut *>(node);
+ return loco::must_cast<luci::CircleNode *>(unique_out->input());
+ }
+ case luci::CircleOpcode::CIRCLEUNPACKOUT:
+ {
+ const auto unpack_out = loco::must_cast<const CircleUnpackOut *>(node);
+ return loco::must_cast<luci::CircleNode *>(unpack_out->input());
+ }
+ // TODO: Support these ops
+ case luci::CircleOpcode::CIRCLEBIDIRECTIONAL_SEQUENCE_LSTM_OUT:
+ case luci::CircleOpcode::CIRCLECUSTOMOUT:
+ case luci::CircleOpcode::CIRCLEIFOUT:
+ case luci::CircleOpcode::CIRCLENONMAXSUPPRESSIONV4OUT:
+ case luci::CircleOpcode::CIRCLENONMAXSUPPRESSIONV5OUT:
+ case luci::CircleOpcode::CIRCLEWHILEOUT:
+ throw std::runtime_error("Unsupported op now");
+ default:
+ return nullptr;
+ }
+}
+
+bool same_setting(const LayerInfo &left, const LayerInfo &right)
+{
+ return left.dtype == right.dtype and left.granularity == right.granularity;
+}
+
+void add_multi_output_node(LayerInfoMap &info_by_name, LayerInfo &layer_info,
+ const luci::CircleNode *node)
+{
+ assert(is_multiple_output_node(node)); // FIX_CALLER_UNLESS
+
+ const auto succs_nodes = loco::succs(node);
+ const auto name = node->name();
+
+ if (info_by_name.find(name) != info_by_name.end())
+ {
+ // Check that all outputs have equal dtype and granularity
+ for (const auto succs_node : succs_nodes)
+ {
+ const auto succs_circle_node = loco::must_cast<luci::CircleNode *>(succs_node);
+
+ const auto it = info_by_name.find(succs_circle_node->name());
+ if (it != info_by_name.end() and not same_setting(layer_info, (it->second)))
+ throw std::runtime_error("Outputs of multiple-output nodes should have equal dtype and "
+ "granularity. Check the quantization configuration file");
+ }
+ return;
+ }
+
+ // Add multiple output node to info_by_name
+ info_by_name[name] = {name, layer_info.dtype, layer_info.granularity};
+
+ // Add outputs node to info_by_name
+ for (const auto succs_node : succs_nodes)
+ {
+ const auto succs_circle_node = loco::must_cast<luci::CircleNode *>(succs_node);
+ const auto succs_circle_node_name = succs_circle_node->name();
+ info_by_name[succs_circle_node_name] = {succs_circle_node_name, layer_info.dtype,
+ layer_info.granularity};
+ }
+}
+
+} // namespace
+
+LayerInfoMap layer_info_map(loco::Graph *g, std::vector<LayerInfo> &layers_info)
+{
+ LayerInfoMap info_by_name;
+
+ for (auto &&info : layers_info)
+ {
+ auto name = info.name;
+ bool found = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto cnode = loco::must_cast<luci::CircleNode *>(node);
+ if (cnode->opcode() == luci::CircleOpcode::CIRCLEOUTPUT)
+ continue;
+
+ if (cnode->name() == name)
+ {
+ // Check and add multiple-output node and its outputs to info_by_name
+ if (const auto multi_output = get_multi_output_node(cnode))
+ {
+ add_multi_output_node(info_by_name, info, multi_output);
+ found = true;
+ continue;
+ }
+
+ if (info_by_name.find(name) != info_by_name.end())
+ {
+ throw std::runtime_error("Duplicate layer name " + name +
+ ". Check layer names in the quantization configuration file.");
+ }
+
+ info_by_name[name] = info;
+ found = true;
+ continue;
+ }
+ }
+
+ if (not found)
+ throw std::runtime_error("No such layer named " + name +
+ ". Check layer names in the quantization configuration file.");
+ }
+
+ // TODO Check all names in layers_info exist in the info_by_name
+ // TODO Check names in info_by_name but not in layers_info are from virtual outputs
+
+ return info_by_name;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/helpers/LayerInfoMap.h b/compiler/luci/pass/src/helpers/LayerInfoMap.h
new file mode 100644
index 000000000..bb4724a50
--- /dev/null
+++ b/compiler/luci/pass/src/helpers/LayerInfoMap.h
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_PASS_HELPERS_LAYER_INFO_MAP_H__
+#define __LUCI_PASS_HELPERS_LAYER_INFO_MAP_H__
+
+#include <luci/Pass/QuantizationParameters.h>
+
+#include <unordered_map>
+
+namespace luci
+{
+
+using LayerInfoMap = std::unordered_map<std::string, luci::LayerInfo>;
+
+LayerInfoMap layer_info_map(loco::Graph *g, std::vector<LayerInfo> &layers_info);
+
+} // namespace luci
+
+#endif // __LUCI_PASS_HELPERS_LAYER_INFO_MAP_H__
diff --git a/compiler/luci/pass/src/helpers/LayerInfoMap.test.cpp b/compiler/luci/pass/src/helpers/LayerInfoMap.test.cpp
new file mode 100644
index 000000000..2ed28eda4
--- /dev/null
+++ b/compiler/luci/pass/src/helpers/LayerInfoMap.test.cpp
@@ -0,0 +1,201 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "LayerInfoMap.h"
+
+#include <luci/IR/CircleNode.h>
+#include <luci/test/TestIOGraph.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+class SoftmaxTestGraph : public luci::test::TestIOGraph
+{
+public:
+ void init(void)
+ {
+ TestIOGraph::init({32}, {32});
+ _softmax = g()->nodes()->create<luci::CircleSoftmax>();
+ {
+ _softmax->logits(input());
+ _softmax->beta(0.1);
+ _softmax->name("test");
+ }
+ output()->from(_softmax);
+ }
+
+private:
+ luci::CircleSoftmax *_softmax = nullptr;
+};
+
+class SplitAddTestGraph : public luci::test::TestIOGraph
+{
+public:
+ void init(void)
+ {
+ TestIOGraph::init({6, 1, 2}, {3, 1, 2});
+ _split_dim = g()->nodes()->create<luci::CircleConst>();
+ {
+ _split_dim->rank(1);
+ _split_dim->dtype(loco::DataType::S32);
+ _split_dim->size<loco::DataType::S32>(1);
+ _split_dim->at<loco::DataType::S32>(0);
+ _split_dim->shape({1});
+ _split_dim->name("split_dim");
+ }
+
+ _split = g()->nodes()->create<luci::CircleSplit>();
+ {
+ _split->input(input());
+ _split->num_split(2);
+ _split->split_dim(_split_dim);
+ _split->name("split0");
+ }
+
+ _split_out_1 = g()->nodes()->create<luci::CircleSplitOut>();
+ {
+ _split_out_1->input(_split);
+ _split_out_1->index(0);
+ _split_out_1->name("split0");
+ }
+
+ _split_out_2 = g()->nodes()->create<luci::CircleSplitOut>();
+ {
+ _split_out_2->input(_split);
+ _split_out_2->index(1);
+ _split_out_2->name("split1");
+ }
+
+ _add = g()->nodes()->create<luci::CircleAdd>();
+ {
+ _add->x(_split_out_1);
+ _add->y(_split_out_2);
+ _add->name("add");
+ }
+ output()->from(_add);
+ }
+
+private:
+ luci::CircleSplit *_split = nullptr;
+ luci::CircleSplitOut *_split_out_1 = nullptr;
+ luci::CircleSplitOut *_split_out_2 = nullptr;
+ luci::CircleConst *_split_dim = nullptr;
+ luci::CircleAdd *_add = nullptr;
+};
+
+} // namespace
+
+TEST(LayerInfoMapTest, simple_test)
+{
+ SoftmaxTestGraph g;
+ g.init();
+
+ luci::LayerInfo info;
+ {
+ info.name = "test";
+ info.dtype = loco::DataType::U8;
+ info.granularity = luci::QuantizationGranularity::ChannelWise;
+ }
+ std::vector<luci::LayerInfo> v;
+ v.emplace_back(info);
+ auto map = luci::layer_info_map(g.g(), v);
+
+ EXPECT_EQ("test", map["test"].name);
+ EXPECT_EQ(loco::DataType::U8, map["test"].dtype);
+ EXPECT_EQ(luci::QuantizationGranularity::ChannelWise, map["test"].granularity);
+}
+
+TEST(LayerInfoMapTest, multiple_output_node_test)
+{
+ SplitAddTestGraph g;
+ g.init();
+
+ luci::LayerInfo info;
+ {
+ info.name = "split0";
+ info.dtype = loco::DataType::U8;
+ info.granularity = luci::QuantizationGranularity::ChannelWise;
+ }
+ std::vector<luci::LayerInfo> v;
+ v.emplace_back(info);
+ auto map = luci::layer_info_map(g.g(), v);
+
+ EXPECT_EQ(map.size(), 2);
+ EXPECT_EQ("split0", map["split0"].name);
+ EXPECT_EQ("split1", map["split1"].name);
+
+ EXPECT_EQ(loco::DataType::U8, map["split0"].dtype);
+ EXPECT_EQ(luci::QuantizationGranularity::ChannelWise, map["split0"].granularity);
+}
+
+TEST(LayerInfoMapTest, invalid_layer_info_multiple_output_node_NEG)
+{
+ SplitAddTestGraph g;
+ g.init();
+
+ luci::LayerInfo info_0;
+ {
+ info_0.name = "split0";
+ info_0.dtype = loco::DataType::U8;
+ info_0.granularity = luci::QuantizationGranularity::ChannelWise;
+ }
+ luci::LayerInfo info_1;
+ {
+ info_1.name = "split1";
+ info_1.dtype = loco::DataType::S16;
+ info_1.granularity = luci::QuantizationGranularity::ChannelWise;
+ }
+ std::vector<luci::LayerInfo> v;
+ v.emplace_back(info_0);
+ v.emplace_back(info_1);
+
+ EXPECT_ANY_THROW(luci::layer_info_map(g.g(), v));
+}
+
+TEST(LayerInfoMapTest, duplicate_name_NEG)
+{
+ SoftmaxTestGraph g;
+ g.init();
+ g.input()->name("test");
+
+ luci::LayerInfo info;
+ {
+ info.name = "test";
+ info.dtype = loco::DataType::U8;
+ info.granularity = luci::QuantizationGranularity::ChannelWise;
+ }
+ std::vector<luci::LayerInfo> v;
+ v.emplace_back(info);
+ EXPECT_ANY_THROW(luci::layer_info_map(g.g(), v));
+}
+
+TEST(LayerInfoMapTest, no_name_NEG)
+{
+ SoftmaxTestGraph g;
+ g.init();
+
+ luci::LayerInfo info;
+ {
+ info.name = "noname";
+ info.dtype = loco::DataType::U8;
+ info.granularity = luci::QuantizationGranularity::ChannelWise;
+ }
+ std::vector<luci::LayerInfo> v;
+ v.emplace_back(info);
+ EXPECT_ANY_THROW(luci::layer_info_map(g.g(), v));
+}
diff --git a/compiler/luci/requires.cmake b/compiler/luci/requires.cmake
index 3ccc58128..e896188be 100644
--- a/compiler/luci/requires.cmake
+++ b/compiler/luci/requires.cmake
@@ -4,8 +4,8 @@ require("loco")
require("locop")
require("logo")
require("logo-core")
-require("mio-circle")
-require("mio-tflite")
+require("mio-circle04")
+require("mio-tflite280")
require("oops")
require("hermes")
require("hermes-std")
diff --git a/compiler/luci/service/CMakeLists.txt b/compiler/luci/service/CMakeLists.txt
index 0e6097f96..24bdfc152 100644
--- a/compiler/luci/service/CMakeLists.txt
+++ b/compiler/luci/service/CMakeLists.txt
@@ -10,7 +10,6 @@ add_library(luci_service ${LUCI_LIBRARY_TYPE} ${SOURCES})
target_include_directories(luci_service PRIVATE src)
target_include_directories(luci_service PUBLIC include)
target_link_libraries(luci_service PUBLIC luci_lang)
-target_link_libraries(luci_service PUBLIC mio_circle)
target_link_libraries(luci_service PUBLIC logo_core)
target_link_libraries(luci_service PRIVATE luci_log)
target_link_libraries(luci_service PRIVATE luci_logex)
diff --git a/compiler/luci/service/include/luci/Service/CircleShapeInference.h b/compiler/luci/service/include/luci/Service/CircleShapeInference.h
index ead12d074..2c1120941 100644
--- a/compiler/luci/service/include/luci/Service/CircleShapeInference.h
+++ b/compiler/luci/service/include/luci/Service/CircleShapeInference.h
@@ -17,11 +17,12 @@
#ifndef __LUCI_CIRCLE_SHAPE_INFERENCE_H__
#define __LUCI_CIRCLE_SHAPE_INFERENCE_H__
-#include <loco/IR/Nodes.h>
-
+#include <luci/Service/CircleShapeInferenceRule.h>
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
-#include <luci/Service/CircleShapeInferenceRule.h>
+
+#include <loco/IR/NodeShape.h>
+#include <loco/IR/TensorShape.h>
namespace luci
{
diff --git a/compiler/luci/service/include/luci/Service/CircleTypeInference.h b/compiler/luci/service/include/luci/Service/CircleTypeInference.h
index d62731380..e0ceabeac 100644
--- a/compiler/luci/service/include/luci/Service/CircleTypeInference.h
+++ b/compiler/luci/service/include/luci/Service/CircleTypeInference.h
@@ -17,13 +17,11 @@
#ifndef __LUCI_CIRCLE_TYPE_INFERENCE_H__
#define __LUCI_CIRCLE_TYPE_INFERENCE_H__
-#include <loco/IR/Nodes.h>
-
-#include <mio/circle/schema_generated.h>
-
+#include <luci/Service/CircleTypeInferenceRule.h>
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
-#include <luci/Service/CircleTypeInferenceRule.h>
+
+#include <loco/IR/DataType.h>
namespace luci
{
diff --git a/compiler/luci/service/src/CircleCloneNode.h b/compiler/luci/service/src/CircleCloneNode.h
index 3926147f5..99e4561b3 100644
--- a/compiler/luci/service/src/CircleCloneNode.h
+++ b/compiler/luci/service/src/CircleCloneNode.h
@@ -208,6 +208,7 @@ public:
luci::CircleNode *visit(const luci::CircleSquaredDifference *) final;
luci::CircleNode *visit(const luci::CircleSqueeze *) final;
luci::CircleNode *visit(const luci::CircleStridedSlice *) final;
+ luci::CircleNode *visit(const luci::CircleSVDF *) final;
luci::CircleNode *visit(const luci::CircleSub *) final;
luci::CircleNode *visit(const luci::CircleSum *) final;
luci::CircleNode *visit(const luci::CircleTanh *) final;
@@ -269,6 +270,7 @@ public:
luci::CircleNode *visit(const luci::CircleTopKV2Out *) final;
luci::CircleNode *visit(const luci::CircleUniqueOut *) final;
luci::CircleNode *visit(const luci::CircleUnpackOut *) final;
+ luci::CircleNode *visit(const luci::CircleVariable *) final;
luci::CircleNode *visit(const luci::CircleWhileOut *) final;
// Handle in CircleNode
diff --git a/compiler/luci/service/src/CircleNodeClone.cpp b/compiler/luci/service/src/CircleNodeClone.cpp
index d2033dd0c..220c6096c 100644
--- a/compiler/luci/service/src/CircleNodeClone.cpp
+++ b/compiler/luci/service/src/CircleNodeClone.cpp
@@ -14,6 +14,7 @@
* limitations under the License.
*/
+#include "luci/IR/CircleQuantParam.h"
#include "luci/Service/CircleNodeClone.h"
#include "CircleCloneNode.h"
@@ -45,18 +46,7 @@ void copy_common_attributes(const luci::CircleNode *src, luci::CircleNode *dst)
dst->shape_status(src->shape_status());
// quantparam
- const auto *quantparam = src->quantparam();
- if (quantparam != nullptr)
- {
- auto qparam = std::make_unique<luci::CircleQuantParam>();
- qparam->scale = quantparam->scale;
- qparam->zerop = quantparam->zerop;
- qparam->min = quantparam->min;
- qparam->max = quantparam->max;
- qparam->quantized_dimension = quantparam->quantized_dimension;
-
- dst->quantparam(std::move(qparam));
- }
+ copy_quantparam(src, dst);
// sparsity
const auto *sparsity = src->sparsityparam();
diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.cpp
index 5d6a31050..9d156f3e2 100644
--- a/compiler/luci/service/src/CircleShapeInferenceRule.cpp
+++ b/compiler/luci/service/src/CircleShapeInferenceRule.cpp
@@ -1,5 +1,6 @@
/*
* Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -196,23 +197,18 @@ template <class CIRCLENODE> loco::NodeShape broadcast_xy(const CIRCLENODE *node)
return loco::NodeShape{output_shape};
}
-template <class CIRCLENODE> loco::NodeShape use_inputs(const CIRCLENODE *node)
-{
- auto inputs_shape = luci::shape_get(node->inputs()).template as<loco::TensorShape>();
- return loco::NodeShape{inputs_shape};
-}
+#define DECLARE_USE_SINGLE(NAME) \
+ template <class CIRCLENODE> loco::NodeShape use_##NAME(const CIRCLENODE *node) \
+ { \
+ auto inputs_shape = luci::shape_get(node->NAME()).template as<loco::TensorShape>(); \
+ return loco::NodeShape{inputs_shape}; \
+ }
-template <class CIRCLENODE> loco::NodeShape use_x(const CIRCLENODE *node)
-{
- auto x_shape = luci::shape_get(node->x()).template as<loco::TensorShape>();
- return loco::NodeShape{x_shape};
-}
+DECLARE_USE_SINGLE(inputs);
+DECLARE_USE_SINGLE(x);
+DECLARE_USE_SINGLE(logits);
-template <class CIRCLENODE> loco::NodeShape use_logits(const CIRCLENODE *node)
-{
- auto shape = luci::shape_get(node->logits()).template as<loco::TensorShape>();
- return loco::NodeShape{shape};
-}
+#undef DECLARE_USE_SINGLE
template <class CIRCLENODE>
loco::NodeShape use_paddings(const CIRCLENODE *node, const luci::CircleConst *paddings)
@@ -721,6 +717,8 @@ loco::NodeShape infer_fully_connected(const luci::CircleFullyConnected *node)
auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
auto weights_shape = luci::shape_get(node->weights()).as<loco::TensorShape>();
+// TODO Remove following unused code
+#if 0
// Checking shape capability for fully connected layer
// Input: a tensor of at least rank 2 [D1, D2, ... Dn]
// Weight: [# of units, K]
@@ -741,6 +739,40 @@ loco::NodeShape infer_fully_connected(const luci::CircleFullyConnected *node)
out_shape.rank(2);
out_shape.dim(0) = batch_size;
out_shape.dim(1) = weights_shape.dim(0);
+#endif
+
+ loco::TensorShape out_shape;
+
+ // NOTE Some recipes in some repositories are using rank 4 input for FullyConnected.
+ // Until they are all fixed, disable following assert.
+ // TODO Enable following assert after related fixes are applied
+ // https://github.com/tensorflow/tensorflow/blob/ea33c1e7a25d8025e8ee405ad8ab7be261798d76/tensorflow/lite/kernels/fully_connected.cc#L194
+ // LUCI_ASSERT(input_shape.rank() == 2 || input_shape.rank() == 3,
+ // "Input rank of FullyConnected should be 2 or 3");
+
+ // https://github.com/tensorflow/tensorflow/blob/ea33c1e7a25d8025e8ee405ad8ab7be261798d76/tensorflow/lite/kernels/fully_connected.cc#L225
+ LUCI_ASSERT(weights_shape.rank() == 2, "Weights of FullyConnected should be 2");
+
+ // https://github.com/tensorflow/tensorflow/blob/ea33c1e7a25d8025e8ee405ad8ab7be261798d76/tensorflow/lite/kernels/fully_connected.cc#L353-L367
+ if (node->keep_num_dims())
+ {
+ out_shape.rank(input_shape.rank());
+ for (uint32_t i = 0; i < input_shape.rank(); ++i)
+ out_shape.dim(i) = input_shape.dim(i);
+ out_shape.dim(out_shape.rank() - 1) = weights_shape.dim(0);
+ }
+ else
+ {
+ uint32_t input_size = 1;
+ for (uint32_t i = 0; i < input_shape.rank(); i++)
+ {
+ input_size = input_size * input_shape.dim(i).value();
+ }
+ const uint32_t batch_size = input_size / weights_shape.dim(1).value();
+ out_shape.rank(2);
+ out_shape.dim(0) = batch_size;
+ out_shape.dim(1) = weights_shape.dim(0);
+ }
return loco::NodeShape{out_shape};
}
@@ -1554,6 +1586,30 @@ loco::NodeShape infer_squeeze(const luci::CircleSqueeze *node)
return loco::NodeShape{output_shape};
}
+loco::NodeShape infer_svdf(const luci::CircleSVDF *node)
+{
+ const auto ifm_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
+ const auto weight_feature_shape = luci::shape_get(node->weight_feature()).as<loco::TensorShape>();
+
+ assert(ifm_shape.rank() == 2);
+ assert(weight_feature_shape.rank() == 2);
+
+ assert(ifm_shape.dim(1) == weight_feature_shape.dim(1));
+ assert(weight_feature_shape.dim(0).known());
+
+ const auto rank = node->svdf_rank();
+ const auto num_filters = weight_feature_shape.dim(0).value();
+ assert(num_filters % rank == 0);
+ const auto num_units = num_filters / rank;
+
+ loco::TensorShape ofm_shape;
+ ofm_shape.rank(2);
+ ofm_shape.dim(0) = ifm_shape.dim(0);
+ ofm_shape.dim(1) = num_units;
+
+ return loco::NodeShape{ofm_shape};
+}
+
loco::NodeShape infer_tile(const luci::CircleTile *node)
{
const loco::DataType S32 = loco::DataType::S32;
@@ -2393,6 +2449,8 @@ public:
return loco::NodeShape{output_shape};
}
+ loco::NodeShape visit(const luci::CircleSVDF *node) final { return infer_svdf(node); }
+
loco::NodeShape visit(const luci::CircleTanh *node) final { return use_x(node); }
loco::NodeShape visit(const luci::CircleTile *node) final { return infer_tile(node); }
@@ -2486,6 +2544,8 @@ public:
loco::NodeShape visit(const luci::CircleUnpackOut *node) final { return infer_unpack_out(node); }
+ loco::NodeShape visit(const luci::CircleVariable *node) final { return use_own(node); }
+
loco::NodeShape visit(const luci::CircleWhileOut *node) final { return infer_while_out(node); }
};
diff --git a/compiler/luci/service/src/CircleTypeInferenceRule.cpp b/compiler/luci/service/src/CircleTypeInferenceRule.cpp
index 5f6d46f2b..438c4a364 100644
--- a/compiler/luci/service/src/CircleTypeInferenceRule.cpp
+++ b/compiler/luci/service/src/CircleTypeInferenceRule.cpp
@@ -478,6 +478,11 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT
loco::DataType visit(const luci::CircleSum *node) final { return luci::dtype_get(node->input()); }
+ loco::DataType visit(const luci::CircleSVDF *node) final
+ {
+ return luci::dtype_get(node->input());
+ }
+
loco::DataType visit(const luci::CircleTanh *node) final { return luci::dtype_get(node->x()); }
loco::DataType visit(const luci::CircleTile *node) final
@@ -605,6 +610,8 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT
return loco::DataType::S32;
}
+ loco::DataType visit(const luci::CircleVariable *node) final { return node->dtype(); }
+
loco::DataType visit(const luci::CircleUniqueOut *node) final
{
if (node->index() == 0)
diff --git a/compiler/luci/service/src/Nodes/CircleSVDF.cpp b/compiler/luci/service/src/Nodes/CircleSVDF.cpp
new file mode 100644
index 000000000..d4c3ce88f
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSVDF.cpp
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleSVDF *node)
+{
+ if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return nullptr;
+
+ auto *cloned = _graph->nodes()->create<luci::CircleSVDF>();
+ if (cloned != nullptr)
+ {
+ cloned->fusedActivationFunction(node->fusedActivationFunction());
+ cloned->asymmetric_quantize_inputs(node->asymmetric_quantize_inputs());
+ cloned->svdf_rank(node->svdf_rank());
+ }
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleSVDF.test.cpp b/compiler/luci/service/src/Nodes/CircleSVDF.test.cpp
new file mode 100644
index 000000000..d6edaf1cc
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSVDF.test.cpp
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_SVDF)
+{
+ auto g = loco::make_graph();
+ auto node_svdf = g->nodes()->create<luci::CircleSVDF>();
+ node_svdf->fusedActivationFunction(luci::FusedActFunc::RELU);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_svdf, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_svdf = dynamic_cast<luci::CircleSVDF *>(cloned);
+ ASSERT_NE(nullptr, cloned_svdf);
+ ASSERT_EQ(node_svdf->asymmetric_quantize_inputs(), cloned_svdf->asymmetric_quantize_inputs());
+ ASSERT_EQ(node_svdf->svdf_rank(), cloned_svdf->svdf_rank());
+}
+
+TEST(CloneNodeTest, clone_SVDF_NEG)
+{
+ auto g = loco::make_graph();
+ auto node_svdf = g->nodes()->create<luci::CircleSVDF>();
+ node_svdf->fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_svdf, gc.get());
+ ASSERT_EQ(nullptr, cloned);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleVariable.cpp b/compiler/luci/service/src/Nodes/CircleVariable.cpp
new file mode 100644
index 000000000..c1430bd3a
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleVariable.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleVariable *)
+{
+ return _graph->nodes()->create<luci::CircleVariable>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleVariable.test.cpp b/compiler/luci/service/src/Nodes/CircleVariable.test.cpp
new file mode 100644
index 000000000..7d29438be
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleVariable.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Variable)
+{
+ auto g = loco::make_graph();
+ auto node_dummy = g->nodes()->create<luci::CircleVariable>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_dummy, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_variable = dynamic_cast<luci::CircleVariable *>(cloned);
+ ASSERT_NE(nullptr, cloned_variable);
+}
diff --git a/compiler/luci/tests/CMakeLists.txt b/compiler/luci/tests/CMakeLists.txt
index c03835823..1333efb7d 100644
--- a/compiler/luci/tests/CMakeLists.txt
+++ b/compiler/luci/tests/CMakeLists.txt
@@ -1,3 +1,14 @@
+set(CIRCLECHEF_FILE_PATH $<TARGET_FILE:circlechef-file>)
+set(TFLCHEF_FILE_PATH $<TARGET_FILE:tflchef-file>)
+set(TFLITE2CIRCLE_PATH $<TARGET_FILE:tflite2circle>)
+if(DEFINED ENV{BUILD_HOST_EXEC})
+ # TODO use better way to represent path for host executable
+ set(CIRCLECHEF_FILE_PATH $ENV{BUILD_HOST_EXEC}/compiler/circlechef/tools/file/circlechef-file)
+ set(TFLCHEF_FILE_PATH $ENV{BUILD_HOST_EXEC}/compiler/tflchef/tools/file/tflchef-file)
+ set(TFLITE2CIRCLE_PATH $ENV{BUILD_HOST_EXEC}/compiler/tflite2circle/tflite2circle)
+ message(STATUS "TFLITE2CIRCLE_PATH = ${TFLITE2CIRCLE_PATH}")
+endif(DEFINED ENV{BUILD_HOST_EXEC})
+
# TODO use local test.recipe files for small networks
file(GLOB RECIPES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*/test.recipe")
@@ -17,14 +28,14 @@ foreach(RECIPE IN ITEMS ${RECIPES})
# Generate .tflite
add_custom_command(OUTPUT "${RECIPE_OUTPUT_FILE}"
- COMMAND tflchef-file "${RECIPE_SOURCE_FILE}" "${RECIPE_OUTPUT_FILE}"
- DEPENDS tflchef-file "${RECIPE_SOURCE_FILE}"
+ COMMAND ${TFLCHEF_FILE_PATH} "${RECIPE_SOURCE_FILE}" "${RECIPE_OUTPUT_FILE}"
+ DEPENDS ${TFLCHEF_FILE_PATH} "${RECIPE_SOURCE_FILE}"
COMMENT "Generating ${RECIPE_OUTPUT_FILE}")
# Generate .circle
add_custom_command(OUTPUT "${CIRCLE_OUTPUT_FILE}"
- COMMAND tflite2circle "${RECIPE_OUTPUT_FILE}" "${CIRCLE_OUTPUT_FILE}"
- DEPENDS tflite2circle "${RECIPE_OUTPUT_FILE}"
+ COMMAND ${TFLITE2CIRCLE_PATH} "${RECIPE_OUTPUT_FILE}" "${CIRCLE_OUTPUT_FILE}"
+ DEPENDS ${TFLITE2CIRCLE_PATH} "${RECIPE_OUTPUT_FILE}"
COMMENT "Generating ${CIRCLE_OUTPUT_FILE}")
list(APPEND TESTFILES "${CIRCLE_OUTPUT_FILE}")
@@ -52,14 +63,14 @@ foreach(RECIPE IN ITEMS ${RECIPES})
# Generate .tflite
add_custom_command(OUTPUT "${RECIPE_OUTPUT_FILE}"
- COMMAND tflchef-file "${RECIPE_SOURCE_FILE}" "${RECIPE_OUTPUT_FILE}"
- DEPENDS tflchef-file "${RECIPE_SOURCE_FILE}"
+ COMMAND ${TFLCHEF_FILE_PATH} "${RECIPE_SOURCE_FILE}" "${RECIPE_OUTPUT_FILE}"
+ DEPENDS ${TFLCHEF_FILE_PATH} "${RECIPE_SOURCE_FILE}"
COMMENT "Generating ${RECIPE_OUTPUT_FILE}")
# Generate .circle
add_custom_command(OUTPUT "${CIRCLE_OUTPUT_FILE}"
- COMMAND tflite2circle "${RECIPE_OUTPUT_FILE}" "${CIRCLE_OUTPUT_FILE}"
- DEPENDS tflite2circle "${RECIPE_OUTPUT_FILE}"
+ COMMAND ${TFLITE2CIRCLE_PATH} "${RECIPE_OUTPUT_FILE}" "${CIRCLE_OUTPUT_FILE}"
+ DEPENDS ${TFLITE2CIRCLE_PATH} "${RECIPE_OUTPUT_FILE}"
COMMENT "Generating ${CIRCLE_OUTPUT_FILE}")
list(APPEND TESTFILES "${CIRCLE_OUTPUT_FILE}")
@@ -87,8 +98,8 @@ foreach(RECIPE IN ITEMS ${RECIPES2})
# Generate .circle
add_custom_command(OUTPUT "${CIRCLE_OUTPUT_FILE}"
- COMMAND circlechef-file "${RECIPE_SOURCE_FILE}" "${CIRCLE_OUTPUT_FILE}"
- DEPENDS circlechef-file "${RECIPE_SOURCE_FILE}"
+ COMMAND ${CIRCLECHEF_FILE_PATH} "${RECIPE_SOURCE_FILE}" "${CIRCLE_OUTPUT_FILE}"
+ DEPENDS ${CIRCLECHEF_FILE_PATH} "${RECIPE_SOURCE_FILE}"
COMMENT "Generating ${CIRCLE_OUTPUT_FILE}")
list(APPEND TESTFILES "${CIRCLE_OUTPUT_FILE}")
@@ -111,6 +122,8 @@ include("test.lst")
# Read "test.local.lst" if exists
include("test.local.lst" OPTIONAL)
+# NOTE $<TARGET_FILE:luci_readtester> is used as-is as test itself should
+# run in target device for cross build also
add_test(NAME luci_unit_readtest
COMMAND "${CMAKE_CURRENT_SOURCE_DIR}/readverify.sh"
"${CMAKE_CURRENT_BINARY_DIR}"
diff --git a/compiler/luci/tests/test.lst b/compiler/luci/tests/test.lst
index 28ddcf672..94e723f21 100644
--- a/compiler/luci/tests/test.lst
+++ b/compiler/luci/tests/test.lst
@@ -180,6 +180,8 @@ addread(Sub_000)
addread(Sub_U8_000)
addread(Sum_000)
addread(Sum_001)
+addread(SVDF_000)
+addread(SVDF_001)
addread(Tanh_000)
addread(Tanh_U8_000)
addread(Tile_000)
@@ -403,6 +405,8 @@ addwrite(Sub_000)
addwrite(Sub_U8_000)
addwrite(Sum_000)
addwrite(Sum_001)
+addwrite(SVDF_000)
+addwrite(SVDF_001)
addwrite(Tanh_000)
addwrite(Tanh_U8_000)
addwrite(Tile_000)
diff --git a/compiler/mio-circle/CMakeLists.txt b/compiler/mio-circle/CMakeLists.txt
index fa05ef0fa..d24717343 100644
--- a/compiler/mio-circle/CMakeLists.txt
+++ b/compiler/mio-circle/CMakeLists.txt
@@ -1,13 +1,14 @@
-nnas_find_package(FlatBuffers EXACT 1.10 QUIET)
+nnas_find_package(FlatBuffers EXACT 2.0 QUIET)
if(NOT FlatBuffers_FOUND)
+ message(STATUS "mio-circle skip: FlatBuffers 2.0 NOT FOUND")
return()
endif(NOT FlatBuffers_FOUND)
message(STATUS "Build mio-circle: TRUE")
# TODO Find a better way
-set(SCHEMA_FILE "${NNAS_PROJECT_SOURCE_DIR}/nnpackage/schema/circle_schema.fbs")
+set(SCHEMA_FILE "${NNAS_PROJECT_SOURCE_DIR}/res/CircleSchema/0.3/circle_schema.fbs")
# NOTE Copy circle_schema.fbs as schema.fbs to generate "schema_generated.fbs" instead of "circle_schema_generated.fbs"
add_custom_command(OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/schema.fbs"
@@ -26,3 +27,10 @@ FlatBuffers_Target(mio_circle
# This example shows how to use "mio-circle" library
add_executable(mio_circle_example example.cpp)
target_link_libraries(mio_circle_example mio_circle)
+
+file(GLOB_RECURSE SOURCES "src/*.cpp")
+
+add_library(mio_circle_helper STATIC ${SOURCES})
+target_include_directories(mio_circle_helper PRIVATE src)
+target_include_directories(mio_circle_helper PUBLIC include)
+target_link_libraries(mio_circle_helper mio_circle)
diff --git a/compiler/mio-circle/include/mio_circle/Helper.h b/compiler/mio-circle/include/mio_circle/Helper.h
new file mode 100644
index 000000000..c0f8115fe
--- /dev/null
+++ b/compiler/mio-circle/include/mio_circle/Helper.h
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MIO_CIRCLE_HELPER_H__
+#define __MIO_CIRCLE_HELPER_H__
+
+#include <mio/circle/schema_generated.h>
+
+namespace mio
+{
+namespace circle
+{
+
+bool is_valid(const ::circle::OperatorCode *opcode);
+bool is_custom(const ::circle::OperatorCode *opcode);
+std::string opcode_name(const ::circle::OperatorCode *opcode);
+const char *tensor_type(const ::circle::Tensor *tensor);
+const char *tensor_name(const ::circle::Tensor *tensor);
+
+} // namespace circle
+} // namespace mio
+
+#endif // __MIO_CIRCLE_HELPER_H__
diff --git a/compiler/mio-circle/src/Helper.cpp b/compiler/mio-circle/src/Helper.cpp
new file mode 100644
index 000000000..6f30c8c10
--- /dev/null
+++ b/compiler/mio-circle/src/Helper.cpp
@@ -0,0 +1,81 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "mio_circle/Helper.h"
+
+#include <sstream>
+
+namespace mio
+{
+namespace circle
+{
+
+bool is_valid(const ::circle::OperatorCode *opcode)
+{
+ ::circle::BuiltinOperator code = opcode->builtin_code();
+ return (::circle::BuiltinOperator_MIN <= code && code <= ::circle::BuiltinOperator_MAX);
+}
+
+bool is_custom(const ::circle::OperatorCode *opcode)
+{
+ ::circle::BuiltinOperator code = opcode->builtin_code();
+ return (code == ::circle::BuiltinOperator_CUSTOM);
+}
+
+std::string opcode_name(const ::circle::OperatorCode *opcode)
+{
+ assert(opcode);
+
+ if (!is_valid(opcode))
+ {
+ std::ostringstream oss;
+ oss << "(invalid)";
+ return oss.str();
+ }
+
+ if (is_custom(opcode))
+ {
+ if (!opcode->custom_code())
+ return "(invalid custom)";
+
+ std::string custom_op = "CUSTOM(";
+ custom_op += opcode->custom_code()->c_str();
+ custom_op += ")";
+ return custom_op;
+ }
+
+ ::circle::BuiltinOperator code = opcode->builtin_code();
+ return ::circle::EnumNameBuiltinOperator(code);
+}
+
+const char *tensor_type(const ::circle::Tensor *tensor)
+{
+ return ::circle::EnumNameTensorType(tensor->type());
+}
+
+const char *tensor_name(const ::circle::Tensor *tensor)
+{
+ static const char *kEmptyTensorName = "(noname)";
+
+ auto name = tensor->name();
+ if (name)
+ return name->c_str();
+
+ return kEmptyTensorName;
+}
+
+} // namespace circle
+} // namespace mio
diff --git a/compiler/mio-circle04/CMakeLists.txt b/compiler/mio-circle04/CMakeLists.txt
new file mode 100644
index 000000000..8ee6da44c
--- /dev/null
+++ b/compiler/mio-circle04/CMakeLists.txt
@@ -0,0 +1,52 @@
+nnas_find_package(FlatBuffers EXACT 2.0 QUIET)
+
+if(NOT FlatBuffers_FOUND)
+ message(STATUS "mio-circle04 skip: FlatBuffers 2.0 NOT FOUND")
+ return()
+endif(NOT FlatBuffers_FOUND)
+
+message(STATUS "Build mio-circle04: TRUE")
+
+# TODO Find a better way
+# TODO use nnpackage
+# set(SCHEMA_FILE "${NNAS_PROJECT_SOURCE_DIR}/nnpackage/schema/circle_schema.fbs")
+set(SCHEMA_FILE "${NNAS_PROJECT_SOURCE_DIR}/res/CircleSchema/0.4/circle_schema.fbs")
+
+# NOTE Copy circle_schema.fbs as schema.fbs to generate "schema_generated.fbs" instead of "circle_schema_generated.fbs"
+add_custom_command(OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/schema.fbs"
+ COMMAND ${CMAKE_COMMAND} -E copy "${SCHEMA_FILE}" schema.fbs
+ WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}"
+ DEPENDS "${SCHEMA_FILE}"
+)
+
+FlatBuffers_Target(mio_circle04
+ OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/gen/mio/circle"
+ INCLUDE_DIR "${CMAKE_CURRENT_BINARY_DIR}/gen"
+ SCHEMA_DIR "${CMAKE_CURRENT_BINARY_DIR}"
+ SCHEMA_FILES "schema.fbs"
+)
+
+# This example shows how to use "mio-circle04" library
+add_executable(mio_circle04_example example.cpp)
+target_link_libraries(mio_circle04_example mio_circle04)
+
+file(GLOB_RECURSE SOURCES "src/*.cpp")
+file(GLOB_RECURSE TESTS "src/*.test.cpp")
+list(REMOVE_ITEM SOURCES ${TESTS})
+
+add_library(mio_circle04_helper STATIC ${SOURCES})
+set_target_properties(mio_circle04_helper PROPERTIES POSITION_INDEPENDENT_CODE ON)
+target_include_directories(mio_circle04_helper PRIVATE src)
+target_include_directories(mio_circle04_helper PUBLIC include)
+target_link_libraries(mio_circle04_helper mio_circle04)
+
+if(NOT ENABLE_TEST)
+ return()
+endif(NOT ENABLE_TEST)
+
+nnas_find_package(GTest REQUIRED)
+
+GTest_AddTest(mio_circle04_helper_test ${TESTS})
+target_include_directories(mio_circle04_helper_test PRIVATE src)
+target_link_libraries(mio_circle04_helper_test mio_circle04)
+target_link_libraries(mio_circle04_helper_test mio_circle04_helper)
diff --git a/compiler/mio-circle04/README.md b/compiler/mio-circle04/README.md
new file mode 100644
index 000000000..d12dd78ff
--- /dev/null
+++ b/compiler/mio-circle04/README.md
@@ -0,0 +1,3 @@
+# mio-circle04
+
+Let's make it easy to read and write Circle models.
diff --git a/compiler/mio-circle04/example.cpp b/compiler/mio-circle04/example.cpp
new file mode 100644
index 000000000..1970f4066
--- /dev/null
+++ b/compiler/mio-circle04/example.cpp
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//
+// This example shows how to include and use "mio-circle04"
+//
+#include <mio/circle/schema_generated.h>
+
+#include <fstream>
+#include <iostream>
+#include <vector>
+
+int main(int argc, char **argv)
+{
+ std::ifstream ifs(argv[1], std::ios_base::binary);
+ std::vector<char> buf(std::istreambuf_iterator<char>{ifs}, std::istreambuf_iterator<char>{});
+
+ flatbuffers::Verifier verifier{reinterpret_cast<uint8_t *>(buf.data()), buf.size()};
+
+ if (!circle::VerifyModelBuffer(verifier))
+ {
+ std::cout << "Fail" << std::endl;
+ return 255;
+ }
+
+ std::cout << "Pass" << std::endl;
+ return 0;
+}
diff --git a/compiler/mio-circle04/include/mio_circle/Helper.h b/compiler/mio-circle04/include/mio_circle/Helper.h
new file mode 100644
index 000000000..d3ffc23e5
--- /dev/null
+++ b/compiler/mio-circle04/include/mio_circle/Helper.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MIO_CIRCLE04_HELPER_H__
+#define __MIO_CIRCLE04_HELPER_H__
+
+#include <mio/circle/schema_generated.h>
+
+namespace mio
+{
+namespace circle
+{
+
+::circle::BuiltinOperator builtin_code_neutral(const ::circle::OperatorCode *opcode);
+bool is_valid(const ::circle::OperatorCode *opcode);
+bool is_custom(const ::circle::OperatorCode *opcode);
+std::string opcode_name(const ::circle::OperatorCode *opcode);
+const char *tensor_type(const ::circle::Tensor *tensor);
+const char *tensor_name(const ::circle::Tensor *tensor);
+
+} // namespace circle
+} // namespace mio
+
+#endif // __MIO_CIRCLE04_HELPER_H__
diff --git a/compiler/mio-circle04/src/Helper.cpp b/compiler/mio-circle04/src/Helper.cpp
new file mode 100644
index 000000000..8b8737a2d
--- /dev/null
+++ b/compiler/mio-circle04/src/Helper.cpp
@@ -0,0 +1,110 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "mio_circle/Helper.h"
+
+#include <algorithm>
+#include <sstream>
+
+namespace mio
+{
+namespace circle
+{
+
+/**
+ * This will provide v3/v3a/v3b format neutral BuiltinOperator
+ * NOTE circle has minus value opcode (252~254 as uint8_t)
+ * we cannot use std::max() like tflite as deprecated_builtin_code can be
+ * minus and builtin_code being 0 for v0.3 files.
+ */
+::circle::BuiltinOperator builtin_code_neutral(const ::circle::OperatorCode *opcode)
+{
+ assert(opcode != nullptr);
+ if (opcode->deprecated_builtin_code() == 127)
+ {
+ assert(opcode->builtin_code() >= 127);
+ return opcode->builtin_code();
+ }
+ // There was no 255(-1) value in v0.3
+ assert(opcode->deprecated_builtin_code() != -1);
+ return static_cast<::circle::BuiltinOperator>(opcode->deprecated_builtin_code());
+}
+
+bool is_valid(const ::circle::OperatorCode *opcode)
+{
+ // Valid Range : BuiltinOperator_MIN <= deprecated_builtin_code <= 127
+ const int8_t deprecated_builtin_code = opcode->deprecated_builtin_code();
+ if (deprecated_builtin_code < ::circle::BuiltinOperator_MIN)
+ return false;
+ // There was no 255(-1) value in v0.3
+ if (deprecated_builtin_code == -1)
+ return false;
+
+ const ::circle::BuiltinOperator builtin_code = opcode->builtin_code();
+ if (!(::circle::BuiltinOperator_MIN <= builtin_code &&
+ builtin_code <= ::circle::BuiltinOperator_MAX))
+ return false;
+
+ return true;
+}
+
+bool is_custom(const ::circle::OperatorCode *opcode)
+{
+ ::circle::BuiltinOperator code = builtin_code_neutral(opcode);
+ return (code == ::circle::BuiltinOperator_CUSTOM);
+}
+
+std::string opcode_name(const ::circle::OperatorCode *opcode)
+{
+ assert(opcode);
+
+ if (!is_valid(opcode))
+ {
+ std::ostringstream oss;
+ oss << "(invalid)";
+ return oss.str();
+ }
+
+ if (is_custom(opcode))
+ {
+ if (!opcode->custom_code())
+ return "(invalid custom)";
+
+ std::string custom_op = "CUSTOM(";
+ custom_op += opcode->custom_code()->c_str();
+ custom_op += ")";
+ return custom_op;
+ }
+
+ ::circle::BuiltinOperator code = builtin_code_neutral(opcode);
+ return ::circle::EnumNameBuiltinOperator(code);
+}
+
+const char *tensor_type(const ::circle::Tensor *tensor)
+{
+ return ::circle::EnumNameTensorType(tensor->type());
+}
+
+const char *tensor_name(const ::circle::Tensor *tensor)
+{
+ if (tensor->name() == nullptr || std::string(tensor->name()->c_str()).empty())
+ return "(noname)";
+
+ return tensor->name()->c_str();
+}
+
+} // namespace circle
+} // namespace mio
diff --git a/compiler/mio-circle04/src/Helper.test.cpp b/compiler/mio-circle04/src/Helper.test.cpp
new file mode 100644
index 000000000..20fce0843
--- /dev/null
+++ b/compiler/mio-circle04/src/Helper.test.cpp
@@ -0,0 +1,153 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "mio_circle/Helper.h"
+
+#include <flatbuffers/flatbuffers.h>
+#include <gtest/gtest.h>
+
+#include <vector>
+
+class mio_circle04_helper_test : public ::testing::Test
+{
+protected:
+ void initialization_finish(void)
+ {
+ _fbb.Finish(circle::CreateModelDirect(_fbb, 0, &_opcodes_vec));
+ }
+
+protected:
+ void add_operator_code(int8_t deprecated_builtin_code, const char *custom_code,
+ circle::BuiltinOperator builtin_code)
+ {
+ _opcodes_vec.push_back(circle::CreateOperatorCodeDirect(
+ _fbb, deprecated_builtin_code, custom_code, 1 /* version */, builtin_code));
+ }
+
+ const circle::OperatorCode *get_operator_code(uint8_t idx)
+ {
+ return circle::GetModel(_fbb.GetBufferPointer())->operator_codes()->Get(idx);
+ }
+
+private:
+ flatbuffers::FlatBufferBuilder _fbb;
+ std::vector<flatbuffers::Offset<circle::OperatorCode>> _opcodes_vec;
+};
+
+TEST_F(mio_circle04_helper_test, v04)
+{
+ // BuiltinOperator_ADD = 0
+ // BuiltinOperator_CONV_2D = 3
+ add_operator_code(3, "", circle::BuiltinOperator_ADD);
+ initialization_finish();
+
+ ASSERT_TRUE(mio::circle::is_valid(get_operator_code(0)));
+ ASSERT_EQ(mio::circle::builtin_code_neutral(get_operator_code(0)),
+ circle::BuiltinOperator_CONV_2D);
+ ASSERT_FALSE(mio::circle::is_custom(get_operator_code(0)));
+}
+
+TEST_F(mio_circle04_helper_test, v04_custom_old)
+{
+ // BuiltinOperator_ADD = 0
+ // BuiltinOperator_CUSTOM = 32
+ add_operator_code(32, "custom", circle::BuiltinOperator_ADD);
+ initialization_finish();
+
+ ASSERT_TRUE(mio::circle::is_valid(get_operator_code(0)));
+ ASSERT_EQ(mio::circle::builtin_code_neutral(get_operator_code(0)),
+ circle::BuiltinOperator_CUSTOM);
+ ASSERT_TRUE(mio::circle::is_custom(get_operator_code(0)));
+}
+
+TEST_F(mio_circle04_helper_test, v04_NEG)
+{
+ // BuiltinOperator_ADD = 0
+ // BuiltinOperator_CUMSUM = 128
+ // deprecated_builtin_code cannot be negative value
+ add_operator_code(128, "", circle::BuiltinOperator_ADD);
+ initialization_finish();
+
+ ASSERT_FALSE(mio::circle::is_valid(get_operator_code(0)));
+}
+
+TEST_F(mio_circle04_helper_test, v04_under127)
+{
+ // BuiltinOperator_CONV_2D = 3
+ add_operator_code(3, "", circle::BuiltinOperator_CONV_2D);
+ initialization_finish();
+
+ ASSERT_TRUE(mio::circle::is_valid(get_operator_code(0)));
+ ASSERT_EQ(mio::circle::builtin_code_neutral(get_operator_code(0)),
+ circle::BuiltinOperator_CONV_2D);
+ ASSERT_FALSE(mio::circle::is_custom(get_operator_code(0)));
+}
+
+TEST_F(mio_circle04_helper_test, v04_under127_NEG)
+{
+ // BuiltinOperator_CONV_2D = 3
+ // BuiltinOperator_CUMSUM = 128
+ // deprecated_builtin_code cannot be negative value
+ add_operator_code(128, "", circle::BuiltinOperator_CONV_2D);
+ initialization_finish();
+
+ ASSERT_FALSE(mio::circle::is_valid(get_operator_code(0)));
+}
+
+TEST_F(mio_circle04_helper_test, v04_custom)
+{
+ // BuiltinOperator_CUSTOM = 32
+ add_operator_code(32, "custom", circle::BuiltinOperator_CUSTOM);
+ initialization_finish();
+
+ ASSERT_TRUE(mio::circle::is_valid(get_operator_code(0)));
+ ASSERT_EQ(mio::circle::builtin_code_neutral(get_operator_code(0)),
+ circle::BuiltinOperator_CUSTOM);
+ ASSERT_TRUE(mio::circle::is_custom(get_operator_code(0)));
+}
+
+TEST_F(mio_circle04_helper_test, v04_custom_NEG)
+{
+ // BuiltinOperator_CUMSUM = 128
+ // deprecated_builtin_code cannot be negative value
+ add_operator_code(128, "custom", circle::BuiltinOperator_CUSTOM);
+ initialization_finish();
+
+ ASSERT_FALSE(mio::circle::is_valid(get_operator_code(0)));
+}
+
+TEST_F(mio_circle04_helper_test, v04_over127)
+{
+ // BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES = 127
+ // BuiltinOperator_CUMSUM = 128
+ add_operator_code(127, "", circle::BuiltinOperator_CUMSUM);
+ initialization_finish();
+
+ ASSERT_TRUE(mio::circle::is_valid(get_operator_code(0)));
+ ASSERT_EQ(mio::circle::builtin_code_neutral(get_operator_code(0)),
+ circle::BuiltinOperator_CUMSUM);
+ ASSERT_FALSE(mio::circle::is_custom(get_operator_code(0)));
+}
+
+TEST_F(mio_circle04_helper_test, v04_over127_NEG)
+{
+ // BuiltinOperator_CUMSUM = 128
+ // deprecated_builtin_code cannot be negative value
+ add_operator_code(128, "", circle::BuiltinOperator_CUMSUM);
+ initialization_finish();
+
+ ASSERT_FALSE(mio::circle::is_valid(get_operator_code(0)));
+}
diff --git a/compiler/mio-tflite/CMakeLists.txt b/compiler/mio-tflite/CMakeLists.txt
index 4660e4003..90187b037 100644
--- a/compiler/mio-tflite/CMakeLists.txt
+++ b/compiler/mio-tflite/CMakeLists.txt
@@ -1,4 +1,4 @@
-nnas_find_package(FlatBuffers EXACT 1.10 QUIET)
+nnas_find_package(FlatBuffers EXACT 2.0 QUIET)
if(NOT FlatBuffers_FOUND)
message(STATUS "Build mio-tflite: FAILED (missing Flatbuffers)")
diff --git a/compiler/mio-tflite260/CMakeLists.txt b/compiler/mio-tflite260/CMakeLists.txt
index 39f4d9a31..f2cfeafcc 100644
--- a/compiler/mio-tflite260/CMakeLists.txt
+++ b/compiler/mio-tflite260/CMakeLists.txt
@@ -1,7 +1,7 @@
-nnas_find_package(FlatBuffers EXACT 1.12 QUIET)
+nnas_find_package(FlatBuffers EXACT 2.0 QUIET)
if(NOT FlatBuffers_FOUND)
- message(STATUS "Build mio-tflite260: FAILED (missing Flatbuffers 1.12)")
+ message(STATUS "Build mio-tflite260: FAILED (missing Flatbuffers 2.0)")
return()
endif(NOT FlatBuffers_FOUND)
@@ -47,3 +47,23 @@ endif(NOT TensorFlowGEMMLowpSource_FOUND)
add_library(mio_tflite260_inc INTERFACE)
target_include_directories(mio_tflite260_inc SYSTEM INTERFACE "${TensorFlowSource_DIR}")
target_include_directories(mio_tflite260_inc SYSTEM INTERFACE "${TensorFlowGEMMLowpSource_DIR}")
+
+file(GLOB_RECURSE SOURCES "src/*.cpp")
+file(GLOB_RECURSE TESTS "src/*.test.cpp")
+list(REMOVE_ITEM SOURCES ${TESTS})
+
+add_library(mio_tflite260_helper STATIC ${SOURCES})
+target_include_directories(mio_tflite260_helper PRIVATE src)
+target_include_directories(mio_tflite260_helper PUBLIC include)
+target_link_libraries(mio_tflite260_helper mio_tflite260)
+
+if(NOT ENABLE_TEST)
+ return()
+endif(NOT ENABLE_TEST)
+
+nnas_find_package(GTest REQUIRED)
+
+GTest_AddTest(mio_tflite260_helper_test ${TESTS})
+target_include_directories(mio_tflite260_helper_test PRIVATE src)
+target_link_libraries(mio_tflite260_helper_test mio_tflite260)
+target_link_libraries(mio_tflite260_helper_test mio_tflite260_helper)
diff --git a/compiler/mio-tflite260/include/mio_tflite260/Helper.h b/compiler/mio-tflite260/include/mio_tflite260/Helper.h
new file mode 100644
index 000000000..cb027e604
--- /dev/null
+++ b/compiler/mio-tflite260/include/mio_tflite260/Helper.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MIO_TFLITE260_HELPER_H__
+#define __MIO_TFLITE260_HELPER_H__
+
+#include <mio/tflite/schema_generated.h>
+
+namespace mio
+{
+namespace tflite
+{
+
+::tflite::BuiltinOperator builtin_code_neutral(const ::tflite::OperatorCode *opcode);
+bool is_valid(const ::tflite::OperatorCode *opcode);
+bool is_custom(const ::tflite::OperatorCode *opcode);
+std::string opcode_name(const ::tflite::OperatorCode *opcode);
+const char *tensor_type(const ::tflite::Tensor *tensor);
+const char *tensor_name(const ::tflite::Tensor *tensor);
+
+} // namespace tflite
+} // namespace mio
+
+#endif // __MIO_TFLITE260_HELPER_H__
diff --git a/compiler/mio-tflite260/src/Helper.cpp b/compiler/mio-tflite260/src/Helper.cpp
new file mode 100644
index 000000000..9669058ea
--- /dev/null
+++ b/compiler/mio-tflite260/src/Helper.cpp
@@ -0,0 +1,104 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "mio_tflite260/Helper.h"
+
+#include <sstream>
+
+namespace mio
+{
+namespace tflite
+{
+
+/**
+ * This will provide v3/v3a format neutral BuiltinOperator
+ *
+ * This function referenced
+ * https://github.com/tensorflow/tensorflow/blob/7d12007d7800d3714a02e05059f3ea602d1aec78/tensorflow/lite/schema/schema_utils.cc
+ */
+::tflite::BuiltinOperator builtin_code_neutral(const ::tflite::OperatorCode *opcode)
+{
+ assert(opcode != nullptr);
+ return std::max(opcode->builtin_code(),
+ static_cast<::tflite::BuiltinOperator>(opcode->deprecated_builtin_code()));
+}
+
+bool is_valid(const ::tflite::OperatorCode *opcode)
+{
+ // Valid Range : 0 <= deprecated_builtin_code <= 127
+ const int8_t deprecated_builtin_code = opcode->deprecated_builtin_code();
+ if (deprecated_builtin_code < 0)
+ return false;
+
+ const ::tflite::BuiltinOperator builtin_code = opcode->builtin_code();
+ if (!(::tflite::BuiltinOperator_MIN <= builtin_code &&
+ builtin_code <= ::tflite::BuiltinOperator_MAX))
+ return false;
+
+ return true;
+}
+
+bool is_custom(const ::tflite::OperatorCode *opcode)
+{
+ ::tflite::BuiltinOperator code = builtin_code_neutral(opcode);
+ return (code == ::tflite::BuiltinOperator_CUSTOM);
+}
+
+std::string opcode_name(const ::tflite::OperatorCode *opcode)
+{
+ assert(opcode);
+
+ if (!is_valid(opcode))
+ {
+ std::ostringstream oss;
+ oss << "(invalid)";
+ return oss.str();
+ }
+
+ if (is_custom(opcode))
+ {
+ if (!opcode->custom_code())
+ return "(invalid custom)";
+
+ std::string custom_op = "CUSTOM(";
+ custom_op += opcode->custom_code()->c_str();
+ custom_op += ")";
+ return custom_op;
+ }
+
+ ::tflite::BuiltinOperator code = builtin_code_neutral(opcode);
+ return ::tflite::EnumNameBuiltinOperator(code);
+}
+
+const char *tensor_type(const ::tflite::Tensor *tensor)
+{
+ return ::tflite::EnumNameTensorType(tensor->type());
+}
+
+const char *tensor_name(const ::tflite::Tensor *tensor)
+{
+ static const char *kEmptyTensorName = "(noname)";
+
+ auto name = tensor->name();
+ if (name)
+ return name->c_str();
+
+ return kEmptyTensorName;
+}
+
+} // namespace tflite
+} // namespace mio
diff --git a/compiler/mio-tflite260/src/Helper.test.cpp b/compiler/mio-tflite260/src/Helper.test.cpp
new file mode 100644
index 000000000..e1ef04ca7
--- /dev/null
+++ b/compiler/mio-tflite260/src/Helper.test.cpp
@@ -0,0 +1,159 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "mio_tflite260/Helper.h"
+
+#include <flatbuffers/flatbuffers.h>
+#include <gtest/gtest.h>
+
+#include <vector>
+
+class mio_tflite260_helper_test : public ::testing::Test
+{
+protected:
+ void initialization_finish(void)
+ {
+ _fbb.Finish(tflite::CreateModelDirect(_fbb, 0, &_opcodes_vec));
+ }
+
+protected:
+ void add_operator_code(int8_t deprecated_builtin_code, const char *custom_code,
+ tflite::BuiltinOperator builtin_code)
+ {
+ _opcodes_vec.push_back(tflite::CreateOperatorCodeDirect(
+ _fbb, deprecated_builtin_code, custom_code, 1 /* version */, builtin_code));
+ }
+
+ const tflite::OperatorCode *get_operator_code(uint8_t idx)
+ {
+ return tflite::GetModel(_fbb.GetBufferPointer())->operator_codes()->Get(idx);
+ }
+
+private:
+ flatbuffers::FlatBufferBuilder _fbb;
+ std::vector<flatbuffers::Offset<tflite::OperatorCode>> _opcodes_vec;
+};
+
+/**
+ * Extended 'builtin_code' is not in TFLite schema v3.
+ *
+ * Thus it is filled with 0(BuiltinOperator_ADD) in schame v3. Please refer to
+ * https://github.com/tensorflow/tensorflow/blob/1ab788fa8d08430be239ab970980b891ad7af494/tensorflow/lite/schema/schema_utils.cc#L28-L31
+ */
+TEST_F(mio_tflite260_helper_test, v3)
+{
+ // BuiltinOperator_ADD = 0
+ // BuiltinOperator_CONV_2D = 3
+ add_operator_code(3, "", tflite::BuiltinOperator_ADD);
+ initialization_finish();
+
+ ASSERT_TRUE(mio::tflite::is_valid(get_operator_code(0)));
+ ASSERT_EQ(mio::tflite::builtin_code_neutral(get_operator_code(0)),
+ tflite::BuiltinOperator_CONV_2D);
+ ASSERT_FALSE(mio::tflite::is_custom(get_operator_code(0)));
+}
+
+TEST_F(mio_tflite260_helper_test, v3_custom)
+{
+ // BuiltinOperator_ADD = 0
+ // BuiltinOperator_CUSTOM = 32
+ add_operator_code(32, "custom", tflite::BuiltinOperator_ADD);
+ initialization_finish();
+
+ ASSERT_TRUE(mio::tflite::is_valid(get_operator_code(0)));
+ ASSERT_EQ(mio::tflite::builtin_code_neutral(get_operator_code(0)),
+ tflite::BuiltinOperator_CUSTOM);
+ ASSERT_TRUE(mio::tflite::is_custom(get_operator_code(0)));
+}
+
+TEST_F(mio_tflite260_helper_test, v3_NEG)
+{
+ // BuiltinOperator_ADD = 0
+ // BuiltinOperator_CUMSUM = 128
+ // deprecated_builtin_code cannot be negative value
+ add_operator_code(128, "", tflite::BuiltinOperator_ADD);
+ initialization_finish();
+
+ ASSERT_FALSE(mio::tflite::is_valid(get_operator_code(0)));
+}
+
+TEST_F(mio_tflite260_helper_test, v3a_under127)
+{
+ // BuiltinOperator_CONV_2D = 3
+ add_operator_code(3, "", tflite::BuiltinOperator_CONV_2D);
+ initialization_finish();
+
+ ASSERT_TRUE(mio::tflite::is_valid(get_operator_code(0)));
+ ASSERT_EQ(mio::tflite::builtin_code_neutral(get_operator_code(0)),
+ tflite::BuiltinOperator_CONV_2D);
+ ASSERT_FALSE(mio::tflite::is_custom(get_operator_code(0)));
+}
+
+TEST_F(mio_tflite260_helper_test, v3a_under127_NEG)
+{
+ // BuiltinOperator_CONV_2D = 3
+ // BuiltinOperator_CUMSUM = 128
+ // deprecated_builtin_code cannot be negative value
+ add_operator_code(128, "", tflite::BuiltinOperator_CONV_2D);
+ initialization_finish();
+
+ ASSERT_FALSE(mio::tflite::is_valid(get_operator_code(0)));
+}
+
+TEST_F(mio_tflite260_helper_test, v3a_custom)
+{
+ // BuiltinOperator_CUSTOM = 32
+ add_operator_code(32, "custom", tflite::BuiltinOperator_CUSTOM);
+ initialization_finish();
+
+ ASSERT_TRUE(mio::tflite::is_valid(get_operator_code(0)));
+ ASSERT_EQ(mio::tflite::builtin_code_neutral(get_operator_code(0)),
+ tflite::BuiltinOperator_CUSTOM);
+ ASSERT_TRUE(mio::tflite::is_custom(get_operator_code(0)));
+}
+
+TEST_F(mio_tflite260_helper_test, v3a_custom_NEG)
+{
+ // BuiltinOperator_CUMSUM = 128
+ // deprecated_builtin_code cannot be negative value
+ add_operator_code(128, "custom", tflite::BuiltinOperator_CUSTOM);
+ initialization_finish();
+
+ ASSERT_FALSE(mio::tflite::is_valid(get_operator_code(0)));
+}
+
+TEST_F(mio_tflite260_helper_test, v3a_over127)
+{
+ // BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES = 127
+ // BuiltinOperator_CUMSUM = 128
+ add_operator_code(127, "", tflite::BuiltinOperator_CUMSUM);
+ initialization_finish();
+
+ ASSERT_TRUE(mio::tflite::is_valid(get_operator_code(0)));
+ ASSERT_EQ(mio::tflite::builtin_code_neutral(get_operator_code(0)),
+ tflite::BuiltinOperator_CUMSUM);
+ ASSERT_FALSE(mio::tflite::is_custom(get_operator_code(0)));
+}
+
+TEST_F(mio_tflite260_helper_test, v3a_over127_NEG)
+{
+ // BuiltinOperator_CUMSUM = 128
+ // deprecated_builtin_code cannot be negative value
+ add_operator_code(128, "", tflite::BuiltinOperator_CUMSUM);
+ initialization_finish();
+
+ ASSERT_FALSE(mio::tflite::is_valid(get_operator_code(0)));
+}
diff --git a/compiler/mio-tflite280/CMakeLists.txt b/compiler/mio-tflite280/CMakeLists.txt
new file mode 100644
index 000000000..f48711eb7
--- /dev/null
+++ b/compiler/mio-tflite280/CMakeLists.txt
@@ -0,0 +1,69 @@
+nnas_find_package(FlatBuffers EXACT 2.0 QUIET)
+
+if(NOT FlatBuffers_FOUND)
+ message(STATUS "Build mio-tflite280: FAILED (missing Flatbuffers 2.0)")
+ return()
+endif(NOT FlatBuffers_FOUND)
+
+nnas_find_package(TensorFlowSource EXACT 2.8.0 QUIET)
+
+if(NOT TensorFlowSource_FOUND)
+ message(STATUS "Build mio-tflite280: FAILED (missing TensorFlowSource 2.8.0)")
+ return()
+endif(NOT TensorFlowSource_FOUND)
+
+message(STATUS "Build mio-tflite280: TRUE")
+
+set(SCHEMA_FILE "${TensorFlowSource_DIR}/tensorflow/lite/schema/schema.fbs")
+
+# NOTE Use copy of schema.fbs as to provide unified way for circle also
+add_custom_command(OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/schema.fbs"
+ COMMAND ${CMAKE_COMMAND} -E copy "${SCHEMA_FILE}" schema.fbs
+ WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}"
+ DEPENDS "${SCHEMA_FILE}"
+)
+
+FlatBuffers_Target(mio_tflite280
+ OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/gen/mio/tflite"
+ INCLUDE_DIR "${CMAKE_CURRENT_BINARY_DIR}/gen"
+ SCHEMA_DIR "${CMAKE_CURRENT_BINARY_DIR}"
+ SCHEMA_FILES "schema.fbs"
+)
+
+add_executable(mio_tflite280_example example.cpp)
+target_link_libraries(mio_tflite280_example mio_tflite280)
+
+# Temporay tflite validation tool to replace nnkit-tflite
+# TODO provide full tflite validation with runtime/interpreter
+add_executable(mio_tflite280_validate example.cpp)
+target_link_libraries(mio_tflite280_validate mio_tflite280)
+
+nnas_find_package(TensorFlowGEMMLowpSource EXACT 2.8.0 QUIET)
+
+if(NOT TensorFlowGEMMLowpSource_FOUND)
+ return()
+endif(NOT TensorFlowGEMMLowpSource_FOUND)
+
+add_library(mio_tflite280_inc INTERFACE)
+target_include_directories(mio_tflite280_inc SYSTEM INTERFACE "${TensorFlowSource_DIR}")
+target_include_directories(mio_tflite280_inc SYSTEM INTERFACE "${TensorFlowGEMMLowpSource_DIR}")
+
+file(GLOB_RECURSE SOURCES "src/*.cpp")
+file(GLOB_RECURSE TESTS "src/*.test.cpp")
+list(REMOVE_ITEM SOURCES ${TESTS})
+
+add_library(mio_tflite280_helper STATIC ${SOURCES})
+target_include_directories(mio_tflite280_helper PRIVATE src)
+target_include_directories(mio_tflite280_helper PUBLIC include)
+target_link_libraries(mio_tflite280_helper mio_tflite280)
+
+if(NOT ENABLE_TEST)
+ return()
+endif(NOT ENABLE_TEST)
+
+nnas_find_package(GTest REQUIRED)
+
+GTest_AddTest(mio_tflite280_helper_test ${TESTS})
+target_include_directories(mio_tflite280_helper_test PRIVATE src)
+target_link_libraries(mio_tflite280_helper_test mio_tflite280)
+target_link_libraries(mio_tflite280_helper_test mio_tflite280_helper)
diff --git a/compiler/mio-tflite280/README.md b/compiler/mio-tflite280/README.md
new file mode 100644
index 000000000..73219a7df
--- /dev/null
+++ b/compiler/mio-tflite280/README.md
@@ -0,0 +1,3 @@
+# mio-tflite280
+
+_mio-tflite280_ provides a library to access TensorFlow lite model files with V2.8.0.
diff --git a/compiler/mio-tflite280/example.cpp b/compiler/mio-tflite280/example.cpp
new file mode 100644
index 000000000..83356b943
--- /dev/null
+++ b/compiler/mio-tflite280/example.cpp
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//
+// This example shows how to include and use "mio-tflite280"
+//
+#include <mio/tflite/schema_generated.h>
+
+#include <fstream>
+#include <iostream>
+#include <vector>
+
+int main(int argc, char **argv)
+{
+ std::ifstream ifs(argv[1], std::ios_base::binary);
+ std::vector<char> buf(std::istreambuf_iterator<char>{ifs}, std::istreambuf_iterator<char>{});
+
+ flatbuffers::Verifier verifier{reinterpret_cast<uint8_t *>(buf.data()), buf.size()};
+
+ if (!tflite::VerifyModelBuffer(verifier))
+ {
+ std::cout << "Fail" << std::endl;
+ return 255;
+ }
+
+ std::cout << "Pass" << std::endl;
+ return 0;
+}
diff --git a/compiler/mio-tflite280/include/mio_tflite280/Helper.h b/compiler/mio-tflite280/include/mio_tflite280/Helper.h
new file mode 100644
index 000000000..b0fb0ace7
--- /dev/null
+++ b/compiler/mio-tflite280/include/mio_tflite280/Helper.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MIO_TFLITE280_HELPER_H__
+#define __MIO_TFLITE280_HELPER_H__
+
+#include <mio/tflite/schema_generated.h>
+
+namespace mio
+{
+namespace tflite
+{
+
+::tflite::BuiltinOperator builtin_code_neutral(const ::tflite::OperatorCode *opcode);
+bool is_valid(const ::tflite::OperatorCode *opcode);
+bool is_custom(const ::tflite::OperatorCode *opcode);
+std::string opcode_name(const ::tflite::OperatorCode *opcode);
+const char *tensor_type(const ::tflite::Tensor *tensor);
+const char *tensor_name(const ::tflite::Tensor *tensor);
+
+} // namespace tflite
+} // namespace mio
+
+#endif // __MIO_TFLITE280_HELPER_H__
diff --git a/compiler/mio-tflite280/src/Helper.cpp b/compiler/mio-tflite280/src/Helper.cpp
new file mode 100644
index 000000000..ebf0bd140
--- /dev/null
+++ b/compiler/mio-tflite280/src/Helper.cpp
@@ -0,0 +1,104 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "mio_tflite280/Helper.h"
+
+#include <sstream>
+
+namespace mio
+{
+namespace tflite
+{
+
+/**
+ * This will provide v3/v3a format neutral BuiltinOperator
+ *
+ * This function referenced
+ * https://github.com/tensorflow/tensorflow/blob/7d12007d7800d3714a02e05059f3ea602d1aec78/tensorflow/lite/schema/schema_utils.cc
+ */
+::tflite::BuiltinOperator builtin_code_neutral(const ::tflite::OperatorCode *opcode)
+{
+ assert(opcode != nullptr);
+ return std::max(opcode->builtin_code(),
+ static_cast<::tflite::BuiltinOperator>(opcode->deprecated_builtin_code()));
+}
+
+bool is_valid(const ::tflite::OperatorCode *opcode)
+{
+ // Valid Range : 0 <= deprecated_builtin_code <= 127
+ const int8_t deprecated_builtin_code = opcode->deprecated_builtin_code();
+ if (deprecated_builtin_code < 0)
+ return false;
+
+ const ::tflite::BuiltinOperator builtin_code = opcode->builtin_code();
+ if (!(::tflite::BuiltinOperator_MIN <= builtin_code &&
+ builtin_code <= ::tflite::BuiltinOperator_MAX))
+ return false;
+
+ return true;
+}
+
+bool is_custom(const ::tflite::OperatorCode *opcode)
+{
+ ::tflite::BuiltinOperator code = builtin_code_neutral(opcode);
+ return (code == ::tflite::BuiltinOperator_CUSTOM);
+}
+
+std::string opcode_name(const ::tflite::OperatorCode *opcode)
+{
+ assert(opcode);
+
+ if (!is_valid(opcode))
+ {
+ std::ostringstream oss;
+ oss << "(invalid)";
+ return oss.str();
+ }
+
+ if (is_custom(opcode))
+ {
+ if (!opcode->custom_code())
+ return "(invalid custom)";
+
+ std::string custom_op = "CUSTOM(";
+ custom_op += opcode->custom_code()->c_str();
+ custom_op += ")";
+ return custom_op;
+ }
+
+ ::tflite::BuiltinOperator code = builtin_code_neutral(opcode);
+ return ::tflite::EnumNameBuiltinOperator(code);
+}
+
+const char *tensor_type(const ::tflite::Tensor *tensor)
+{
+ return ::tflite::EnumNameTensorType(tensor->type());
+}
+
+const char *tensor_name(const ::tflite::Tensor *tensor)
+{
+ static const char *kEmptyTensorName = "(noname)";
+
+ auto name = tensor->name();
+ if (name)
+ return name->c_str();
+
+ return kEmptyTensorName;
+}
+
+} // namespace tflite
+} // namespace mio
diff --git a/compiler/mio-tflite280/src/Helper.test.cpp b/compiler/mio-tflite280/src/Helper.test.cpp
new file mode 100644
index 000000000..df573bf44
--- /dev/null
+++ b/compiler/mio-tflite280/src/Helper.test.cpp
@@ -0,0 +1,159 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "mio_tflite280/Helper.h"
+
+#include <flatbuffers/flatbuffers.h>
+#include <gtest/gtest.h>
+
+#include <vector>
+
+class mio_tflite280_helper_test : public ::testing::Test
+{
+protected:
+ void initialization_finish(void)
+ {
+ _fbb.Finish(tflite::CreateModelDirect(_fbb, 0, &_opcodes_vec));
+ }
+
+protected:
+ void add_operator_code(int8_t deprecated_builtin_code, const char *custom_code,
+ tflite::BuiltinOperator builtin_code)
+ {
+ _opcodes_vec.push_back(tflite::CreateOperatorCodeDirect(
+ _fbb, deprecated_builtin_code, custom_code, 1 /* version */, builtin_code));
+ }
+
+ const tflite::OperatorCode *get_operator_code(uint8_t idx)
+ {
+ return tflite::GetModel(_fbb.GetBufferPointer())->operator_codes()->Get(idx);
+ }
+
+private:
+ flatbuffers::FlatBufferBuilder _fbb;
+ std::vector<flatbuffers::Offset<tflite::OperatorCode>> _opcodes_vec;
+};
+
+/**
+ * Extended 'builtin_code' is not in TFLite schema v3.
+ *
+ * Thus it is filled with 0(BuiltinOperator_ADD) in schame v3. Please refer to
+ * https://github.com/tensorflow/tensorflow/blob/1ab788fa8d08430be239ab970980b891ad7af494/tensorflow/lite/schema/schema_utils.cc#L28-L31
+ */
+TEST_F(mio_tflite280_helper_test, v3)
+{
+ // BuiltinOperator_ADD = 0
+ // BuiltinOperator_CONV_2D = 3
+ add_operator_code(3, "", tflite::BuiltinOperator_ADD);
+ initialization_finish();
+
+ ASSERT_TRUE(mio::tflite::is_valid(get_operator_code(0)));
+ ASSERT_EQ(mio::tflite::builtin_code_neutral(get_operator_code(0)),
+ tflite::BuiltinOperator_CONV_2D);
+ ASSERT_FALSE(mio::tflite::is_custom(get_operator_code(0)));
+}
+
+TEST_F(mio_tflite280_helper_test, v3_custom)
+{
+ // BuiltinOperator_ADD = 0
+ // BuiltinOperator_CUSTOM = 32
+ add_operator_code(32, "custom", tflite::BuiltinOperator_ADD);
+ initialization_finish();
+
+ ASSERT_TRUE(mio::tflite::is_valid(get_operator_code(0)));
+ ASSERT_EQ(mio::tflite::builtin_code_neutral(get_operator_code(0)),
+ tflite::BuiltinOperator_CUSTOM);
+ ASSERT_TRUE(mio::tflite::is_custom(get_operator_code(0)));
+}
+
+TEST_F(mio_tflite280_helper_test, v3_NEG)
+{
+ // BuiltinOperator_ADD = 0
+ // BuiltinOperator_CUMSUM = 128
+ // deprecated_builtin_code cannot be negative value
+ add_operator_code(128, "", tflite::BuiltinOperator_ADD);
+ initialization_finish();
+
+ ASSERT_FALSE(mio::tflite::is_valid(get_operator_code(0)));
+}
+
+TEST_F(mio_tflite280_helper_test, v3a_under127)
+{
+ // BuiltinOperator_CONV_2D = 3
+ add_operator_code(3, "", tflite::BuiltinOperator_CONV_2D);
+ initialization_finish();
+
+ ASSERT_TRUE(mio::tflite::is_valid(get_operator_code(0)));
+ ASSERT_EQ(mio::tflite::builtin_code_neutral(get_operator_code(0)),
+ tflite::BuiltinOperator_CONV_2D);
+ ASSERT_FALSE(mio::tflite::is_custom(get_operator_code(0)));
+}
+
+TEST_F(mio_tflite280_helper_test, v3a_under127_NEG)
+{
+ // BuiltinOperator_CONV_2D = 3
+ // BuiltinOperator_CUMSUM = 128
+ // deprecated_builtin_code cannot be negative value
+ add_operator_code(128, "", tflite::BuiltinOperator_CONV_2D);
+ initialization_finish();
+
+ ASSERT_FALSE(mio::tflite::is_valid(get_operator_code(0)));
+}
+
+TEST_F(mio_tflite280_helper_test, v3a_custom)
+{
+ // BuiltinOperator_CUSTOM = 32
+ add_operator_code(32, "custom", tflite::BuiltinOperator_CUSTOM);
+ initialization_finish();
+
+ ASSERT_TRUE(mio::tflite::is_valid(get_operator_code(0)));
+ ASSERT_EQ(mio::tflite::builtin_code_neutral(get_operator_code(0)),
+ tflite::BuiltinOperator_CUSTOM);
+ ASSERT_TRUE(mio::tflite::is_custom(get_operator_code(0)));
+}
+
+TEST_F(mio_tflite280_helper_test, v3a_custom_NEG)
+{
+ // BuiltinOperator_CUMSUM = 128
+ // deprecated_builtin_code cannot be negative value
+ add_operator_code(128, "custom", tflite::BuiltinOperator_CUSTOM);
+ initialization_finish();
+
+ ASSERT_FALSE(mio::tflite::is_valid(get_operator_code(0)));
+}
+
+TEST_F(mio_tflite280_helper_test, v3a_over127)
+{
+ // BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES = 127
+ // BuiltinOperator_CUMSUM = 128
+ add_operator_code(127, "", tflite::BuiltinOperator_CUMSUM);
+ initialization_finish();
+
+ ASSERT_TRUE(mio::tflite::is_valid(get_operator_code(0)));
+ ASSERT_EQ(mio::tflite::builtin_code_neutral(get_operator_code(0)),
+ tflite::BuiltinOperator_CUMSUM);
+ ASSERT_FALSE(mio::tflite::is_custom(get_operator_code(0)));
+}
+
+TEST_F(mio_tflite280_helper_test, v3a_over127_NEG)
+{
+ // BuiltinOperator_CUMSUM = 128
+ // deprecated_builtin_code cannot be negative value
+ add_operator_code(128, "", tflite::BuiltinOperator_CUMSUM);
+ initialization_finish();
+
+ ASSERT_FALSE(mio::tflite::is_valid(get_operator_code(0)));
+}
diff --git a/compiler/mir/src/mir_onnx_importer/CMakeLists.txt b/compiler/mir/src/mir_onnx_importer/CMakeLists.txt
index e6eb13b93..04c22055e 100644
--- a/compiler/mir/src/mir_onnx_importer/CMakeLists.txt
+++ b/compiler/mir/src/mir_onnx_importer/CMakeLists.txt
@@ -112,6 +112,10 @@ target_include_directories(mir_onnx_importer PUBLIC ../../include/mir_onnx_impor
target_include_directories(mir_onnx_importer PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_link_libraries(mir_onnx_importer PUBLIC mir mir_onnx_proto PRIVATE mir_interpreter nncc_common)
+if(NOT ENABLE_TEST)
+ return()
+endif(NOT ENABLE_TEST)
+
nnas_find_package(GTest REQUIRED)
file(GLOB_RECURSE TEST_SOURCES "*.test.cpp")
diff --git a/compiler/mir/src/mir_tflite_importer/CMakeLists.txt b/compiler/mir/src/mir_tflite_importer/CMakeLists.txt
index 42eb4f8a5..6c6c28a32 100644
--- a/compiler/mir/src/mir_tflite_importer/CMakeLists.txt
+++ b/compiler/mir/src/mir_tflite_importer/CMakeLists.txt
@@ -1,4 +1,4 @@
-nnas_find_package(FlatBuffers EXACT 1.10 REQUIRED)
+nnas_find_package(FlatBuffers EXACT 2.0 REQUIRED)
if (NOT FlatBuffers_FOUND)
return()
diff --git a/compiler/mir2loco/CMakeLists.txt b/compiler/mir2loco/CMakeLists.txt
index a8a096ef4..217f1bd15 100644
--- a/compiler/mir2loco/CMakeLists.txt
+++ b/compiler/mir2loco/CMakeLists.txt
@@ -8,11 +8,11 @@ target_include_directories(mir2loco PUBLIC include)
target_link_libraries(mir2loco PUBLIC mir)
target_link_libraries(mir2loco PUBLIC loco)
-nnas_find_package(GTest QUIET)
-
-if(NOT GTest_FOUND)
+if(NOT ENABLE_TEST)
return()
-endif(NOT GTest_FOUND)
+endif(NOT ENABLE_TEST)
+
+nnas_find_package(GTest QUIET)
GTest_AddTest(mir2loco_test ${TESTS})
target_link_libraries(mir2loco_test mir2loco)
diff --git a/compiler/moco-tf/CMakeLists.txt b/compiler/moco-tf/CMakeLists.txt
index 7c42761ba..95669264f 100644
--- a/compiler/moco-tf/CMakeLists.txt
+++ b/compiler/moco-tf/CMakeLists.txt
@@ -26,6 +26,7 @@ target_link_libraries(moco_tf_frontend PRIVATE locomotiv)
target_link_libraries(moco_tf_frontend PRIVATE plier_tf)
target_link_libraries(moco_tf_frontend PRIVATE locoex_customop)
target_link_libraries(moco_tf_frontend PRIVATE logo)
+target_link_libraries(moco_tf_frontend PRIVATE logo_ex)
target_link_libraries(moco_tf_frontend PRIVATE oops)
install(TARGETS moco_tf_frontend DESTINATION lib)
@@ -46,4 +47,5 @@ target_link_libraries(moco_tf_frontend_test moco_tf_frontend)
target_link_libraries(moco_tf_frontend_test plier_tf)
target_link_libraries(moco_tf_frontend_test locoex_customop)
target_link_libraries(moco_tf_frontend_test logo)
+target_link_libraries(moco_tf_frontend_test logo_ex)
add_test(moco_tf_frontend_test moco_tf_frontend_test)
diff --git a/compiler/moco-tf/requires.cmake b/compiler/moco-tf/requires.cmake
index 90590e374..71755556c 100644
--- a/compiler/moco-tf/requires.cmake
+++ b/compiler/moco-tf/requires.cmake
@@ -9,5 +9,6 @@ require("mio-tf")
require("plier-tf")
require("locoex-customop")
require("logo")
+require("logo-ex")
require("oops")
require("bino")
diff --git a/compiler/moco-tf/src/Transforms.h b/compiler/moco-tf/src/Transforms.h
index f14b81675..a197a796e 100644
--- a/compiler/moco-tf/src/Transforms.h
+++ b/compiler/moco-tf/src/Transforms.h
@@ -21,6 +21,7 @@
#include "Transforms/TypeInferencePass.h"
#include <logo/Passes.h>
+#include <logo/PassesEx.h>
#include <moco/Pass/Passes.h>
#endif // __MOCO_TF_TRANSFORMS_H__
diff --git a/compiler/morph/CMakeLists.txt b/compiler/morph/CMakeLists.txt
index ec7da8d30..5a5ae2623 100644
--- a/compiler/morph/CMakeLists.txt
+++ b/compiler/morph/CMakeLists.txt
@@ -8,11 +8,11 @@ target_include_directories(morph PUBLIC include)
target_link_libraries(morph PRIVATE nncc_common)
target_link_libraries(morph PUBLIC angkor)
-nnas_find_package(GTest QUIET)
-
-if(NOT GTest_FOUND)
+if(NOT ENABLE_TEST)
return()
-endif(NOT GTest_FOUND)
+endif(NOT ENABLE_TEST)
+
+nnas_find_package(GTest QUIET)
add_executable(morph_test ${TESTS})
target_link_libraries(morph_test morph)
diff --git a/compiler/nest/core/CMakeLists.txt b/compiler/nest/core/CMakeLists.txt
index b603f9ae9..4f17db3b4 100644
--- a/compiler/nest/core/CMakeLists.txt
+++ b/compiler/nest/core/CMakeLists.txt
@@ -15,11 +15,11 @@ foreach(EXAMPLE_FILE IN ITEMS ${EXAMPLE_FILES})
target_link_libraries(${TARGET_NAME} nest_core)
endforeach(EXAMPLE_FILE)
-nnas_find_package(GTest QUIET)
-
-if(NOT GTest_FOUND)
+if(NOT ENABLE_TEST)
return()
-endif(NOT GTest_FOUND)
+endif(NOT ENABLE_TEST)
+
+nnas_find_package(GTest QUIET)
add_executable(nest_core_test ${TESTS})
target_link_libraries(nest_core_test gtest_main)
diff --git a/compiler/nike/CMakeLists.txt b/compiler/nike/CMakeLists.txt
index 737c73b8f..6bd3199e3 100644
--- a/compiler/nike/CMakeLists.txt
+++ b/compiler/nike/CMakeLists.txt
@@ -5,11 +5,11 @@ list(REMOVE_ITEM SOURCES ${TESTS})
add_library(nike STATIC ${SOURCES})
target_include_directories(nike PUBLIC include)
-nnas_find_package(GTest QUIET)
-
-if(NOT GTest_FOUND)
+if(NOT ENABLE_TEST)
return()
-endif(NOT GTest_FOUND)
+endif(NOT ENABLE_TEST)
+
+nnas_find_package(GTest QUIET)
GTest_AddTest(nike_test ${TESTS})
target_link_libraries(nike_test nike)
diff --git a/compiler/nnc/unittests/soft_backend/ModelAnalyzer.cpp b/compiler/nnc/unittests/soft_backend/ModelAnalyzer.cpp
index d38385e91..c2135c4be 100644
--- a/compiler/nnc/unittests/soft_backend/ModelAnalyzer.cpp
+++ b/compiler/nnc/unittests/soft_backend/ModelAnalyzer.cpp
@@ -22,6 +22,8 @@
#include <gtest/gtest.h>
+#include <algorithm>
+
using namespace std;
using namespace nnc;
using namespace mir;
diff --git a/compiler/nnop/CMakeLists.txt b/compiler/nnop/CMakeLists.txt
index 82c0e3a86..d2c8af26d 100644
--- a/compiler/nnop/CMakeLists.txt
+++ b/compiler/nnop/CMakeLists.txt
@@ -2,11 +2,11 @@ add_library(nnop INTERFACE)
target_include_directories(nnop INTERFACE include)
target_link_libraries(nnop INTERFACE angkor)
-nnas_find_package(GTest QUIET)
-
-if(NOT GTest_FOUND)
+if(NOT ENABLE_TEST)
return()
-endif(NOT GTest_FOUND)
+endif(NOT ENABLE_TEST)
+
+nnas_find_package(GTest QUIET)
file(GLOB_RECURSE TESTS "src/*.test.cpp")
diff --git a/compiler/one-cmds/CMakeLists.txt b/compiler/one-cmds/CMakeLists.txt
index 729bfa80a..8732340ae 100644
--- a/compiler/one-cmds/CMakeLists.txt
+++ b/compiler/one-cmds/CMakeLists.txt
@@ -14,6 +14,11 @@ set(ONE_COMMAND_FILES
onecc
)
+# pytorch importer is an experimental feature, it is not used in default configuration
+if(ENABLE_ONE_IMPORT_PYTORCH)
+ list(APPEND ONE_COMMAND_FILES one-import-pytorch)
+endif(ENABLE_ONE_IMPORT_PYTORCH)
+
foreach(ONE_COMMAND IN ITEMS ${ONE_COMMAND_FILES})
set(ONE_COMMAND_FILE ${ONE_COMMAND})
@@ -41,6 +46,7 @@ set(ONE_UTILITY_FILES
one-build.template.cfg
onecc.template.cfg
utils.py
+ onnx_legalizer.py
)
foreach(ONE_UTILITY IN ITEMS ${ONE_UTILITY_FILES})
@@ -66,6 +72,39 @@ foreach(ONE_UTILITY IN ITEMS ${ONE_UTILITY_FILES})
endforeach(ONE_UTILITY)
+# make python directory
+set(ONE_PYTHON_FILES constant.py
+ make_cmd.py)
+
+foreach(ONE_PYTHON_FILE IN ITEMS ${ONE_PYTHON_FILES})
+
+ set(ONE_PYTHON_DIR "onelib")
+ set(ONE_PYTHON_DIR_BIN "${CMAKE_CURRENT_BINARY_DIR}/${ONE_PYTHON_DIR}")
+ set(ONE_PYTHON_FILE_SRC "${CMAKE_CURRENT_SOURCE_DIR}/${ONE_PYTHON_DIR}/${ONE_PYTHON_FILE}")
+ set(ONE_PYTHON_FILE_BIN "${CMAKE_CURRENT_BINARY_DIR}/${ONE_PYTHON_DIR}/${ONE_PYTHON_FILE}")
+ set(ONE_PYTHON_TARGET "${ONE_PYTHON_FILE}_target")
+
+ add_custom_command(OUTPUT ${ONE_PYTHON_DIR_BIN}
+ COMMAND ${CMAKE_COMMAND} -E make_directory "${ONE_PYTHON_DIR_BIN}"
+ COMMENT "Generate ${ONE_PYTHON_DIR_BIN}"
+ )
+
+ add_custom_command(OUTPUT ${ONE_PYTHON_FILE_BIN}
+ COMMAND ${CMAKE_COMMAND} -E copy "${ONE_PYTHON_FILE_SRC}" "${ONE_PYTHON_FILE_BIN}"
+ DEPENDS ${ONE_PYTHON_SRC}
+ COMMENT "Generate ${ONE_PYTHON_FILE_BIN}"
+ )
+
+ add_custom_target(${ONE_PYTHON_TARGET} ALL DEPENDS ${ONE_PYTHON_DIR_BIN} ${ONE_PYTHON_FILE_BIN})
+
+ install(DIRECTORY ${ONE_PYTHON_DIR}
+ FILE_PERMISSIONS OWNER_WRITE OWNER_READ
+ GROUP_READ
+ WORLD_READ
+ DESTINATION bin)
+
+endforeach(ONE_PYTHON_FILE)
+
set(ONE_DOCUMENT_FILES
how-to-use-one-commands.txt
how-to-prepare-virtualenv.txt
diff --git a/compiler/one-cmds/how-to-prepare-virtualenv.txt b/compiler/one-cmds/how-to-prepare-virtualenv.txt
index 6d846c081..8d6007f38 100644
--- a/compiler/one-cmds/how-to-prepare-virtualenv.txt
+++ b/compiler/one-cmds/how-to-prepare-virtualenv.txt
@@ -5,7 +5,7 @@ Last update: 2020-09-15
This document explains about 'one-prepare-venv' command.
-'one-prepare-venv' will prepare python3 virtual environment with tensorflow-cpu
+'one-prepare-venv' will prepare python3.8 virtual environment with tensorflow-cpu
version 2.3.0, recommanded 2.x version as of now, so that 'one-import-tf'
command can execute properly.
@@ -20,7 +20,7 @@ Please install these required packages before venv preparation.
$ sudo apt-get update
$ sudo apt-get upgrade
-$ sudo apt-get install python3-pip python3-venv
+$ sudo apt-get install python3.8 python3-pip python3.8-venv
How to run for Ubuntu
@@ -36,18 +36,9 @@ There will be venv folder as of result.
How to run for Windows
----------------------
-1. First, please prepare Python 3.5-3.7
-2. Open the Command Prompt as an administrator
-3. cd(change directory) to the directory where one-compiler is installed
-4. run below command
-```
-$ ONE\install\bin> python -m venv venv
-$ ONE\install\bin> cd venv/Scripts
-$ ONE\install\bin\venv/Scripts> pip.exe install -U pip
-$ ONE\install\bin\venv/Scripts> pip.exe install -U tensorflow-cpu==2.3.0
-```
-
-After running the above command, go back to MinGW and run one-compiler.
+Support for Windows is not maintained for now.
+If you have any needs for running in Windows, please fire an issue.
+Or you can use Docker for Windows.
Trouble shooting
diff --git a/compiler/one-cmds/how-to-use-one-commands.txt b/compiler/one-cmds/how-to-use-one-commands.txt
index 0a0c4b14c..ebc165167 100644
--- a/compiler/one-cmds/how-to-use-one-commands.txt
+++ b/compiler/one-cmds/how-to-use-one-commands.txt
@@ -155,6 +155,7 @@ Current transformation options are
- fold_cast : This removes Cast operation which can be folded
- fold_dequantize : This removes Dequantize operation which can be folded
- fold_dwconv : This folds Depthwise Convolution operation which can be folded
+- fold_gather : This removes Gather operation which can be folded
- fold_sparse_to_dense : This removes SparseToDense operation which can be folded
- forward_reshape_to_unaryop: This will move Reshape after UnaryOp for centain condition
- fuse_add_with_fully_connected: This fuses Add operator with the preceding FullyConnected operator if possible
@@ -178,6 +179,7 @@ Current transformation options are
- generate_profile_data : This will turn on profiling data generation.
- remove_fakequant : This will remove all fakequant operators.
- remove_quantdequant : This will remove all Quantize-Dequantize sequence.
+- remove_redundant_quantize : This removes redundant quantize operators.
- remove_redundant_reshape : This fuses or removes redundant reshape operators.
- remove_redundant_transpose : This fuses or removes redundant transpose operators.
- remove_unnecessary_reshape : This removes unnecessary reshape operators.
diff --git a/compiler/one-cmds/one-build b/compiler/one-cmds/one-build
index 90dfa77b8..5c313b44b 100644
--- a/compiler/one-cmds/one-build
+++ b/compiler/one-cmds/one-build
@@ -154,25 +154,31 @@ def main():
config = _parse_cfg(args)
# verify configuration file
- drivers = [
- 'one-import-tf', 'one-import-tflite', 'one-import-bcq', 'one-import-onnx',
- 'one-optimize', 'one-quantize', 'one-pack', 'one-codegen'
+ bin_dir = os.path.dirname(os.path.realpath(__file__))
+ import_drivers_dict = _utils._detect_one_import_drivers(bin_dir)
+ transform_drivers = [
+ 'one-optimize', 'one-quantize', 'one-pack', 'one-codegen', 'one-profile'
]
- _verify_cfg(drivers, config)
+ _verify_cfg(import_drivers_dict, config)
# verify optimization option file
_verify_opt(args)
# get sections to run
section_to_run = []
- for d in drivers:
+ for d in list(import_drivers_dict) + transform_drivers:
if _is_available_driver(config, d):
section_to_run.append(d)
# run
dir_path = os.path.dirname(os.path.realpath(__file__))
for section in section_to_run:
- driver_path = os.path.join(dir_path, _get_driver_name(section))
+ if section in import_drivers_dict:
+ # we already has driver name in dict
+ driver_name = import_drivers_dict[section]
+ else:
+ driver_name = _get_driver_name(section)
+ driver_path = os.path.join(dir_path, driver_name)
cmd = [driver_path, '--config', getattr(args, 'config'), '--section', section]
if section == 'one-optimize' and _utils._is_valid_attr(args, 'O'):
cmd += ['-O', getattr(args, 'O')]
diff --git a/compiler/one-cmds/one-import-bcq b/compiler/one-cmds/one-import-bcq
index 9aef6270e..ef89a9297 100644
--- a/compiler/one-cmds/one-import-bcq
+++ b/compiler/one-cmds/one-import-bcq
@@ -25,6 +25,7 @@ import subprocess
import sys
import tempfile
+import onelib.make_cmd as _make_cmd
import utils as _utils
import generate_bcq_output_arrays as _bcq_info_gen
@@ -32,6 +33,10 @@ import generate_bcq_output_arrays as _bcq_info_gen
sys.tracebacklimit = 0
+def get_driver_cfg_section():
+ return "one-import-bcq"
+
+
def _get_parser():
parser = argparse.ArgumentParser(
description='command line tool to convert TensorFlow with BCQ to circle')
@@ -155,7 +160,7 @@ def _convert(args):
tmpdir,
os.path.splitext(
os.path.basename(generate_bcq_metadata_output_path))[0]) + '.tflite'
- tf2tfliteV2_cmd = _utils._make_tf2tfliteV2_cmd(args, tf2tfliteV2_path,
+ tf2tfliteV2_cmd = _make_cmd.make_tf2tfliteV2_cmd(args, tf2tfliteV2_path,
generate_bcq_metadata_output_path,
tf2tfliteV2_output_path)
try:
@@ -171,7 +176,7 @@ def _convert(args):
# make a command to convert from tflite to circle
tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
- tflite2circle_cmd = _utils._make_tflite2circle_cmd(tflite2circle_path,
+ tflite2circle_cmd = _make_cmd.make_tflite2circle_cmd(tflite2circle_path,
tf2tfliteV2_output_path,
getattr(args, 'output_path'))
diff --git a/compiler/one-cmds/one-import-onnx b/compiler/one-cmds/one-import-onnx
index 1c0c5498e..eaa136197 100644
--- a/compiler/one-cmds/one-import-onnx
+++ b/compiler/one-cmds/one-import-onnx
@@ -27,12 +27,25 @@ import tempfile
import onnx
import onnx_tf
+# ONNX legalizer is an optional feature
+# It enables conversion of some operations, but in experimental phase for now
+try:
+ import onnx_legalizer
+ _onnx_legalizer_enabled = True
+except ImportError:
+ _onnx_legalizer_enabled = False
+
+import onelib.make_cmd as _make_cmd
import utils as _utils
# TODO Find better way to suppress trackback on error
sys.tracebacklimit = 0
+def get_driver_cfg_section():
+ return "one-import-onnx"
+
+
def _get_parser():
parser = argparse.ArgumentParser(
description='command line tool to convert ONNX to circle')
@@ -64,6 +77,10 @@ def _get_parser():
tf2tfliteV2_group.add_argument('--model_format', default='saved_model')
tf2tfliteV2_group.add_argument('--converter_version', default='v2')
+ parser.add_argument('--unroll_rnn', action='store_true', help='Unroll RNN operators')
+ parser.add_argument(
+ '--unroll_lstm', action='store_true', help='Unroll LSTM operators')
+
# save intermediate file(s)
parser.add_argument(
'--save_intermediate',
@@ -120,6 +137,11 @@ def _convert(args):
tmpdir = os.path.dirname(logfile_path)
# convert onnx to tf saved model
onnx_model = onnx.load(getattr(args, 'input_path'))
+ if _onnx_legalizer_enabled:
+ options = onnx_legalizer.LegalizeOptions
+ options.unroll_rnn = _utils._is_valid_attr(args, 'unroll_rnn')
+ options.unroll_lstm = _utils._is_valid_attr(args, 'unroll_lstm')
+ onnx_legalizer.legalize(onnx_model, options)
tf_savedmodel = onnx_tf.backend.prepare(onnx_model)
savedmodel_name = os.path.splitext(os.path.basename(
@@ -133,7 +155,7 @@ def _convert(args):
args.output_path))[0] + '.tflite'
tf2tfliteV2_output_path = os.path.join(tmpdir, tf2tfliteV2_output_name)
- tf2tfliteV2_cmd = _utils._make_tf2tfliteV2_cmd(
+ tf2tfliteV2_cmd = _make_cmd.make_tf2tfliteV2_cmd(
args, tf2tfliteV2_path, savedmodel_output_path, tf2tfliteV2_output_path)
f.write((' '.join(tf2tfliteV2_cmd) + '\n').encode())
@@ -143,7 +165,7 @@ def _convert(args):
# make a command to convert from tflite to circle
tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
- tflite2circle_cmd = _utils._make_tflite2circle_cmd(tflite2circle_path,
+ tflite2circle_cmd = _make_cmd.make_tflite2circle_cmd(tflite2circle_path,
tf2tfliteV2_output_path,
getattr(args, 'output_path'))
diff --git a/compiler/one-cmds/one-import-pytorch b/compiler/one-cmds/one-import-pytorch
new file mode 100644
index 000000000..dbf1ba6d7
--- /dev/null
+++ b/compiler/one-cmds/one-import-pytorch
@@ -0,0 +1,366 @@
+#!/usr/bin/env bash
+''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
+''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python # '''
+''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@" # '''
+''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
+''''exit 255 # '''
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import importlib
+import inspect
+import os
+import sys
+import tempfile
+import torch
+import onnx
+import onnx_tf
+import json
+import zipfile
+
+import onnx_legalizer
+import onelib.make_cmd as _make_cmd
+import utils as _utils
+
+# TODO Find better way to suppress trackback on error
+sys.tracebacklimit = 0
+
+
+def get_driver_spec():
+ return ("one-import-pytorch", _utils.DriverType.IMPORTER)
+
+
+def _get_parser():
+ parser = argparse.ArgumentParser(
+ description='command line tool to convert PyTorch to Circle')
+
+ _utils._add_default_arg(parser)
+
+ ## converter arguments
+ converter_group = parser.add_argument_group('converter arguments')
+
+ # input and output path.
+ converter_group.add_argument(
+ '-i', '--input_path', type=str, help='full filepath of the input file')
+ converter_group.add_argument(
+ '-p', '--python_path', type=str, help='full filepath of the python model file')
+ converter_group.add_argument(
+ '-o', '--output_path', type=str, help='full filepath of the output file')
+
+ # input arrays.
+ converter_group.add_argument(
+ '-s',
+ '--input_shapes',
+ type=str,
+ help=
+ 'Shapes corresponding to --input_arrays, colon-separated.(ex:\"1,4,4,3:1,20,20,3\")'
+ )
+ converter_group.add_argument(
+ '-t',
+ '--input_types',
+ type=str,
+ help='data types of input tensors, colon-separated (ex: float32, uint8, int32)')
+
+ # fixed options
+ tf2tflite_group = parser.add_argument_group('tf2tfliteV2 arguments')
+ tf2tflite_group.add_argument('--model_format', default='saved_model')
+ tf2tflite_group.add_argument('--converter_version', default='v2')
+
+ parser.add_argument('--unroll_rnn', action='store_true', help='Unroll RNN operators')
+ parser.add_argument('--unroll_lstm', action='store_true', help='Unroll LSTM operators')
+
+ # save intermediate file(s)
+ parser.add_argument(
+ '--save_intermediate',
+ action='store_true',
+ help='Save intermediate files to output folder')
+
+ return parser
+
+
+def _verify_arg(parser, args):
+ """verify given arguments"""
+ # check if required arguments is given
+ missing = []
+ if not _utils._is_valid_attr(args, 'input_path'):
+ missing.append('-i/--input_path')
+ if not _utils._is_valid_attr(args, 'output_path'):
+ missing.append('-o/--output_path')
+ if not _utils._is_valid_attr(args, 'input_shapes'):
+ missing.append('-s/--input_shapes')
+ if not _utils._is_valid_attr(args, 'input_types'):
+ missing.append('-t/--input_types')
+
+ if len(missing):
+ parser.error('the following arguments are required: ' + ' '.join(missing))
+
+
+def _parse_arg(parser):
+ args = parser.parse_args()
+ # print version
+ if args.version:
+ _utils._print_version_and_exit(__file__)
+
+ return args
+
+
+def _apply_verbosity(verbosity):
+ # NOTE
+ # TF_CPP_MIN_LOG_LEVEL
+ # 0 : INFO + WARNING + ERROR + FATAL
+ # 1 : WARNING + ERROR + FATAL
+ # 2 : ERROR + FATAL
+ # 3 : FATAL
+ if verbosity:
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
+ else:
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
+
+
+def _parse_shapes(shapes_str):
+ shapes = []
+ for shape_str in shapes_str.split(":"):
+ if shape_str != "":
+ shapes += [list(map(int, shape_str.split(",")))]
+ else:
+ shapes += [[]]
+ return shapes
+
+
+def _parse_types(types_str):
+ # There are no convenient way to create torch from string ot numpy dtype, so using this workaround
+ dtype_dict = {
+ "bool": torch.bool,
+ "uint8": torch.uint8,
+ "int8": torch.int8,
+ "int16": torch.int16,
+ "int32": torch.int32,
+ "int64": torch.int64,
+ "float16": torch.float16,
+ "float32": torch.float32,
+ "float64": torch.float64,
+ "complex64": torch.complex64,
+ "complex128": torch.complex128
+ }
+ array = types_str.split(",")
+ types = [dtype_dict[type_str.strip()] for type_str in array]
+ return types
+
+
+# merge contents of module into global namespace
+def _merge_module(module):
+ # is there an __all__? if so respect it
+ if "__all__" in module.__dict__:
+ names = module.__dict__["__all__"]
+ else:
+ # otherwise we import all names that don't begin with _
+ names = [x for x in module.__dict__ if not x.startswith("_")]
+ globals().update({k: getattr(module, k) for k in names})
+
+
+def _list_classes_from_module(module):
+ # Parsing the module to get all defined classes
+ is_member = lambda member: inspect.isclass(member) and member.__module__ == module.__name__
+ classes = [cls[1] for cls in inspect.getmembers(module, is_member)]
+ return classes
+
+
+def _extract_pytorch_model(log_file, parameters_path, python_path):
+ log_file.write(('Trying to load saved model\n').encode())
+ python_model_path = os.path.abspath(python_path)
+ module_name = os.path.basename(python_model_path)
+ module_dir = os.path.dirname(python_model_path)
+ sys.path.append(module_dir)
+ log_file.write(('Trying to load given python module\n').encode())
+ module_loader = importlib.machinery.SourceFileLoader(module_name, python_model_path)
+ module_spec = importlib.util.spec_from_loader(module_name, module_loader)
+ python_model_module = importlib.util.module_from_spec(module_spec)
+
+ try:
+ module_loader.exec_module(python_model_module)
+ except:
+ raise ValueError('Failed to execute given python model file')
+
+ log_file.write(('Model python module is loaded\n').encode())
+ try:
+ # this branch assumes this parameters_path contains state_dict
+ state_dict = torch.load(parameters_path)
+ log_file.write(('Trying to find model class and fill it`s state dict\n').encode())
+ model_class_definitions = _list_classes_from_module(python_model_module)
+ if len(model_class_definitions) != 1:
+ raise ValueError("Expected only one class as model definition. {}".format(
+ model_class_definitions))
+ pytorch_model_class = model_class_definitions[0]
+ model = pytorch_model_class()
+ model.load_state_dict(state_dict)
+ return model
+ except:
+ # this branch assumes this parameters_path contains "entire" model
+ _merge_module(python_model_module)
+ log_file.write(('Model python module is merged into main environment\n').encode())
+ model = torch.load(parameters_path)
+ log_file.write(('Pytorch model loaded\n').encode())
+ return model
+
+
+def _extract_torchscript_model(log_file, input_path):
+ # assuming this is a pytorch script
+ log_file.write(('Trying to load TorchScript model\n').encode())
+ try:
+ pytorch_model = torch.jit.load(input_path)
+ return pytorch_model
+ except RuntimeError as e:
+ log_file.write((str(e) + '\n').encode())
+ log_file.write(
+ 'Failed to import input file. Maybe this it contains only weights? Try pass "python_path" argument\n'.
+ encode())
+ raise
+ log_file.write(('TorchScript model is loaded\n').encode())
+
+
+def _extract_mar_model(log_file, tmpdir, input_path):
+ mar_dir_path = os.path.join(tmpdir, 'mar')
+ with zipfile.ZipFile(input_path) as zip_input:
+ zip_input.extractall(path=mar_dir_path)
+ manifest_path = os.path.join(mar_dir_path, 'MAR-INF/MANIFEST.json')
+ with open(manifest_path) as manifest_file:
+ manifest = json.load(manifest_file)
+ serialized_file = os.path.join(mar_dir_path, manifest['model']['serializedFile'])
+ if 'modelFile' in manifest['model']:
+ model_file = os.path.join(mar_dir_path, manifest['model']['modelFile'])
+ return _extract_pytorch_model(log_file, serialized_file, model_file)
+ else:
+ return _extract_torchscript_model(log_file, serialized_file)
+
+
+def _convert(args):
+ _apply_verbosity(args.verbose)
+
+ # get file path to log
+ dir_path = os.path.dirname(os.path.realpath(__file__))
+ logfile_path = os.path.realpath(args.output_path) + '.log'
+ with open(logfile_path, 'wb') as f, tempfile.TemporaryDirectory() as tmpdir:
+ # save intermediate
+ if _utils._is_valid_attr(args, 'save_intermediate'):
+ tmpdir = os.path.dirname(logfile_path)
+ # convert pytorch to onnx model
+ input_path = getattr(args, 'input_path')
+ model_file = getattr(args, 'python_path')
+
+ if input_path[-4:] == '.mar':
+ pytorch_model = _extract_mar_model(f, tmpdir, input_path)
+ elif model_file is None:
+ pytorch_model = _extract_torchscript_model(f, input_path)
+ else:
+ pytorch_model = _extract_pytorch_model(f, input_path, model_file)
+
+ input_shapes = _parse_shapes(getattr(args, 'input_shapes'))
+ input_types = _parse_types(getattr(args, 'input_types'))
+
+ if len(input_shapes) != len(input_types):
+ raise ValueError('number of input shapes and input types must be equal')
+
+ sample_inputs = []
+ for input_spec in zip(input_shapes, input_types):
+ sample_inputs += [torch.ones(input_spec[0], dtype=input_spec[1])]
+
+ f.write(('Trying to inference loaded model').encode())
+ sample_outputs = pytorch_model(*sample_inputs)
+ f.write(('Acquired sample outputs\n').encode())
+
+ onnx_output_name = os.path.splitext(os.path.basename(
+ args.output_path))[0] + '.onnx'
+ onnx_output_path = os.path.join(tmpdir, onnx_output_name)
+
+ onnx_saved = False
+ # some operations are not supported in early opset versions, try several
+ for onnx_opset_version in range(9, 15):
+ f.write(('Trying to save onnx model using opset version ' +
+ str(onnx_opset_version) + '\n').encode())
+ try:
+ torch.onnx.export(
+ pytorch_model,
+ tuple(sample_inputs),
+ onnx_output_path,
+ example_outputs=sample_outputs,
+ opset_version=onnx_opset_version)
+ onnx_saved = True
+ break
+ except:
+ f.write(('attempt failed\n').encode())
+
+ if not onnx_saved:
+ raise ValueError('Failed to save temporary onnx model')
+
+ # convert onnx to tf saved mode
+ onnx_model = onnx.load(onnx_output_path)
+
+ options = onnx_legalizer.LegalizeOptions()
+ options.unroll_rnn = _utils._is_valid_attr(args, 'unroll_rnn')
+ options.unroll_lstm = _utils._is_valid_attr(args, 'unroll_lstm')
+ onnx_legalizer.legalize(onnx_model, options)
+
+ tf_savedmodel = onnx_tf.backend.prepare(onnx_model)
+
+ savedmodel_name = os.path.splitext(os.path.basename(
+ args.output_path))[0] + '.savedmodel'
+ savedmodel_output_path = os.path.join(tmpdir, savedmodel_name)
+ tf_savedmodel.export_graph(savedmodel_output_path)
+
+ # make a command to convert from tf to tflite
+ tf2tfliteV2_path = os.path.join(dir_path, 'tf2tfliteV2.py')
+ tf2tfliteV2_output_name = os.path.splitext(os.path.basename(
+ args.output_path))[0] + '.tflite'
+ tf2tfliteV2_output_path = os.path.join(tmpdir, tf2tfliteV2_output_name)
+
+ del args.input_shapes
+ tf2tfliteV2_cmd = _make_cmd.make_tf2tfliteV2_cmd(
+ args, tf2tfliteV2_path, savedmodel_output_path, tf2tfliteV2_output_path)
+
+ f.write((' '.join(tf2tfliteV2_cmd) + '\n').encode())
+
+ # convert tf to tflite
+ _utils._run(tf2tfliteV2_cmd, logfile=f)
+
+ # make a command to convert from tflite to circle
+ tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
+ tflite2circle_cmd = _make_cmd.make_tflite2circle_cmd(tflite2circle_path,
+ tf2tfliteV2_output_path,
+ getattr(args, 'output_path'))
+
+ f.write((' '.join(tflite2circle_cmd) + '\n').encode())
+
+ # convert tflite to circle
+ _utils._run(tflite2circle_cmd, err_prefix="tflite2circle", logfile=f)
+
+
+def main():
+ # parse arguments
+ parser = _get_parser()
+ args = _parse_arg(parser)
+
+ # parse configuration file
+ _utils._parse_cfg(args, 'one-import-pytorch')
+
+ # verify arguments
+ _verify_arg(parser, args)
+
+ # convert
+ _convert(args)
+
+
+if __name__ == '__main__':
+ _utils._safemain(main, __file__)
diff --git a/compiler/one-cmds/one-import-tf b/compiler/one-cmds/one-import-tf
index e2294caa6..999255a34 100644
--- a/compiler/one-cmds/one-import-tf
+++ b/compiler/one-cmds/one-import-tf
@@ -25,9 +25,14 @@ import subprocess
import sys
import tempfile
+import onelib.make_cmd as _make_cmd
import utils as _utils
+def get_driver_cfg_section():
+ return "one-import-tf"
+
+
def _get_parser():
parser = argparse.ArgumentParser(
description='command line tool to convert TensorFlow to circle')
@@ -146,7 +151,7 @@ def _convert(args):
tf2tfliteV2_output_path = os.path.join(
tmpdir,
os.path.splitext(os.path.basename(args.output_path))[0]) + '.tflite'
- tf2tfliteV2_cmd = _utils._make_tf2tfliteV2_cmd(args, tf2tfliteV2_path,
+ tf2tfliteV2_cmd = _make_cmd.make_tf2tfliteV2_cmd(args, tf2tfliteV2_path,
getattr(args, 'input_path'),
tf2tfliteV2_output_path)
@@ -157,7 +162,7 @@ def _convert(args):
# make a command to convert from tflite to circle
tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
- tflite2circle_cmd = _utils._make_tflite2circle_cmd(tflite2circle_path,
+ tflite2circle_cmd = _make_cmd.make_tflite2circle_cmd(tflite2circle_path,
tf2tfliteV2_output_path,
getattr(args, 'output_path'))
diff --git a/compiler/one-cmds/one-import-tflite b/compiler/one-cmds/one-import-tflite
index 7eee0484a..2d756bff6 100644
--- a/compiler/one-cmds/one-import-tflite
+++ b/compiler/one-cmds/one-import-tflite
@@ -24,12 +24,17 @@ import os
import subprocess
import sys
+import onelib.make_cmd as _make_cmd
import utils as _utils
# TODO Find better way to suppress trackback on error
sys.tracebacklimit = 0
+def get_driver_cfg_section():
+ return "one-import-tflite"
+
+
def _get_parser():
parser = argparse.ArgumentParser(
description='command line tool to convert TensorFlow lite to circle')
@@ -77,7 +82,7 @@ def _convert(args):
with open(logfile_path, 'wb') as f:
# make a command to convert from tflite to circle
tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
- tflite2circle_cmd = _utils._make_tflite2circle_cmd(tflite2circle_path,
+ tflite2circle_cmd = _make_cmd.make_tflite2circle_cmd(tflite2circle_path,
getattr(args, 'input_path'),
getattr(args, 'output_path'))
diff --git a/compiler/one-cmds/one-optimize b/compiler/one-cmds/one-optimize
index a64abff19..8b1f3f7be 100644
--- a/compiler/one-cmds/one-optimize
+++ b/compiler/one-cmds/one-optimize
@@ -24,6 +24,8 @@ import os
import subprocess
import sys
+import onelib.constant as _constant
+import onelib.make_cmd as _make_cmd
import utils as _utils
# TODO Find better way to suppress trackback on error
@@ -60,7 +62,7 @@ def _get_parser():
'-o', '--output_path', type=str, help='full filepath of the output file')
# optimization pass
- for opt in _utils._CONSTANT.OPTIMIZATION_OPTS:
+ for opt in _constant.CONSTANT.OPTIMIZATION_OPTS:
# opt = (option_name, help_message)
circle2circle_group.add_argument('--' + opt[0], action='store_true', help=opt[1])
@@ -99,7 +101,7 @@ def _optimize(args):
with open(logfile_path, 'wb') as f:
# make a command to optimize circle model
circle2circle_path = os.path.join(dir_path, 'circle2circle')
- circle2circle_cmd = _utils._make_circle2circle_cmd(args, circle2circle_path,
+ circle2circle_cmd = _make_cmd.make_circle2circle_cmd(args, circle2circle_path,
getattr(args, 'input_path'),
getattr(args, 'output_path'))
diff --git a/compiler/one-cmds/one-prepare-venv b/compiler/one-cmds/one-prepare-venv
index 285191761..0f75166a7 100644
--- a/compiler/one-cmds/one-prepare-venv
+++ b/compiler/one-cmds/one-prepare-venv
@@ -26,16 +26,17 @@ VENV_PYTHON=${DRIVER_PATH}/venv/bin/python
if [ ! -f ${VENV_ACTIVATE} ]; then
# Create python virtual enviornment
- python3 -m venv "${DRIVER_PATH}/venv"
+ python3.8 -m venv "${DRIVER_PATH}/venv"
fi
# NOTE version
# - https://github.com/onnx/onnx/blob/master/docs/Versioning.md
# - https://github.com/onnx/onnx-tensorflow/blob/master/Versioning.md
-VER_TENSORFLOW=2.3.0
-VER_ONNX=1.10.1
-VER_ONNX_TF=1.9.0
+VER_TENSORFLOW=2.8.0
+VER_ONNX=1.11.0
+VER_ONNXRUNTIME=1.11.0
+VER_ONNX_TF=1.10.0
# Install tensorflow
@@ -54,18 +55,32 @@ if [[ ! -z "$ONE_PREPVENV_PIP_OPTION" ]]; then
PIP_OPTIONS+=" ${ONE_PREPVENV_PIP_OPTION} "
fi
-# TODO remove version number of 'pip==20.2.1 setuptools==49.3.0'
-# NOTE adding version is for temporary hotfix of setuptools 50.x.y version
-${VENV_PYTHON} -m pip ${PIP_OPTIONS} install -U pip==20.2.1 setuptools==49.3.0
-${VENV_PYTHON} -m pip ${PIP_OPTIONS} install tensorflow-cpu==${VER_TENSORFLOW}
-${VENV_PYTHON} -m pip ${PIP_OPTIONS} install Pillow==6.2.2
+${VENV_PYTHON} -m pip ${PIP_OPTIONS} install --upgrade pip setuptools
+if [ -n "${EXT_TENSORFLOW_WHL}" ]; then
+ ${VENV_PYTHON} -m pip ${PIP_OPTIONS} install ${EXT_TENSORFLOW_WHL}
+else
+ ${VENV_PYTHON} -m pip ${PIP_OPTIONS} install tensorflow-cpu==${VER_TENSORFLOW}
+fi
+${VENV_PYTHON} -m pip ${PIP_OPTIONS} install Pillow
+${VENV_PYTHON} -m pip ${PIP_OPTIONS} install tensorflow_probability
# Install PyTorch and ONNX related
-${VENV_PYTHON} -m pip ${PIP_OPTIONS} install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
+# NOTE set ONE_PREPVENV_TORCH_STABLE to override 'torch_stable.html' URL.
+# torch_stable.html points to download URL of torch wheel file(s)
+# but sometimes the server gets unstable, especially from in-house CI.
+TORCH_STABLE_URL="https://download.pytorch.org/whl/torch_stable.html"
+if [[ ! -z "$ONE_PREPVENV_TORCH_STABLE" ]]; then
+ TORCH_STABLE_URL="${ONE_PREPVENV_TORCH_STABLE}"
+fi
+${VENV_PYTHON} -m pip ${PIP_OPTIONS} install torch==1.11.0+cpu -f ${TORCH_STABLE_URL}
+
+${VENV_PYTHON} -m pip ${PIP_OPTIONS} install onnx==${VER_ONNX}
+
+${VENV_PYTHON} -m pip ${PIP_OPTIONS} install onnxruntime==${VER_ONNXRUNTIME}
# Provide install of custom onnx-tf
if [ -n "${EXT_ONNX_TF_WHL}" ]; then
- ${VENV_PYTHON} -m pip ${PIP_OPTIONS} install onnx==${VER_ONNX} ${EXT_ONNX_TF_WHL}
+ ${VENV_PYTHON} -m pip ${PIP_OPTIONS} install ${EXT_ONNX_TF_WHL}
else
- ${VENV_PYTHON} -m pip ${PIP_OPTIONS} install onnx==${VER_ONNX} onnx-tf==${VER_ONNX_TF}
+ ${VENV_PYTHON} -m pip ${PIP_OPTIONS} install onnx-tf==${VER_ONNX_TF}
fi
diff --git a/compiler/one-cmds/one-quantize b/compiler/one-cmds/one-quantize
index 22d4ddb0e..f2eff24bd 100644
--- a/compiler/one-cmds/one-quantize
+++ b/compiler/one-cmds/one-quantize
@@ -119,6 +119,18 @@ def _get_parser():
help=
"calibration algorithm for post-training quantization (supported: percentile/moving_average, default=percentile). 'percentile' mode uses the n-th percentiles as min/max values. 'moving_average' mode records the moving average of min/max."
)
+ quantization_group.add_argument(
+ '--TF-style_maxpool',
+ action='store_true',
+ help=
+ "Force MaxPool Op to have the same input/output quantparams. NOTE: This option can degrade accuracy of some models.)"
+ )
+ quantization_group.add_argument(
+ '--quant_config',
+ type=str,
+ help=
+ "Path to the quantization configuration file."
+ )
# arguments for force_quantparam option
force_quantparam_group = parser.add_argument_group(
@@ -137,6 +149,19 @@ def _get_parser():
force_quantparam_group.add_argument(
'--zero_point', type=int, action='append', help='zero point (int)')
+ # arguments for copy_quantparam option
+ copy_quantparam_group = parser.add_argument_group(
+ 'arguments for copy_quantparam option')
+
+ copy_quantparam_group.add_argument(
+ '--copy_quantparam',
+ action='store_true',
+ help='copy quantparam (scale, zero_point) of a tensor to another tensor.')
+ copy_quantparam_group.add_argument(
+ '--src_tensor_name', type=str, action='append', help='tensor name (string)')
+ copy_quantparam_group.add_argument(
+ '--dst_tensor_name', type=str, action='append', help='tensor name (string)')
+
return parser
@@ -171,6 +196,11 @@ def _verify_arg(parser, args):
missing.append('--scale')
if not _utils._is_valid_attr(args, 'zero_point'):
missing.append('--zero_point')
+ if _utils._is_valid_attr(args, 'copy_quantparam'):
+ if not _utils._is_valid_attr(args, 'src_tensor_name'):
+ missing.append('--src_tensor_name')
+ if not _utils._is_valid_attr(args, 'dst_tensor_name'):
+ missing.append('--dst_tensor_name')
if len(missing):
parser.error('the following arguments are required: ' + ' '.join(missing))
if _utils._is_valid_attr(args, 'force_quantparam'):
@@ -180,6 +210,12 @@ def _verify_arg(parser, args):
if len(tensors) != len(scales) or len(tensors) != len(zerops):
parser.error(
'The same number of tensor_name, scale, and zero_point should be given.')
+ if _utils._is_valid_attr(args, 'copy_quantparam'):
+ src_tensors = getattr(args, 'src_tensor_name')
+ dst_tensors = getattr(args, 'dst_tensor_name')
+ if len(src_tensors) != len(dst_tensors):
+ parser.error(
+ 'The same number of src_tensor_name and dst_tensor_name should be given.')
def _parse_arg(parser):
@@ -197,6 +233,11 @@ def _quantize(args):
_write_qparam(args)
return
+ if _utils._is_valid_attr(args, 'copy_quantparam'):
+ # copy quantization parameters
+ _copy_qparam(args)
+ return
+
# get file path to log
dir_path = os.path.dirname(os.path.realpath(__file__))
logfile_path = os.path.realpath(args.output_path) + '.log'
@@ -294,12 +335,19 @@ def _quantize(args):
circle_quantizer_cmd.append(getattr(args, 'quantized_dtype'))
if _utils._is_valid_attr(args, 'granularity'):
circle_quantizer_cmd.append(getattr(args, 'granularity'))
+ if _utils._is_valid_attr(args, 'TF-style_maxpool'):
+ circle_quantizer_cmd.append('--TF-style_maxpool')
if _utils._is_valid_attr(args, 'input_type'):
circle_quantizer_cmd.append('--input_type')
circle_quantizer_cmd.append(getattr(args, 'input_type'))
if _utils._is_valid_attr(args, 'output_type'):
circle_quantizer_cmd.append('--output_type')
circle_quantizer_cmd.append(getattr(args, 'output_type'))
+ if _utils._is_valid_attr(args, 'quant_config'):
+ # NOTE --config conflicts with --config option in onecc, so
+ # we use quant_config for one-quantize
+ circle_quantizer_cmd.append('--config')
+ circle_quantizer_cmd.append(getattr(args, 'quant_config'))
# input and output path
circle_quantizer_cmd.append(tmp_output_path_2)
if _utils._is_valid_attr(args, 'output_path'):
@@ -351,6 +399,40 @@ def _write_qparam(args):
_utils._run(circle_quantizer_cmd, err_prefix="circle_quantizer", logfile=f)
+def _copy_qparam(args):
+ # get file path to log
+ dir_path = os.path.dirname(os.path.realpath(__file__))
+ logfile_path = os.path.realpath(args.output_path) + '.log'
+
+ with open(logfile_path, 'wb') as f:
+ # get driver path
+ circle_quantizer_path = os.path.join(dir_path, 'circle-quantizer')
+
+ # make a command to write qparams to the tensors
+ circle_quantizer_cmd = [circle_quantizer_path]
+ # verbose
+ if _utils._is_valid_attr(args, 'verbose'):
+ circle_quantizer_cmd.append('--verbose')
+ if _utils._is_valid_attr(args, 'src_tensor_name'):
+ src_tensor_name = getattr(args, 'src_tensor_name')
+ if _utils._is_valid_attr(args, 'dst_tensor_name'):
+ dst_tensor_name = getattr(args, 'dst_tensor_name')
+ for (src, dst) in zip(src_tensor_name, dst_tensor_name):
+ circle_quantizer_cmd.append('--copy_quantparam')
+ circle_quantizer_cmd.append(src)
+ circle_quantizer_cmd.append(dst)
+ # input and output path
+ if _utils._is_valid_attr(args, 'input_path'):
+ circle_quantizer_cmd.append(getattr(args, 'input_path'))
+ if _utils._is_valid_attr(args, 'output_path'):
+ circle_quantizer_cmd.append(getattr(args, 'output_path'))
+
+ f.write((' '.join(circle_quantizer_cmd) + '\n').encode())
+
+ # run circle-quantizer
+ _utils._run(circle_quantizer_cmd, err_prefix="circle_quantizer", logfile=f)
+
+
def main():
# parse arguments
parser = _get_parser()
diff --git a/compiler/one-cmds/onecc b/compiler/one-cmds/onecc
index ca440d852..25682ff4b 100644
--- a/compiler/one-cmds/onecc
+++ b/compiler/one-cmds/onecc
@@ -104,10 +104,6 @@ def _verify_arg(parser, args):
def _get_driver_name(driver_name):
return {
- 'one-import-bcq': 'one-import-bcq',
- 'one-import-tf': 'one-import-tf',
- 'one-import-tflite': 'one-import-tflite',
- 'one-import-onnx': 'one-import-onnx',
'one-optimize': 'one-optimize',
'one-quantize': 'one-quantize',
'one-pack': 'one-pack',
@@ -130,19 +126,15 @@ def _is_available_driver(config, driver_name):
'onecc', driver_name)
-def _verify_cfg(driver_list, config):
+def _verify_cfg(import_driver_list, config):
if not config.has_section('onecc'):
raise ImportError('[onecc] section is required in configuration file')
import_driver_cnt = 0
- if _is_available_driver(config, 'one-import-tf'):
- import_driver_cnt += 1
- if _is_available_driver(config, 'one-import-tflite'):
- import_driver_cnt += 1
- if _is_available_driver(config, 'one-import-bcq'):
- import_driver_cnt += 1
- if _is_available_driver(config, 'one-import-onnx'):
- import_driver_cnt += 1
+ for d in import_driver_list:
+ if _is_available_driver(config, d):
+ import_driver_cnt += 1
+
if import_driver_cnt > 1:
raise AssertionError('Only one import-* driver can be executed')
@@ -170,22 +162,27 @@ def main():
config = _parse_cfg(args)
# verify configuration file
- drivers = [
- 'one-import-tf', 'one-import-tflite', 'one-import-bcq', 'one-import-onnx',
+ bin_dir = os.path.dirname(os.path.realpath(__file__))
+ import_drivers_dict = _utils._detect_one_import_drivers(bin_dir)
+ transform_drivers = [
'one-optimize', 'one-quantize', 'one-pack', 'one-codegen', 'one-profile'
]
- _verify_cfg(drivers, config)
+ _verify_cfg(import_drivers_dict, config)
# get sections to run
section_to_run = []
- for d in drivers:
+ for d in list(import_drivers_dict) + transform_drivers:
if _is_available_driver(config, d):
section_to_run.append(d)
# run
dir_path = os.path.dirname(os.path.realpath(__file__))
for section in section_to_run:
- driver_name = _get_driver_name(section)
+ if section in import_drivers_dict:
+ # we already has driver name in dict
+ driver_name = import_drivers_dict[section]
+ else:
+ driver_name = _get_driver_name(section)
options = ['--config', getattr(args, 'config'), '--section', section]
if _utils._is_valid_attr(args, 'verbose'):
options.append('--verbose')
diff --git a/compiler/one-cmds/onelib/constant.py b/compiler/one-cmds/onelib/constant.py
new file mode 100644
index 000000000..7ddd7382d
--- /dev/null
+++ b/compiler/one-cmds/onelib/constant.py
@@ -0,0 +1,86 @@
+#!/usr/bin/env python
+
+# Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+class CONSTANT:
+ __slots__ = () # This prevents access via __dict__.
+ OPTIMIZATION_OPTS = (
+ # (OPTION_NAME, HELP_MESSAGE)
+ ('O1', 'enable O1 optimization pass'),
+ ('convert_nchw_to_nhwc',
+ 'Experimental: This will convert NCHW operators to NHWC under the assumption that input model is NCHW.'
+ ),
+ ('expand_broadcast_const', 'expand broadcastable constant node inputs'),
+ ('nchw_to_nhwc_input_shape',
+ 'convert the input shape of the model (argument for convert_nchw_to_nhwc)'),
+ ('nchw_to_nhwc_output_shape',
+ 'convert the output shape of the model (argument for convert_nchw_to_nhwc)'),
+ ('fold_add_v2', 'fold AddV2 op with constant inputs'),
+ ('fold_cast', 'fold Cast op with constant input'),
+ ('fold_dequantize', 'fold Dequantize op'),
+ ('fold_dwconv', 'fold Depthwise Convolution op with constant inputs'),
+ ('fold_gather', 'fold Gather op'),
+ ('fold_sparse_to_dense', 'fold SparseToDense op'),
+ ('forward_reshape_to_unaryop', 'Forward Reshape op'),
+ ('fuse_add_with_tconv', 'fuse Add op to Transposed'),
+ ('fuse_add_with_fully_connected', 'fuse Add op to FullyConnected op'),
+ ('fuse_batchnorm_with_conv', 'fuse BatchNorm op to Convolution op'),
+ ('fuse_batchnorm_with_dwconv', 'fuse BatchNorm op to Depthwise Convolution op'),
+ ('fuse_batchnorm_with_tconv', 'fuse BatchNorm op to Transposed Convolution op'),
+ ('fuse_bcq', 'apply Binary Coded Quantization'),
+ ('fuse_preactivation_batchnorm',
+ 'fuse BatchNorm operators of pre-activations to Convolution op'),
+ ('fuse_mean_with_mean', 'fuse two consecutive Mean ops'),
+ ('fuse_transpose_with_mean',
+ 'fuse Mean with a preceding Transpose under certain conditions'),
+ ('make_batchnorm_gamma_positive',
+ 'make negative gamma of BatchNorm to a small positive value (1e-10).'
+ ' Note that this pass can change the execution result of the model.'
+ ' So, use it only when the impact is known to be acceptable.'),
+ ('fuse_activation_function', 'fuse Activation function to a preceding operator'),
+ ('fuse_instnorm', 'fuse ops to InstanceNorm operator'),
+ ('replace_cw_mul_add_with_depthwise_conv',
+ 'replace channel-wise Mul/Add with DepthwiseConv2D'),
+ ('remove_fakequant', 'remove FakeQuant ops'),
+ ('remove_quantdequant', 'remove Quantize-Dequantize sequence'),
+ ('remove_redundant_quantize', 'remove redundant Quantize ops'),
+ ('remove_redundant_reshape', 'fuse or remove subsequent Reshape ops'),
+ ('remove_redundant_transpose', 'fuse or remove subsequent Transpose ops'),
+ ('remove_unnecessary_reshape', 'remove unnecessary reshape ops'),
+ ('remove_unnecessary_slice', 'remove unnecessary slice ops'),
+ ('remove_unnecessary_strided_slice', 'remove unnecessary strided slice ops'),
+ ('remove_unnecessary_split', 'remove unnecessary split ops'),
+ ('resolve_customop_add', 'convert Custom(Add) op to Add op'),
+ ('resolve_customop_batchmatmul',
+ 'convert Custom(BatchMatmul) op to BatchMatmul op'),
+ ('resolve_customop_matmul', 'convert Custom(Matmul) op to Matmul op'),
+ ('resolve_customop_max_pool_with_argmax',
+ 'convert Custom(MaxPoolWithArgmax) to net of builtin operators'),
+ ('shuffle_weight_to_16x1float32',
+ 'convert weight format of FullyConnected op to SHUFFLED16x1FLOAT32.'
+ ' Note that it only converts weights whose row is a multiple of 16'),
+ ('substitute_pack_to_reshape', 'convert single input Pack op to Reshape op'),
+ ('substitute_padv2_to_pad', 'convert certain condition PadV2 to Pad'),
+ ('substitute_splitv_to_split', 'convert certain condition SplitV to Split'),
+ ('substitute_squeeze_to_reshape', 'convert certain condition Squeeze to Reshape'),
+ ('substitute_strided_slice_to_reshape',
+ 'convert certain condition StridedSlice to Reshape'),
+ ('substitute_transpose_to_reshape',
+ 'convert certain condition Transpose to Reshape'),
+ ('transform_min_max_to_relu6', 'transform Minimum-Maximum pattern to Relu6 op'),
+ ('transform_min_relu_to_relu6', 'transform Minimum(6)-Relu pattern to Relu6 op'))
+
+
+CONSTANT = CONSTANT()
diff --git a/compiler/one-cmds/onelib/make_cmd.py b/compiler/one-cmds/onelib/make_cmd.py
new file mode 100644
index 000000000..d8380f28d
--- /dev/null
+++ b/compiler/one-cmds/onelib/make_cmd.py
@@ -0,0 +1,100 @@
+#!/usr/bin/env python
+
+# Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+
+import onelib.constant as _constant
+
+def _is_valid_attr(args, attr):
+ return hasattr(args, attr) and getattr(args, attr)
+
+
+def make_tf2tfliteV2_cmd(args, driver_path, input_path, output_path):
+ """make a command for running tf2tfliteV2.py"""
+ cmd = [sys.executable, os.path.expanduser(driver_path)]
+ # verbose
+ if _is_valid_attr(args, 'verbose'):
+ cmd.append('--verbose')
+ # model_format
+ if _is_valid_attr(args, 'model_format_cmd'):
+ cmd.append(getattr(args, 'model_format_cmd'))
+ elif _is_valid_attr(args, 'model_format'):
+ cmd.append('--' + getattr(args, 'model_format'))
+ else:
+ cmd.append('--graph_def') # default value
+ # converter version
+ if _is_valid_attr(args, 'converter_version_cmd'):
+ cmd.append(getattr(args, 'converter_version_cmd'))
+ elif _is_valid_attr(args, 'converter_version'):
+ cmd.append('--' + getattr(args, 'converter_version'))
+ else:
+ cmd.append('--v1') # default value
+ # input_path
+ if _is_valid_attr(args, 'input_path'):
+ cmd.append('--input_path')
+ cmd.append(os.path.expanduser(input_path))
+ # output_path
+ if _is_valid_attr(args, 'output_path'):
+ cmd.append('--output_path')
+ cmd.append(os.path.expanduser(output_path))
+ # input_arrays
+ if _is_valid_attr(args, 'input_arrays'):
+ cmd.append('--input_arrays')
+ cmd.append(getattr(args, 'input_arrays'))
+ # input_shapes
+ if _is_valid_attr(args, 'input_shapes'):
+ cmd.append('--input_shapes')
+ cmd.append(getattr(args, 'input_shapes'))
+ # output_arrays
+ if _is_valid_attr(args, 'output_arrays'):
+ cmd.append('--output_arrays')
+ cmd.append(getattr(args, 'output_arrays'))
+
+ return cmd
+
+
+def make_tflite2circle_cmd(driver_path, input_path, output_path):
+ """make a command for running tflite2circle"""
+ cmd = [driver_path, input_path, output_path]
+ return [os.path.expanduser(c) for c in cmd]
+
+
+def make_circle2circle_cmd(args, driver_path, input_path, output_path):
+ """make a command for running circle2circle"""
+ cmd = [os.path.expanduser(c) for c in [driver_path, input_path, output_path]]
+ # profiling
+ if _is_valid_attr(args, 'generate_profile_data'):
+ cmd.append('--generate_profile_data')
+ # optimization pass(only true/false options)
+ # TODO support options whose number of arguments is more than zero
+ for opt in _constant.CONSTANT.OPTIMIZATION_OPTS:
+ if _is_valid_attr(args, opt[0]):
+ # ./driver --opt[0]
+ if type(getattr(args, opt[0])) is bool:
+ cmd.append('--' + opt[0])
+ """
+ This condition check is for config file interface, usually would be
+ SomeOption=True
+ but user can write as follows while development
+ SomeOption=False
+ instead of removing SomeOption option
+ """
+ if type(getattr(args, opt[0])) is str and not getattr(
+ args, opt[0]).lower() in ['false', '0', 'n']:
+ cmd.append('--' + opt[0])
+
+ return cmd
diff --git a/compiler/one-cmds/onnx_legalizer.py b/compiler/one-cmds/onnx_legalizer.py
new file mode 100755
index 000000000..26c2b75b9
--- /dev/null
+++ b/compiler/one-cmds/onnx_legalizer.py
@@ -0,0 +1,1065 @@
+#!/usr/bin/python3
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import onnx
+import onnx.numpy_helper
+import sys
+import numpy as np
+import re
+
+# Transform onnx model to make it compilable with our toolchain
+#
+# This code works with onnx model in proto format. See proto buffers format in
+# https://github.com/onnx/onnx/blob/96516aecd4c110b0ac57eba08ac236ebf7205728/onnx/onnx.proto3
+#
+# More examples of handling onnx models could be found here:
+# https://github.com/onnx/onnx/tree/96516aecd4c110b0ac57eba08ac236ebf7205728/onnx/examples
+#
+# List of transformations:
+# - Replace RNN operation with unrolled subgraph
+# - Replace LSTM operation with unrolled subgraph
+
+
+class LegalizeOptions:
+ """Controls transformations that legalizer apply
+
+ Attributes:
+ unroll_rnn (bool): default is False. If True - unrolls RNN operations
+ unroll_lstm (bool): default is False. If True - unrolls LSTM operations
+ """
+
+ unroll_rnn = False
+ unroll_lstm = False
+
+
+def _reverse_str(s):
+ return ''.join(reversed(s))
+
+
+def _parse_tensor_name(name):
+ """Splits tensor name to base part and serial number
+
+ Most of tensor names have following format: "tensor_123".
+ This function breaks name into two values: "tensor_" and 123.
+ Tensor names like this: "321" are broken into "" and 321.
+
+ Serial number is used to create unique tensor names using given base name.
+
+ Args:
+ name (str): tensor name
+
+ Returns:
+ tuple of str, int: base name and serial number of tensor
+ """
+ rev = _reverse_str(name)
+ m = re.match('(\d*)(.*)', rev)
+ if m.groups()[0] != '':
+ return (_reverse_str(m.groups()[1]), int(_reverse_str(m.groups()[0])))
+ else:
+ return (_reverse_str(m.groups()[1]), 0)
+
+
+class _ModelTransformerHelper:
+ """Helper for onnx model transformation
+
+ This helper is used for convenient operation replacement in onnx model
+
+ Attributes:
+ _model (onnx.onnx_ml_pb2.ModelProto): target model that should be changed
+ _nodes_to_delete (list of onnx.onnx_ml_pb2.NodeProto): list of replaced operations
+ _insert_id (int): position to insert created operations (should be in topologically sorted)
+ _base_name_idx (dict from str to int): maps tensor "base" name to
+ largest existing serial num. For example model has tensors "t_1", "t_2", "t_4",
+ in that case _base_name_idx["t_"] == 4.
+ This attribute is used for unique tensor name generation.
+ """
+
+ def __init__(self, model):
+ self._model = model
+ self._nodes_to_delete = []
+ self._insert_id = 0
+ # each tensor has name containing base name and unique number. for example:
+ # "abc_123": "abs_" - base name, "123" - unique number
+ # if no number in name, consider it is equal to "0"
+
+ # mapping from base names to largest given number
+ self._base_name_idx = {}
+ # gather name information for existing tensors
+ for node in model.graph.node:
+ for t in list(node.input) + list(node.output):
+ base_name, number = _parse_tensor_name(t)
+ if base_name in self._base_name_idx:
+ self._base_name_idx[base_name] = max(self._base_name_idx[base_name],
+ number)
+ else:
+ self._base_name_idx[base_name] = number
+
+ def make_tensor_with_base_name(self, base_name):
+ """ Create unique name for given base_name
+
+ Args:
+ base_name (str): base tensor name
+
+ Returns:
+ str : unique tensor name that starts with base_name
+ """
+ if base_name in self._base_name_idx:
+ self._base_name_idx[base_name] += 1
+ return base_name + str(self._base_name_idx[base_name])
+ else:
+ self._base_name_idx[base_name] = 0
+ return base_name + '0'
+
+ def make_node(self, opcode, inputs, outputs, *p_args, **k_args):
+ """Create arbitrary node and insert it in graph.
+
+ Args:
+ opcode (str): opcode name of desired operation
+ inputs (list of str): names of input tensors
+ outputs (list of str or int): names of existing tensors to use as output tensors for operation or
+ number of tensors that should be created
+ p_args: additional arguments for onnx make_node helper
+ k_args: attributes for onnx node
+
+ Returns:
+ list of str: list of output tensor names
+ """
+ if type(outputs) == int:
+ outputs = [self.make_tensor_with_base_name('') for i in range(outputs)]
+ assert (type(outputs) == list)
+ node = onnx.helper.make_node(opcode, inputs, outputs, *p_args, **k_args)
+ self._model.graph.node.insert(self._insert_id, node)
+ self._insert_id += 1
+ return outputs
+
+ def make_split(self, input, split_sizes, axis):
+ """Create Split operation and insert it in graph.
+
+ Args:
+ input (str): name of input tensor
+ split_sizes (list of int): list of split sizes
+ axis (int): number of axis to split
+
+ Returns:
+ list: list of output tensor names
+ """
+ return self.make_node(
+ 'Split', [input], len(split_sizes), axis=axis, split=split_sizes)
+
+ def make_concat(self, inputs, axis):
+ """Create Concat operation and insert it in graph.
+
+ Args:
+ inputs (list of str): list of tensors names to concat
+ axis (int): axis number to concat
+
+ Returns:
+ str: output tensor name
+ """
+ return self.make_node('Concat', inputs, 1, axis=axis)[0]
+
+ def make_squeeze(self, input, axes):
+ """Create Squeeze operation and insert it in graph.
+
+ Args:
+ input (str): name of input tensor
+ axes (list of int): list of dimension containing ones to remove
+
+ Returns:
+ str: output tensor name
+ """
+ return self.make_node('Squeeze', [input], 1, axes=axes)[0]
+
+ def make_unsqueeze(self, input, axes):
+ """Create Unsqueeze operation and insert it in graph.
+
+ Args:
+ input (str): name of input tensor
+ axes (list of int): list of dimension to insert ones
+
+ Returns:
+ str: output tensor name
+ """
+ return self.make_node('Unsqueeze', [input], 1, axes=axes)[0]
+
+ def make_gemm(self, A, B, C, trans_a=False, trans_b=False):
+ """Create Gemm operation and insert it in graph.
+
+ Result tensor contains A*B + C
+
+ Args:
+ A (str): name of tensor A
+ B (str): name of tensor B
+ C (str): name of tensor C
+ transA (bool): if True, transpose tensor A before multiplication
+ transB (bool): if True, transpose tensor B before multiplication
+
+ Returns:
+ str: output tensor name
+ """
+ return self.make_node(
+ 'Gemm', [A, B, C], 1, transA=bool(trans_a), transB=bool(trans_b))[0]
+
+ def make_add(self, a, b):
+ """Creates Add operation and insert it in graph.
+
+ Args:
+ a (str): name of left operand tensor
+ b (str): name of right operand tensor
+
+ Returns:
+ str: output tensor name
+ """
+ return self.make_node('Add', [a, b], 1)[0]
+
+ def make_mul(self, a, b):
+ """Creates Mul operation and insert it in graph.
+
+ Args:
+ a (str): name of left operand tensor
+ b (str): name of right operand tensor
+
+ Returns:
+ str: output tensor name
+ """
+ return self.make_node('Mul', [a, b], 1)[0]
+
+ def make_clip(self, input, min, max):
+ """Create Clip operation and insert it in graph.
+
+ Args:
+ input (str): input tensor name
+ min (int/float): lower clip bound
+ max (int/float ): upper clip bound
+
+ Returns:
+ str: output tensor name
+ """
+ return self.make_node('Clip', [input], 1, min=min, max=max)[0]
+
+ def make_act(self, input, act_name):
+ """Create activation function operation and insert it in graph.
+
+ Args:
+ input (str): input tensor name
+ act_name (str): name of activation function, one of ['Relu', 'Tanh', 'Sigmoid']
+
+ Returns:
+ str: output tensor name
+ """
+ assert (act_name in ['Relu', 'Tanh', 'Sigmoid'])
+ return self.make_node(act_name, [input], 1)[0]
+
+ def make_constant_tensor(self, tensor_data, base_name):
+ """Creates onnx constant tensor
+
+ Args:
+ tensor_data (numpy.ndarray): tensor data
+ base_name (str): prefix of constant tensor name
+
+ Returns:
+ str: name of created constant tensor
+ """
+ tensor = onnx.numpy_helper.from_array(tensor_data)
+ tensor.name = self.make_tensor_with_base_name(base_name)
+ self._model.graph.initializer.append(tensor)
+ return tensor.name
+
+ def mark_for_deletion(self, node):
+ self._nodes_to_delete += [node]
+
+ def get_insert_id(self):
+ return self._insert_id
+
+ def set_insert_id(self, insert_id):
+ self._insert_id = insert_id
+
+ def delete_marked_nodes(self):
+ for node in self._nodes_to_delete:
+ self._model.graph.node.remove(node)
+
+
+class _TensorInfo:
+ def __init__(self, dtype, shape):
+ self.dtype = dtype
+ self.shape = shape
+
+
+def _get_tensor_infos(model):
+ """Infer tensor shapes and dtypes
+ Args:
+ model (onnx.onnx_ml_pb2.ModelProto): model to process
+
+ Returns:
+ dict from str to _TensorInfo: maps tensor name to shape and dtype information
+ """
+
+ inferred_shape_model = onnx.shape_inference.infer_shapes(model)
+
+ infos = {}
+ for tensor in list(inferred_shape_model.graph.value_info) + list(
+ inferred_shape_model.graph.input):
+ info = _TensorInfo(tensor.type.tensor_type.elem_type, [])
+ for dim in tensor.type.tensor_type.shape.dim:
+ info.shape += [dim.dim_value]
+ infos[tensor.name] = info
+
+ for tensor in list(model.graph.initializer):
+ infos[tensor.name] = _TensorInfo(tensor.data_type, tensor.dims)
+ return infos
+
+
+def _dtype_to_np(dtype):
+ """Convert onnx dtype value to numpy dtype class
+
+ For more types see:
+ https://github.com/onnx/onnx/blob/96516aecd4c110b0ac57eba08ac236ebf7205728/onnx/onnx.proto3#L484
+
+ Args:
+ dtype (int): onnx dtype
+
+ Returns:
+ numpy data type: numpy dtype, like np.float32
+ """
+
+ if dtype == 1:
+ return np.float32
+ else:
+ raise NotImplementedError('unsupported data type')
+
+
+def _generate_one_direction_RNN(transformer, X, W, R, B, initial_h, clip, activation_name):
+ """Generate subgraph of one direction of unrolled RNN layer
+
+ Args:
+ transformer (_ModelTransformerHelper): helper for model generation
+ X (list of str): names of input tensors in sequence. Tensor shapes: [batch_size, input_size].
+ W (str): name of weight tensor
+ R (str): name of recurrence weight tensor
+ B (str): name of bias tensor
+ initial_h (str or None): name of tensor containing initial hidden state. Shape [batch_size, hidden_size]
+ clip (float or None): range which clips input of activations
+ act (str): activation function
+ """
+ # one direction RNN:
+ #
+ # For details see:
+ # https://github.com/onnx/onnx/blob/5cf5feef5ec3fd5527b2fdb6c29780e3b705059f/docs/Changelog.md#RNN-7
+ #
+ # H = f(X*(W^T) + h*(R^T) + B)
+ #
+ # H - new hidden state
+ # h - previous hidden state
+ # X - current input
+ # W - input weights matrix
+ # R - reccurent weights matrix
+ # Wb - input weights matmul bias
+ # Rb - reccurent weights matmul bias
+ # f - activation function
+
+ seq_length = len(X)
+ first_iter = 0
+ state_tensors = []
+ if initial_h is not None:
+ previous_state_tensor = initial_h
+ else:
+ first_iter = 1
+ state_tensor = transformer.make_gemm(X[0], W, B, trans_b=True)
+ if clip != None:
+ state_tensor = transformer.make_clip(state_tensor, min=-clip, max=clip)
+ previous_state_tensor = transformer.make_act(state_tensor, activation_name)
+ state_tensors += [previous_state_tensor]
+
+ for i in range(first_iter, seq_length):
+ state_tensor = transformer.make_gemm(X[i], W, B, trans_b=True)
+ state_tensor = transformer.make_gemm(
+ previous_state_tensor, R, state_tensor, trans_b=True)
+ if clip != None:
+ state_tensor = transformer.make_clip(state_tensor, min=-clip, max=clip)
+ previous_state_tensor = transformer.make_act(state_tensor, activation_name)
+ state_tensors += [previous_state_tensor]
+ return state_tensors
+
+
+def _transform_unidirectional_RNN(transformer, original_node, x, tensor_infos, activation,
+ clip, direction, hidden_size, layout):
+ """Generate Simple (forward or reverse) unrolled RNN
+
+ Args:
+ transformer (_ModelTransformerHelper): transformation helper
+ original_node (onnx.onnx_ml_pb2.NodeProto): unidirectional RNN operation to unroll
+ x (list of str): list of input tensors (input tensor split along "time" dimension)
+ tensor_infos (dict from str to _TensorInfo): dict maps tensor name to it's shape and dtype info
+ activation (str): name of activation function
+ clip (float or None): range which clips input of activations
+ direction (str): "forward" or "reverse"
+ hidden_size (int): size of hidden state
+ layout (int): See attribute description:
+ https://github.com/onnx/onnx/blob/5cf5feef5ec3fd5527b2fdb6c29780e3b705059f/docs/Operators.md#attributes-56
+ """
+
+ inputs = original_node.input
+ outputs = original_node.output
+ if direction == 'reverse':
+ x.reverse()
+ w = transformer.make_squeeze(inputs[1], axes=[0])
+ r = transformer.make_squeeze(inputs[2], axes=[0])
+ if len(inputs) > 3 and inputs[3] != '':
+ raw_bias_tensor = transformer.make_squeeze(inputs[3], axes=[0])
+ splitted_bias_tensors = transformer.make_split(
+ raw_bias_tensor, split_sizes=[hidden_size] * 2, axis=0)
+ b = transformer.make_add(splitted_bias_tensors[0], splitted_bias_tensors[1])
+ else:
+ data_type = _dtype_to_np(tensor_infos[inputs[2]].dtype)
+ b = transformer.make_constant_tensor(
+ np.zeros(hidden_size, dtype=data_type), "zero_bias")
+ if len(inputs) > 5 and inputs[5] != '':
+ direction_dim = layout
+ initial_h = transformer.make_squeeze(inputs[5], axes=[direction_dim])
+ else:
+ initial_h = None
+ state_tensors = _generate_one_direction_RNN(transformer, x, w, r, b, initial_h, clip,
+ activation)
+ y_direction_dim = layout + 1
+ y_h_direction_dim = layout
+ state_layout_tensors = []
+ seq_length_dim = layout
+ for state in state_tensors:
+ state_layout_tensors += [
+ transformer.make_unsqueeze(state, axes=[seq_length_dim, y_direction_dim])
+ ]
+
+ # use low-level interface to attach to existing tensors
+ Y_h = outputs[1]
+ transformer.make_node(
+ 'Unsqueeze', [state_tensors[-1]], [Y_h], axes=[y_h_direction_dim])
+ Y = outputs[0]
+ transformer.make_node(
+ 'Concat', state_layout_tensors, [Y], axis=seq_length_dim)
+
+
+def _transform_bidirectional_RNN(transformer, original_node, x, tensor_infos, activations,
+ clip, hidden_size, layout):
+ """Generate Bidirectional unrolled RNN
+
+ Args:
+ transformer (_ModelTransformerHelper): transformation helper
+ original_node (onnx.onnx_ml_pb2.NodeProto): bidirectional RNN operation to unroll
+ x (list of str): list of input tensors (input tensor split along "time" dimension)
+ tensor_infos (dict from str to _TensorInfo): dict maps tensor name to it's shape and dtype info
+ activations (list of str): list of len (2) containing names of forward and reverse activations
+ clip (float or None): range which clips input of activations
+ hidden_size (int): size of hidden state
+ layout (int): See attribute description:
+ https://github.com/onnx/onnx/blob/5cf5feef5ec3fd5527b2fdb6c29780e3b705059f/docs/Operators.md#attributes-56
+ """
+
+ inputs = original_node.input
+ outputs = original_node.output
+ w_bi = transformer.make_split(inputs[1], split_sizes=[1, 1], axis=0)
+ r_bi = transformer.make_split(inputs[2], split_sizes=[1, 1], axis=0)
+ w = []
+ r = []
+ for d in range(2):
+ w += [transformer.make_squeeze(w_bi[d], axes=[0])]
+ r += [transformer.make_squeeze(r_bi[d], axes=[0])]
+
+ b = []
+ if len(inputs) > 3 and inputs[3] != '':
+ raw_bias_tensors = transformer.make_split(inputs[3], split_sizes=[1, 1], axis=0)
+ for d in range(2):
+ raw_bias_tensors_squeezed = transformer.make_squeeze(
+ raw_bias_tensors[d], axes=[0])
+ splitted_bias_tensors = transformer.make_split(
+ raw_bias_tensors_squeezed, split_sizes=[hidden_size] * 2, axis=0)
+ b += [
+ transformer.make_add(splitted_bias_tensors[0], splitted_bias_tensors[1])
+ ]
+ else:
+ data_type = _dtype_to_np(tensor_infos[inputs[2]].dtype)
+ b = [
+ transformer.make_constant_tensor(
+ np.zeros(hidden_size, dtype=data_type), "zero_bias")
+ ] * 2
+ initial_h = [None, None]
+ if len(inputs) > 5 and inputs[5] != '':
+ direction_dim = layout
+ initial_h = transformer.make_split(
+ inputs[5], split_sizes=[1, 1], axis=direction_dim)
+ for d in range(2):
+ initial_h[d] = transformer.make_squeeze(initial_h[d], axes=[direction_dim])
+
+ state_f_tensors = _generate_one_direction_RNN(transformer, x, w[0], r[0], b[0],
+ initial_h[0], clip, activations[0])
+ x.reverse()
+ state_b_tensors = _generate_one_direction_RNN(transformer, x, w[1], r[1], b[1],
+ initial_h[1], clip, activations[1])
+ state_b_tensors.reverse()
+
+ y_direction_dim = layout + 1
+ y_h_direction_dim = layout
+ state_layout_tensors = []
+ seq_length_dim = layout
+ seq_length = len(x)
+ for t in range(seq_length):
+ state_f = state_f_tensors[t]
+ state_b = state_b_tensors[t]
+ state_layout_tensors_f = transformer.make_unsqueeze(
+ state_f, axes=[seq_length_dim, y_direction_dim])
+ state_layout_tensors_b = transformer.make_unsqueeze(
+ state_b, axes=[seq_length_dim, y_direction_dim])
+ state_layout_tensors += [
+ transformer.make_concat(
+ [state_layout_tensors_f, state_layout_tensors_b], axis=y_direction_dim)
+ ]
+
+ last_f_state_layout_tensor = transformer.make_unsqueeze(
+ state_f_tensors[-1], axes=[y_h_direction_dim])
+ last_b_state_layout_tensor = transformer.make_unsqueeze(
+ state_b_tensors[0], axes=[y_h_direction_dim])
+
+ # use low-level interface to attach to existing tensors
+ Y_h = outputs[1]
+ transformer.make_node(
+ 'Concat', [last_f_state_layout_tensor, last_b_state_layout_tensor], [Y_h],
+ axis=y_h_direction_dim)
+
+ Y = outputs[0]
+ transformer.make_node(
+ 'Concat', state_layout_tensors, [Y], axis=seq_length_dim)
+
+
+def _legalize_RNN(transformer, tensor_infos, node):
+ """Unroll RNN operation
+
+ Args:
+ transformer (_ModelTransformerHelper): transformation helper
+ tensor_infos (dict from str to _TensorInfo): dict maps tensor name to it's shape and dtype info
+ node (onnx.onnx_ml_pb2.NodeProto): RNN operation to unroll
+ """
+ inputs = node.input
+ if len(inputs) > 4 and inputs[4] != '':
+ raise NotImplementedError('Variadic length of output is not supported')
+ # attributes
+ activation_alpha = []
+ activation_beta = []
+ activations = ['Tanh', 'Tanh']
+ clip = None
+ direction = 'forward'
+ hidden_size = 0
+ layout = 0
+
+ for attr in node.attribute:
+ if attr.name == 'activation_alpha':
+ activation_alpha = attr.floats
+ if attr.name == 'activation_beta':
+ activation_beta = attr.floats
+ if attr.name == 'activations':
+ activations = list(map(lambda item: item.decode('UTF-8'), list(attr.strings)))
+ if attr.name == 'clip':
+ clip = attr.f
+ if attr.name == 'direction':
+ direction = attr.s.decode('UTF-8')
+ if attr.name == 'hidden_size':
+ hidden_size = attr.i
+ if attr.name == 'layout':
+ layout = attr.i
+
+ if len(activation_alpha) > 0 or len(activation_beta) > 0:
+ raise NotImplementedError('Unsupported parameters for LSTM activations')
+
+ for act in activations:
+ if act not in ['Relu', 'Tanh', 'Sigmoid']:
+ raise NotImplementedError('Unsupported activation function')
+
+ seq_length_dim = layout
+ seq_length = tensor_infos[inputs[0]].shape[seq_length_dim]
+ if hidden_size == 0:
+ hidden_size = tensor_infos[inputs[2]].shape[2]
+
+ input_split_tensor = transformer.make_split(
+ inputs[0], split_sizes=[1] * seq_length, axis=seq_length_dim)
+ x = []
+ for i in range(len(input_split_tensor)):
+ input_frame_tensor = input_split_tensor[i]
+ squeezed_frame_tensor = transformer.make_squeeze(input_frame_tensor, axes=[0])
+ x += [squeezed_frame_tensor]
+
+ if direction in ['forward', 'reverse']:
+ _transform_unidirectional_RNN(transformer, node, x, tensor_infos, activations[0],
+ clip, direction, hidden_size, layout)
+ elif direction == 'bidirectional':
+ _transform_bidirectional_RNN(transformer, node, x, tensor_infos, activations, clip,
+ hidden_size, layout)
+ else:
+ raise RuntimeError('Unknown RNN type')
+
+ transformer.mark_for_deletion(node)
+
+
+def _generate_one_direction_LSTM(transformer, X, W, R, B, initial_h, initial_c, P, clip,
+ act, dtype, hidden_size, batch_size):
+ """Generate subgraph for one direction of unrolled LSTM layer
+
+ Args:
+ transformer (_ModelTransformerHelper): helper for model generation
+ X (list of str): names of tensors in input sequence. Each tensor shape: [batch_size, input_size]
+ W (str): name of concatenated weight tensor: [input, output, forget, cell]
+ R (str): name of concatenated recurrence weights tensor: [input, output, forget, cell]
+ B (str): name of concatenated bias tensor: [input, output, forget, cell]
+ initial_h (str or None): name of tensor containing initial hidden state. Shape [batch_size, hidden_size]
+ initial_c (str or None): name of tensor containing initial cell state. Shape [batch_size, hidden_size]
+ P (str or None): name of concatenated peephole tensor: [input, output, forget]
+ clip (float or None): range which clips input of activations
+ act (dict of str): activation functions {'f': 'Sigmoid', 'g': 'Tanh', 'h': 'Tanh'}
+ dtype (numpy dtype): data type used in created LSTM operation
+ hidden_size (int): hidden dimension
+ batch_size (int): batch dimension
+ """
+ # one direction LSTM:
+ #
+ # For details see:
+ # https://github.com/onnx/onnx/blob/5cf5feef5ec3fd5527b2fdb6c29780e3b705059f/docs/Changelog.md#LSTM-7
+ #
+ # it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
+ # ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
+ # ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
+ # Ct = ft (.) Ct-1 + it (.) ct
+ # ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
+ # Ht = ot (.) h(Ct)
+ #
+ # X - input tensor
+ # i - input gate
+ # o - output gate
+ # f - forget gate
+ # c - cell gate
+ # t - time step (t-1 means previous time step)
+ # W[iofc] - W parameter weight matrix for input, output, forget, and cell gates
+ # R[iofc] - R recurrence weight matrix for input, output, forget, and cell gates
+ # Wb[iofc] - W bias vectors for input, output, forget, and cell gates
+ # Rb[iofc] - R bias vectors for input, output, forget, and cell gates
+ # P[iof] - P peephole weight vector for input, output, and forget gates
+ # WB[iofc] - W parameter weight matrix for backward input, output, forget, and cell gates
+ # RB[iofc] - R recurrence weight matrix for backward input, output, forget, and cell gates
+ # WBb[iofc] - W bias vectors for backward input, output, forget, and cell gates
+ # RBb[iofc] - R bias vectors for backward input, output, forget, and cell gates
+ # PB[iof] - P peephole weight vector for backward input, output, and forget gates
+ # H - Hidden state
+
+ seq_length = len(X)
+ state_h_tensors = []
+
+ w_tensors = transformer.make_split(W, split_sizes=[hidden_size] * 4, axis=0)
+ W = {'i': w_tensors[0], 'o': w_tensors[1], 'f': w_tensors[2], 'c': w_tensors[3]}
+
+ r_tensors = transformer.make_split(R, split_sizes=[hidden_size] * 4, axis=0)
+ R = {'i': r_tensors[0], 'o': r_tensors[1], 'f': r_tensors[2], 'c': r_tensors[3]}
+
+ if B is not None:
+ separate_b_tensors = transformer.make_split(
+ B, split_sizes=[hidden_size] * 8, axis=0)
+ b_tensors = []
+ for i in range(4):
+ b_tensors += [
+ transformer.make_add(separate_b_tensors[i], separate_b_tensors[i + 4])
+ ]
+ else:
+ b_tensors = [
+ transformer.make_constant_tensor(
+ np.zeros((hidden_size), dtype=dtype), 'zero_b')
+ ] * 4
+ B = {'i': b_tensors[0], 'o': b_tensors[1], 'f': b_tensors[2], 'c': b_tensors[3]}
+
+ if initial_h is not None:
+ previous_h_state_tensor = initial_h
+ else:
+ previous_h_state_tensor = transformer.make_constant_tensor(
+ np.zeros((batch_size, hidden_size), dtype=dtype), 'initial_h')
+
+ if initial_c is not None:
+ previous_c_state_tensor = initial_c
+ else:
+ previous_c_state_tensor = transformer.make_constant_tensor(
+ np.zeros((batch_size, hidden_size), dtype=dtype), 'initial_c')
+
+ if P is not None:
+ p_tensors = transformer.make_split(P, split_sizes=[hidden_size] * 3, axis=0)
+ P = {'i': p_tensors[0], 'o': p_tensors[1], 'f': p_tensors[2]}
+ else:
+ zero = transformer.make_constant_tensor(
+ np.zeros((hidden_size), dtype=dtype), 'zero_peephole')
+ P = {'i': zero, 'o': zero, 'f': zero}
+
+ for i in range(seq_length):
+ # it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
+ it = transformer.make_gemm(X[i], W['i'], B['i'], trans_b=True)
+ it = transformer.make_gemm(previous_h_state_tensor, R['i'], it, trans_b=True)
+ peephole_it = transformer.make_mul(P['i'], previous_c_state_tensor)
+ it = transformer.make_add(it, peephole_it)
+ if clip is not None:
+ it = transformer.make_clip(it, min=-clip, max=clip)
+ it = transformer.make_act(it, act['f'])
+
+ # ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
+ ft = transformer.make_gemm(X[i], W['f'], B['f'], trans_b=True)
+ ft = transformer.make_gemm(previous_h_state_tensor, R['f'], ft, trans_b=True)
+ peephole_ft = transformer.make_mul(P['f'], previous_c_state_tensor)
+ ft = transformer.make_add(ft, peephole_ft)
+ if clip is not None:
+ ft = transformer.make_clip(ft, min=-clip, max=clip)
+ ft = transformer.make_act(ft, act['f'])
+
+ # ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
+ ct = transformer.make_gemm(X[i], W['c'], B['c'], trans_b=True)
+ ct = transformer.make_gemm(previous_h_state_tensor, R['c'], ct, trans_b=True)
+ if clip is not None:
+ ct = transformer.make_clip(ct, min=-clip, max=clip)
+ ct = transformer.make_act(ct, act['g'])
+
+ # Ct = ft (.) Ct-1 + it (.) ct
+ ft_Ct = transformer.make_mul(ft, previous_c_state_tensor)
+ it_ct = transformer.make_mul(it, ct)
+ Ct = transformer.make_add(ft_Ct, it_ct)
+ previous_c_state_tensor = Ct
+
+ # ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
+ ot = transformer.make_gemm(X[i], W['o'], B['o'], trans_b=True)
+ ot = transformer.make_gemm(previous_h_state_tensor, R['o'], ot, trans_b=True)
+ peephole_ot = transformer.make_mul(P['o'], Ct)
+ ot = transformer.make_add(ot, peephole_ot)
+ if clip is not None:
+ ot = transformer.make_clip(ot, min=-clip, max=clip)
+ ot = transformer.make_act(ot, act['f'])
+
+ # Ht = ot (.) h(Ct)
+ Ht = transformer.make_act(Ct, act['h'])
+ Ht = transformer.make_mul(ot, Ht)
+ previous_h_state_tensor = Ht
+ state_h_tensors += [Ht]
+
+ return (state_h_tensors, previous_c_state_tensor)
+
+
+def _transform_unidirectional_LSTM(transformer, original_node, x, tensor_infos,
+ activations, clip, direction, hidden_size, layout):
+ """Generate Simple (forward or reverse) unrolled LSTM
+
+ Args:
+ transformer (_ModelTransformerHelper): transformation helper
+ original_node (onnx.onnx_ml_pb2.NodeProto): unidirectional LSTM operation to unroll
+ x (list of str): list of input tensors (input tensor split along "time" dimension)
+ tensor_infos (dict from str to _TensorInfo): dict maps tensor name to it's shape and dtype info
+ activations (list of str): list of length 3 containing names of activation functions
+ clip (float or None): range which clips input of activations
+ direction (str): "forward" or "reverse"
+ hidden_size (int): size of hidden state
+ layout (int): See attribute description:
+ https://github.com/onnx/onnx/blob/5cf5feef5ec3fd5527b2fdb6c29780e3b705059f/docs/Operators.md#attributes-37
+ """
+
+ inputs = original_node.input
+ outputs = original_node.output
+ if direction == 'reverse':
+ x.reverse()
+ w = transformer.make_squeeze(inputs[1], axes=[0])
+ r = transformer.make_squeeze(inputs[2], axes=[0])
+
+ b = None
+ if len(inputs) > 3 and inputs[3] != '':
+ b = transformer.make_squeeze(inputs[3], axes=[0])
+
+ initial_h = None
+ if len(inputs) > 5 and inputs[5] != '':
+ direction_dim = layout
+ initial_h = transformer.make_squeeze(inputs[5], axes=[direction_dim])
+
+ initial_c = None
+ if len(inputs) > 6 and inputs[6] != '':
+ direction_dim = layout
+ initial_c = transformer.make_squeeze(inputs[6], axes=[direction_dim])
+
+ p = None
+ if len(inputs) > 7 and inputs[7] != '':
+ p = transformer.make_squeeze(inputs[7], axes=[0])
+
+ dtype = _dtype_to_np(tensor_infos[inputs[0]].dtype)
+ batch_size = tensor_infos[inputs[0]].shape[1 - layout]
+
+ act = {'f': activations[0], 'g': activations[1], 'h': activations[2]}
+
+ state_h_tensors, state_c_tensor = _generate_one_direction_LSTM(
+ transformer, x, w, r, b, initial_h, initial_c, p, clip, act, dtype, hidden_size,
+ batch_size)
+
+ y_direction_dim = layout + 1
+ y_h_direction_dim = layout
+ state_layout_tensors = []
+ seq_length_dim = layout
+ for h_state in state_h_tensors:
+ state_layout_tensors += [
+ transformer.make_unsqueeze(h_state, axes=[seq_length_dim, y_direction_dim])
+ ]
+
+ # use low-level interface to attach to existing tensors
+ Y_h = outputs[1]
+ transformer.make_node(
+ 'Unsqueeze', [state_h_tensors[-1]], [Y_h], axes=[y_h_direction_dim])
+ Y_c = outputs[2]
+ transformer.make_node(
+ 'Unsqueeze', [state_c_tensor], [Y_c], axes=[y_h_direction_dim])
+ if direction == 'reverse':
+ state_layout_tensors.reverse()
+ Y = outputs[0]
+ transformer.make_node(
+ 'Concat', state_layout_tensors, [Y], axis=seq_length_dim)
+
+
+def _transform_bidirectional_LSTM(transformer, original_node, x, tensor_infos, activations,
+ clip, hidden_size, layout):
+ """Generate Bidirectional unrolled LSTM
+
+ Args:
+ transformer (_ModelTransformerHelper): transformation helper
+ original_node (onnx.onnx_ml_pb2.NodeProto): bidirectional LSTM operation to unroll
+ x (list of str): list of input tensors (input tensor split along "time" dimension)
+ tensor_infos (dict from str to _TensorInfo): dict maps tensor name to it's shape and dtype info
+ activations (list of str): list of length 6, containing names of forward and reverse activations
+ clip (float or None): range which clips input of activations
+ hidden_size (int): size of hidden state
+ layout (int): See attribute description:
+ https://github.com/onnx/onnx/blob/5cf5feef5ec3fd5527b2fdb6c29780e3b705059f/docs/Operators.md#attributes-37
+ """
+
+ inputs = original_node.input
+ outputs = original_node.output
+
+ w = transformer.make_split(inputs[1], split_sizes=[1, 1], axis=0)
+ r = transformer.make_split(inputs[2], split_sizes=[1, 1], axis=0)
+ for d in range(2):
+ w[d] = transformer.make_squeeze(w[d], axes=[0])
+ r[d] = transformer.make_squeeze(r[d], axes=[0])
+
+ b = [None, None]
+ if len(inputs) > 3 and inputs[3] != '':
+ b = transformer.make_split(inputs[3], split_sizes=[1, 1], axis=0)
+ for d in range(2):
+ b[d] = transformer.make_squeeze(b[d], axes=[0])
+
+ initial_h = [None, None]
+ if len(inputs) > 5 and inputs[5] != '':
+ direction_dim = layout
+ initial_h = transformer.make_split(
+ inputs[5], split_sizes=[1, 1], axis=direction_dim)
+ for d in range(2):
+ initial_h[d] = transformer.make_squeeze(initial_h[d], axes=[direction_dim])
+
+ initial_c = [None, None]
+ if len(inputs) > 6 and inputs[6] != '':
+ direction_dim = layout
+ initial_c = transformer.make_split(
+ inputs[6], split_sizes=[1, 1], axis=direction_dim)
+ for d in range(2):
+ initial_c[d] = transformer.make_squeeze(initial_c[d], axes=[direction_dim])
+
+ p = [None, None]
+ if len(inputs) > 7 and inputs[7] != '':
+ p = transformer.make_split(inputs[7], split_sizes=[1, 1], axis=0)
+ for d in range(2):
+ p[d] = transformer.make_squeeze(p[d], axes=[0])
+
+ dtype = _dtype_to_np(tensor_infos[inputs[0]].dtype)
+ batch_size = tensor_infos[inputs[0]].shape[1 - layout]
+
+ act = [{
+ 'f': activations[0],
+ 'g': activations[1],
+ 'h': activations[2]
+ }, {
+ 'f': activations[3],
+ 'g': activations[4],
+ 'h': activations[5]
+ }]
+
+ state_f_h_tensors, state_f_c_tensor = _generate_one_direction_LSTM(
+ transformer, x, w[0], r[0], b[0], initial_h[0], initial_c[0], p[0], clip, act[0],
+ dtype, hidden_size, batch_size)
+ x.reverse()
+ state_b_h_tensors, state_b_c_tensor = _generate_one_direction_LSTM(
+ transformer, x, w[1], r[1], b[1], initial_h[1], initial_c[1], p[1], clip, act[1],
+ dtype, hidden_size, batch_size)
+ state_b_h_tensors.reverse()
+
+ y_direction_dim = layout + 1
+ y_c_direction_dim = layout
+ state_layout_tensors = []
+ seq_length_dim = layout
+ for f_h_state, b_h_state in zip(state_f_h_tensors, state_b_h_tensors):
+ state_f_layout_tensors = transformer.make_unsqueeze(
+ f_h_state, axes=[seq_length_dim, y_direction_dim])
+ state_b_layout_tensors = transformer.make_unsqueeze(
+ b_h_state, axes=[seq_length_dim, y_direction_dim])
+ state_layout_tensors += [
+ transformer.make_concat(
+ [state_f_layout_tensors, state_b_layout_tensors], axis=y_direction_dim)
+ ]
+
+ last_f_state_layout_tensor = transformer.make_unsqueeze(
+ state_f_h_tensors[-1], axes=[y_c_direction_dim])
+ last_b_state_layout_tensor = transformer.make_unsqueeze(
+ state_b_h_tensors[0], axes=[y_c_direction_dim])
+
+ Y_h = outputs[1]
+ transformer.make_node(
+ 'Concat', [last_f_state_layout_tensor, last_b_state_layout_tensor], [Y_h],
+ axis=y_c_direction_dim)
+
+ Y_f_c = transformer.make_unsqueeze(state_f_c_tensor, axes=[y_c_direction_dim])
+ Y_b_c = transformer.make_unsqueeze(state_b_c_tensor, axes=[y_c_direction_dim])
+ Y_c = outputs[2]
+ transformer.make_node(
+ 'Concat', [Y_f_c, Y_b_c], [Y_c], axis=y_c_direction_dim)
+
+ Y = outputs[0]
+ transformer.make_node(
+ 'Concat', state_layout_tensors, [Y], axis=seq_length_dim)
+
+
+def _legalize_LSTM(transformer, tensor_infos, node):
+ """Unroll LSTM operation
+
+ Args:
+ transformer (_ModelTransformerHelper): transformation helper
+ tensor_infos (dict from str to _TensorInfo): dict maps tensor name to it's shape and dtype info
+ node (onnx.onnx_ml_pb2.NodeProto): LSTM operation to unroll
+ """
+ inputs = node.input
+ if len(inputs) > 4 and inputs[4] != '':
+ raise NotImplementedError('Variadic length of output is not supported')
+ # attributes
+ activation_alpha = []
+ activation_beta = []
+ activations = ['Sigmoid', 'Tanh', 'Tanh'] * 2
+ clip = None
+ direction = 'forward'
+ hidden_size = 0
+ input_forget = 0
+ layout = 0
+
+ for attr in node.attribute:
+ if attr.name == 'activation_alpha':
+ activation_alpha = attr.floats
+ if attr.name == 'activation_beta':
+ activation_beta = attr.floats
+ if attr.name == 'activations':
+ activations = list(map(lambda item: item.decode('UTF-8'), list(attr.strings)))
+ if attr.name == 'clip':
+ clip = attr.f
+ if attr.name == 'direction':
+ direction = attr.s.decode('UTF-8')
+ if attr.name == 'hidden_size':
+ hidden_size = attr.i
+ if attr.name == 'input_forget':
+ input_forget = attr.i
+ if attr.name == 'layout':
+ layout = attr.i
+
+ if len(activation_alpha) > 0 or len(activation_beta) > 0:
+ raise NotImplementedError('Unsupported parameters for LSTM activations')
+
+ for act in activations:
+ if act not in ['Relu', 'Tanh', 'Sigmoid']:
+ raise NotImplementedError('Unsupported activation function')
+
+ if input_forget != 0:
+ raise NotImplementedError('Unsupported input_forget attribute value')
+
+ seq_length_dim = layout
+ seq_length = tensor_infos[inputs[0]].shape[seq_length_dim]
+ if hidden_size == 0:
+ hidden_size = tensor_infos[inputs[2]].shape[2]
+
+ input_split_tensor = transformer.make_split(
+ inputs[0], split_sizes=[1] * seq_length, axis=seq_length_dim)
+ x = []
+ for i in range(len(input_split_tensor)):
+ input_frame_tensor = input_split_tensor[i]
+ squeezed_frame_tensor = transformer.make_squeeze(input_frame_tensor, axes=[0])
+ x += [squeezed_frame_tensor]
+
+ if direction in ['forward', 'reverse']:
+ _transform_unidirectional_LSTM(transformer, node, x, tensor_infos, activations,
+ clip, direction, hidden_size, layout)
+ elif direction == 'bidirectional':
+ _transform_bidirectional_LSTM(transformer, node, x, tensor_infos, activations,
+ clip, hidden_size, layout)
+ else:
+ raise RuntimeError('Unknown LSTM type')
+
+ transformer.mark_for_deletion(node)
+
+
+def legalize(model, options):
+ """Replace selected operations in onnx model
+
+ Replaces operations, selected by given options with different operation sequences.
+ For example remove unsupported parts of graph with sequences of supported operations.
+
+ Note that graph is changes inplace.
+
+ Args:
+ model (onnx.onnx_ml_pb2.ModelProto): target model
+ options (LegalizeOptions):
+ """
+ tensor_infos = _get_tensor_infos(model)
+
+ transformer = _ModelTransformerHelper(model)
+
+ node_id = 0
+ while node_id < len(model.graph.node):
+ node = model.graph.node[node_id]
+ if node.op_type == 'RNN' and options.unroll_rnn:
+ # opset version is required by split operation
+ if model.opset_import[0].version >= 13:
+ raise NotImplementedError(
+ 'Can not generate code with opcode version 13 and greater')
+ transformer.set_insert_id(node_id)
+ _legalize_RNN(transformer, tensor_infos, node)
+ node_id = transformer.get_insert_id()
+ elif node.op_type == 'LSTM' and options.unroll_lstm:
+ if model.opset_import[0].version >= 13:
+ raise NotImplementedError(
+ 'Can not generate code with opcode version 13 and greater')
+ transformer.set_insert_id(node_id)
+ _legalize_LSTM(transformer, tensor_infos, node)
+ node_id = transformer.get_insert_id()
+ node_id += 1
+
+ transformer.delete_marked_nodes()
+
+
+if __name__ == '__main__':
+ if len(sys.argv) < 3:
+ print('usage: ./legalize_onnx.py <path to input model> <path to output model>\n'
+ '\n'
+ ' In stand-alone utility mode this tool provides basic funtionality\n'
+ ' If you want to have more control over applied transformations, use this legalizer as a library')
+ exit(1)
+ options = LegalizeOptions()
+ options.unroll_lstm = True
+ options.unroll_rnn = True
+ model = onnx.load(sys.argv[1])
+ legalize(model, options)
+ onnx.save(model, sys.argv[2])
diff --git a/compiler/one-cmds/tests/CMakeLists.txt b/compiler/one-cmds/tests/CMakeLists.txt
index 6f9f2847e..caea756c2 100644
--- a/compiler/one-cmds/tests/CMakeLists.txt
+++ b/compiler/one-cmds/tests/CMakeLists.txt
@@ -3,6 +3,7 @@
# Gather test scripts
file(GLOB TESTITEMS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "./*.test")
file(GLOB CONFIGITEMS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "./*.cfg")
+file(GLOB QCONFIGITEMS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "./*.qconf.json")
# Create a script to run the tests at installation folder
set(DRIVER_SCRIPT "${CMAKE_CURRENT_BINARY_DIR}/runtestall.sh")
@@ -39,6 +40,11 @@ foreach(CONFIGITEM IN ITEMS ${CONFIGITEMS})
install(FILES ${CONFIGITEM} DESTINATION test)
endforeach(CONFIGITEM)
+foreach(QCONFIGITEM IN ITEMS ${QCONFIGITEMS})
+ get_filename_component(ITEM_PREFIX ${QCONFIGITEM} NAME_WE)
+ install(FILES ${QCONFIGITEM} DESTINATION test)
+endforeach(QCONFIGITEM)
+
file(APPEND "${DRIVER_SCRIPT}" "popd > /dev/null\n\n")
file(APPEND "${DRIVER_SCRIPT}"
@@ -52,6 +58,8 @@ fi\n
set(PREPARE_TEST_MATERIALS_SH "${CMAKE_CURRENT_SOURCE_DIR}/prepare_test_materials.sh")
set(PREPROCESS_IMAGES_PY "${CMAKE_CURRENT_SOURCE_DIR}/preprocess_images.py")
+set(ONNX_LEGALIZE_RUN_COMPARE "${CMAKE_CURRENT_SOURCE_DIR}/onnx_legalize_run_compare.py")
+set(PRINT_ONNX_MODEL "${CMAKE_CURRENT_SOURCE_DIR}/print_onnx_model.py")
install(FILES ${DRIVER_SCRIPT}
PERMISSIONS OWNER_WRITE OWNER_READ OWNER_EXECUTE
@@ -71,5 +79,23 @@ install(FILES ${PREPROCESS_IMAGES_PY}
WORLD_READ
DESTINATION test)
+install(FILES ${ONNX_LEGALIZE_RUN_COMPARE}
+ PERMISSIONS OWNER_WRITE OWNER_READ OWNER_EXECUTE
+ GROUP_READ GROUP_EXECUTE
+ WORLD_READ WORLD_EXECUTE
+ DESTINATION test)
+
+install(FILES ${PRINT_ONNX_MODEL}
+ PERMISSIONS OWNER_WRITE OWNER_READ OWNER_EXECUTE
+ GROUP_READ GROUP_EXECUTE
+ WORLD_READ WORLD_EXECUTE
+ DESTINATION test)
+
install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/README.txt
DESTINATION test)
+
+add_subdirectory(onnx-operations)
+
+if(ENABLE_ONE_IMPORT_PYTORCH)
+ add_subdirectory(pytorch-operations)
+endif(ENABLE_ONE_IMPORT_PYTORCH)
diff --git a/compiler/one-cmds/tests/one-quantize_009.qconf.json b/compiler/one-cmds/tests/one-quantize_009.qconf.json
new file mode 100644
index 000000000..ac274e83a
--- /dev/null
+++ b/compiler/one-cmds/tests/one-quantize_009.qconf.json
@@ -0,0 +1,36 @@
+{
+ "default_quantization_dtype" : "uint8",
+ "default_granularity" : "channel",
+ "layers" : [
+ {
+ "name" : "InceptionV3/InceptionV3/Conv2d_2b_3x3/Relu;InceptionV3/InceptionV3/Conv2d_2b_3x3/BatchNorm/FusedBatchNorm;InceptionV3/InceptionV3/Mixed_6a/Branch_1/Conv2d_0a_1x1/Conv2D;InceptionV3/InceptionV3/Conv2d_2b_3x3/Conv2D",
+ "dtype" : "int16",
+ "granularity" : "channel"
+ },
+ {
+ "name" : "InceptionV3/InceptionV3/MaxPool_5a_3x3/MaxPool",
+ "dtype" : "int16",
+ "granularity" : "channel"
+ },
+ {
+ "name" : "InceptionV3/InceptionV3/Mixed_5b/concat",
+ "dtype" : "int16",
+ "granularity" : "channel"
+ },
+ {
+ "name" : "InceptionV3/InceptionV3/Mixed_5b/Branch_3/AvgPool_0a_3x3/AvgPool",
+ "dtype" : "int16",
+ "granularity" : "channel"
+ },
+ {
+ "name" : "InceptionV3/InceptionV3/Mixed_7c/concat",
+ "dtype" : "int16",
+ "granularity" : "channel"
+ },
+ {
+ "name" : "InceptionV3/Predictions/Reshape_1",
+ "dtype" : "int16",
+ "granularity" : "channel"
+ }
+ ]
+}
diff --git a/compiler/one-cmds/tests/one-quantize_009.test b/compiler/one-cmds/tests/one-quantize_009.test
new file mode 100644
index 000000000..aa0670350
--- /dev/null
+++ b/compiler/one-cmds/tests/one-quantize_009.test
@@ -0,0 +1,55 @@
+#!/bin/bash
+
+# Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+inputfile="./inception_v3.circle"
+outputfile="./inception_v3.random.quantized.mixed.circle"
+
+rm -rf ${outputfile}
+
+# to create inception_v3.circle
+if [[ ! -s ${inputfile} ]]; then
+ /bin/bash one-import_001.test > /dev/null 2>&1
+ return_code=$?
+ if [[ ${return_code} != 0 ]]; then
+ trap_err_onexit
+ fi
+fi
+
+# run test without input data
+one-quantize \
+--input_dtype float32 \
+--quantized_dtype uint8 \
+--granularity channel \
+--quant_config one-quantize_009.qconf.json \
+--input_path ${inputfile} \
+--output_path ${outputfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
diff --git a/compiler/one-cmds/tests/onnx-operations/CMakeLists.txt b/compiler/one-cmds/tests/onnx-operations/CMakeLists.txt
new file mode 100644
index 000000000..e6b2b354a
--- /dev/null
+++ b/compiler/one-cmds/tests/onnx-operations/CMakeLists.txt
@@ -0,0 +1,86 @@
+# Install one-cmds test scripts for onnx models
+
+# Gather test scripts
+set(EXAMPLES_DIR "${NNAS_PROJECT_SOURCE_DIR}/res/PyTorchExamples/examples")
+file(GLOB TEST_EXAMPLES RELATIVE "${EXAMPLES_DIR}" "${EXAMPLES_DIR}/*")
+
+set(TEST_DST test/onnx-operations)
+
+install(DIRECTORY "${NNAS_PROJECT_SOURCE_DIR}/res/PyTorchExamples/" DESTINATION "${TEST_DST}")
+
+set(ONNX_IMPORT_OPTIONS "--unroll_rnn --unroll_lstm")
+
+foreach(TEST_ITEM IN ITEMS ${TEST_EXAMPLES})
+ set(TEST_SCRIPT "${CMAKE_CURRENT_BINARY_DIR}/${TEST_ITEM}.test")
+
+ # generate test script
+ file(WRITE "${TEST_SCRIPT}" "#!/bin/bash\n\n")
+ file(APPEND "${TEST_SCRIPT}" "filename_ext=\"\$(basename -- $0)\"\n")
+ file(APPEND "${TEST_SCRIPT}" "filename=\"\${filename_ext%.*}\"\n")
+ file(APPEND "${TEST_SCRIPT}" "trap_err_onexit()\n")
+ file(APPEND "${TEST_SCRIPT}" "{\n")
+ file(APPEND "${TEST_SCRIPT}" "echo \"\${filename_ext} FAILED\"\n")
+ file(APPEND "${TEST_SCRIPT}" "exit 255\n")
+ file(APPEND "${TEST_SCRIPT}" "}\n")
+ file(APPEND "${TEST_SCRIPT}" "trap trap_err_onexit ERR\n")
+ file(APPEND "${TEST_SCRIPT}" "outputfile=\"${TEST_ITEM}.circle\"\n")
+ file(APPEND "${TEST_SCRIPT}" "one-import-onnx --input_path=${TEST_ITEM}.onnx --output_path=${TEST_ITEM}.circle\
+ ${ONNX_IMPORT_OPTIONS} &> /dev/null\n")
+ file(APPEND "${TEST_SCRIPT}" "if [[ ! -s \"\${outputfile}\" ]]; then\n")
+ file(APPEND "${TEST_SCRIPT}" "trap_err_onexit\n")
+ file(APPEND "${TEST_SCRIPT}" "fi\n")
+ file(APPEND "${TEST_SCRIPT}" "echo \"\${filename_ext} SUCCESS\"\n")
+
+ install(FILES "${TEST_SCRIPT}" DESTINATION "${TEST_DST}")
+endforeach(TEST_ITEM)
+
+
+# Create a script to run the tests at installation folder
+set(DRIVER_SCRIPT "${CMAKE_CURRENT_BINARY_DIR}/runtestall.sh")
+
+file(WRITE "${DRIVER_SCRIPT}" "#!/bin/bash\n\n")
+file(APPEND "${DRIVER_SCRIPT}" "SCRIPT_PATH=$(cd $(dirname \${BASH_SOURCE[0]}) && pwd)\n")
+file(APPEND "${DRIVER_SCRIPT}" "pushd $SCRIPT_PATH > /dev/null\n")
+file(APPEND "${DRIVER_SCRIPT}" "rm -rf runtestall.log\n")
+file(APPEND "${DRIVER_SCRIPT}" "export PATH=$SCRIPT_PATH/../bin:$PATH\n")
+file(APPEND "${DRIVER_SCRIPT}" "if [[ $# -ge 1 ]]; then\n")
+file(APPEND "${DRIVER_SCRIPT}" " USER_PATH=$1\n")
+file(APPEND "${DRIVER_SCRIPT}" " export PATH=$USER_PATH:$PATH\n")
+file(APPEND "${DRIVER_SCRIPT}" "fi\n")
+file(APPEND "${DRIVER_SCRIPT}" "\n")
+file(APPEND "${DRIVER_SCRIPT}" "# refer https://github.com/Samsung/ONE/issues/6286\n")
+file(APPEND "${DRIVER_SCRIPT}" "set -o pipefail\n\n")
+file(APPEND "${DRIVER_SCRIPT}" "fail_count=0\n")
+file(APPEND "${DRIVER_SCRIPT}" "trap \"(( fail_count++ ))\" ERR\n\n")
+
+foreach(TEST_ITEM IN ITEMS ${TEST_EXAMPLES})
+ file(APPEND "${DRIVER_SCRIPT}" "/bin/bash \"${TEST_ITEM}.test\" | tee -a runtestall.log\n")
+endforeach(TEST_ITEM)
+
+file(APPEND "${DRIVER_SCRIPT}" "popd > /dev/null\n\n")
+
+file(APPEND "${DRIVER_SCRIPT}"
+"if [[ $fail_count != 0 ]]; then
+ echo \"$fail_count TESTS FAILED\"
+ exit 255
+else
+ echo \"ALL TESTS PASSED!\"
+fi\n
+")
+
+set(PREPARE_TEST_MATERIALS_SH "${CMAKE_CURRENT_SOURCE_DIR}/prepare_test_materials.sh")
+
+install(FILES "${DRIVER_SCRIPT}"
+ PERMISSIONS OWNER_WRITE OWNER_READ OWNER_EXECUTE
+ GROUP_READ GROUP_EXECUTE
+ WORLD_READ WORLD_EXECUTE
+ DESTINATION "${TEST_DST}")
+
+install(FILES "${PREPARE_TEST_MATERIALS_SH}"
+ PERMISSIONS OWNER_WRITE OWNER_READ OWNER_EXECUTE
+ GROUP_READ GROUP_EXECUTE
+ WORLD_READ WORLD_EXECUTE
+ DESTINATION "${TEST_DST}")
+
+install(FILES "${CMAKE_CURRENT_SOURCE_DIR}/README.md"
+ DESTINATION "${TEST_DST}")
diff --git a/compiler/one-cmds/tests/onnx-operations/README.md b/compiler/one-cmds/tests/onnx-operations/README.md
new file mode 100644
index 000000000..928fb84dd
--- /dev/null
+++ b/compiler/one-cmds/tests/onnx-operations/README.md
@@ -0,0 +1,28 @@
+## Overview
+
+This directory contains auxilliary tests for small onnx target models.
+
+Most of the models contains single operations, but some contains multiple operations, that represents one operation with complex semantics.
+
+Models for these tests are taken from res/PyTorchExamples.
+
+## To run all tests
+
+Steps:
+1) run 'one-prepare-venv' in bin folder to prepare python virtual-env with TensorFlow
+ - you need to run this only once
+ - read 'doc/how-to-prepare-virtualenv.txt' for more information
+ ```
+ bin/one-prepare-venv
+ ```
+2) run 'test/onnx-operations/prepare_test_materials.sh' to download test material models
+ - you need to run this only once
+ - you need internet connection to download files
+ - you may need to install 'wget' and 'unzip' packages
+ ```
+ test/onnx-operations/prepare_test_materials.sh
+ ```
+3) run 'test/onnx-operations/runtestall.sh' to run the test
+ ```
+ test/onnx-operations/runtestall.sh
+ ```
diff --git a/compiler/one-cmds/tests/onnx-operations/prepare_test_materials.sh b/compiler/one-cmds/tests/onnx-operations/prepare_test_materials.sh
new file mode 100644
index 000000000..274a60f0a
--- /dev/null
+++ b/compiler/one-cmds/tests/onnx-operations/prepare_test_materials.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+SCRIPT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+pushd $SCRIPT_PATH > /dev/null
+
+for test_case in examples/*; do
+ python3 ptem.py $(basename ${test_case})
+done
+
+cp output/*.onnx .
+
+popd > /dev/null
diff --git a/compiler/one-cmds/tests/onnx_legalize_run_compare.py b/compiler/one-cmds/tests/onnx_legalize_run_compare.py
new file mode 100644
index 000000000..9b02b74af
--- /dev/null
+++ b/compiler/one-cmds/tests/onnx_legalize_run_compare.py
@@ -0,0 +1,129 @@
+# Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import onnxruntime as rt
+import onnx
+import sys
+import numpy as np
+import importlib.util
+
+
+def _generate_inputs(model):
+ """Generate random inputs for given model
+
+ Args:
+ model (onnx.onnx_ml_pb2.ModelProto): target model
+
+ Returns:
+ dict from str to numpy.ndarray: generated inputs
+ """
+ inputs = {}
+ for input in model.graph.input:
+ # check if elem type is float32
+ # list of types could be extended, this is a property of current testsuite
+ assert (
+ input.type.tensor_type.elem_type == onnx.TensorProto.DataType.Value("FLOAT"))
+ input_shape = []
+ for dim in input.type.tensor_type.shape.dim:
+ input_shape += [dim.dim_value]
+ inputs[input.name] = np.random.random(input_shape).astype(np.float32)
+ return inputs
+
+
+def _run_model(model, inputs):
+ """Run given model
+
+ Args:
+ model (onnx.onnx_ml_pb2.ModelProto): target model
+ inputs (dict from str to numpy.ndarray): sample inputs
+
+ Returns:
+ list of numpy.ndarray: inference outputs
+ """
+ output_names = list(map(lambda output: output.name, model.graph.output))
+ session = rt.InferenceSession(model.SerializeToString())
+ outputs = session.run(output_names, inputs)
+ return outputs
+
+
+def _compare_results(ref_outputs, test_outputs, tolerance):
+ """Generate random inputs for given model
+
+ Args:
+ ref_outputs (list of numpy.ndarray): reference values (original model results)
+ test_outputs (list of numpy.ndarray): tested values (modified model results)
+ tolerance (float): maximum acceptable relative difference
+
+ Returns:
+ bool: True if outputs considered equal, False otherwise
+ """
+ num_outputs = len(ref_outputs)
+ assert (len(test_outputs) == num_outputs)
+ for i in range(num_outputs):
+ if ref_outputs[i].shape != test_outputs[i].shape:
+ print("output {} shape mismatch: ref({}) vs test({})".format(
+ i, ref_outputs[i].shape, test_outputs[i].shape))
+ return False
+
+ abs_difference = np.abs(ref_outputs[i] - test_outputs[i])
+ abs_ref_maximum = np.abs(ref_outputs[i]).max()
+ peak_error = abs_difference.max() / abs_ref_maximum
+
+ if peak_error > tolerance:
+ print("output {} peak error to value ratio {} is too big".format(
+ i, peak_error))
+ return False
+ return True
+
+
+if __name__ == '__main__':
+ if len(sys.argv) < 6:
+ exit('expecting 5 arguments:\n'
+ ' - path to input model\n'
+ ' - path to "legalized" model\n'
+ ' - path to onnx_legalizer.py\n'
+ ' - base name for generated test inputs\n'
+ ' - output tolerance')
+ input_model_path = sys.argv[1]
+ output_model_path = sys.argv[2]
+ onnx_legalizer_path = sys.argv[3]
+ input_dump_path = sys.argv[4]
+ tolerance = float(sys.argv[5])
+
+ onnx_legalizer_spec = importlib.util.spec_from_file_location(
+ "onnx_legalizer", onnx_legalizer_path)
+ onnx_legalizer = importlib.util.module_from_spec(onnx_legalizer_spec)
+ onnx_legalizer_spec.loader.exec_module(onnx_legalizer)
+
+ model = onnx.load(input_model_path)
+
+ inputs = _generate_inputs(model)
+
+ for i in inputs:
+ np.save('{}_{}.npy'.format(input_dump_path, i), inputs[i])
+
+ ref_outputs = _run_model(model, inputs)
+
+ options = onnx_legalizer.LegalizeOptions()
+ options.unroll_rnn = True
+ options.unroll_lstm = True
+ onnx_legalizer.legalize(model, options)
+
+ with open(output_model_path, 'wb') as f:
+ f.write(model.SerializeToString())
+
+ test_outputs = _run_model(model, inputs)
+
+ if not _compare_results(ref_outputs, test_outputs, tolerance):
+ exit('comparison failed')
diff --git a/compiler/one-cmds/tests/prepare_test_materials.sh b/compiler/one-cmds/tests/prepare_test_materials.sh
index 7f269530c..c80c59834 100644
--- a/compiler/one-cmds/tests/prepare_test_materials.sh
+++ b/compiler/one-cmds/tests/prepare_test_materials.sh
@@ -91,6 +91,39 @@ if [[ ! -s "onnx_conv2d_conv2d.onnx" ]]; then
# https://github.com/Samsung/ONE/issues/5577#issuecomment-755078444
fi
+function files_missing() {
+ condition="test "
+
+ for f in "${@}"; do
+ condition="${condition} ! -s ${f} -o"
+ done
+
+ # last condition is always false to properly close last "or"
+ condition="${condition} -z non_zero_string "
+ ${condition}
+}
+
+declare -a TEST_RECCURENT_MODELS=(\
+ "RNN.onnx" "RNN-nobias.onnx" "RNN-relu.onnx" "RNN-bi.onnx" "RNN-noinit.onnx"\
+ "LSTM.onnx" "LSTM-bi.onnx" "LSTM-noinit.onnx" "LSTM-nobias.onnx"
+)
+
+if files_missing "${TEST_RECCURENT_MODELS[@]}"; then
+ rm -rf test_onnx_recurrent_models.zip
+ wget https://github.com/Samsung/ONE/files/8067909/test_onnx_recurrent_models.zip
+ unzip test_onnx_recurrent_models.zip
+ # https://github.com/Samsung/ONE/issues/8395#issuecomment-1040072097
+fi
+
+declare -a NEG_TEST_RECCURENT_MODELS=("rnn_variable.onnx" "lstm_variable.onnx")
+
+if files_missing "${NEG_TEST_RECCURENT_MODELS[@]}"; then
+ rm -rf neg_test_onnx_recurrent_models.zip
+ wget https://github.com/Samsung/ONE/files/8137183/neg_test_onnx_recurrent_models.zip
+ unzip neg_test_onnx_recurrent_models.zip
+ # https://github.com/Samsung/ONE/issues/8395#issuecomment-1050364375
+fi
+
# prepare 'inception_v3.circle' file used for quantization test
inputfile="./inception_v3.pb"
outputfile="./inception_v3.circle"
diff --git a/compiler/one-cmds/tests/print_onnx_model.py b/compiler/one-cmds/tests/print_onnx_model.py
new file mode 100644
index 000000000..ecab0f6da
--- /dev/null
+++ b/compiler/one-cmds/tests/print_onnx_model.py
@@ -0,0 +1,20 @@
+# Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import onnx
+import sys
+
+if __name__ == '__main__':
+ model = onnx.load(sys.argv[1])
+ print(model)
diff --git a/compiler/one-cmds/tests/pytorch-operations/CMakeLists.txt b/compiler/one-cmds/tests/pytorch-operations/CMakeLists.txt
new file mode 100644
index 000000000..10f30a5c9
--- /dev/null
+++ b/compiler/one-cmds/tests/pytorch-operations/CMakeLists.txt
@@ -0,0 +1,109 @@
+# Install one-cmds test scripts for pytorch models
+
+# Gather test scripts
+set(EXAMPLES_DIR "${NNAS_PROJECT_SOURCE_DIR}/res/PyTorchExamples/examples")
+file(GLOB TEST_EXAMPLES RELATIVE "${EXAMPLES_DIR}" "${EXAMPLES_DIR}/*")
+file(GLOB SPECIAL_TEST_ITEMS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "./*.test")
+
+set(TEST_DST test/pytorch-operations)
+
+install(DIRECTORY "${NNAS_PROJECT_SOURCE_DIR}/res/PyTorchExamples/" DESTINATION "${TEST_DST}")
+
+set(PYTORCH_IMPORT_OPTIONS "--unroll_rnn --unroll_lstm")
+
+foreach(TEST_ITEM IN ITEMS ${TEST_EXAMPLES})
+ set(TEST_SCRIPT "${CMAKE_CURRENT_BINARY_DIR}/${TEST_ITEM}.test")
+
+ # generate test script
+ file(WRITE "${TEST_SCRIPT}" "#!/bin/bash\n\n")
+ file(APPEND "${TEST_SCRIPT}" "filename_ext=\"\$(basename -- $0)\"\n")
+ file(APPEND "${TEST_SCRIPT}" "filename=\"\${filename_ext%.*}\"\n")
+ file(APPEND "${TEST_SCRIPT}" "trap_err_onexit()\n")
+ file(APPEND "${TEST_SCRIPT}" "{\n")
+ file(APPEND "${TEST_SCRIPT}" " echo \"\${filename_ext} FAILED\"\n")
+ file(APPEND "${TEST_SCRIPT}" " exit 255\n")
+ file(APPEND "${TEST_SCRIPT}" "}\n")
+ file(APPEND "${TEST_SCRIPT}" "trap trap_err_onexit ERR\n")
+ file(APPEND "${TEST_SCRIPT}" "outputfile=\"${TEST_ITEM}.circle\"\n")
+ file(APPEND "${TEST_SCRIPT}" "input_shapes=\$(head -n 1 ${TEST_ITEM}.spec)\n")
+ file(APPEND "${TEST_SCRIPT}" "input_types=\$(tail -n 1 ${TEST_ITEM}.spec)\n")
+ file(APPEND "${TEST_SCRIPT}" "one-import-pytorch --input_path=${TEST_ITEM}.pth --output_path=${TEST_ITEM}.circle\
+ ${PYTORCH_IMPORT_OPTIONS} --input_shapes=\${input_shapes} --input_types=\${input_types} &> /dev/null\n")
+ file(APPEND "${TEST_SCRIPT}" "if [[ ! -s \"\${outputfile}\" ]]; then\n")
+ file(APPEND "${TEST_SCRIPT}" " trap_err_onexit\n")
+ file(APPEND "${TEST_SCRIPT}" "fi\n")
+ file(APPEND "${TEST_SCRIPT}" "echo \"\${filename_ext} SUCCESS\"\n")
+
+ install(FILES "${TEST_SCRIPT}" DESTINATION "${TEST_DST}")
+endforeach(TEST_ITEM)
+
+
+# Create a script to run the tests at installation folder
+set(DRIVER_SCRIPT "${CMAKE_CURRENT_BINARY_DIR}/runtestall.sh")
+
+file(WRITE "${DRIVER_SCRIPT}" "#!/bin/bash\n\n")
+file(APPEND "${DRIVER_SCRIPT}" "SCRIPT_PATH=$(cd $(dirname \${BASH_SOURCE[0]}) && pwd)\n")
+file(APPEND "${DRIVER_SCRIPT}" "pushd $SCRIPT_PATH > /dev/null\n")
+file(APPEND "${DRIVER_SCRIPT}" "rm -rf runtestall.log\n")
+file(APPEND "${DRIVER_SCRIPT}" "export PATH=$SCRIPT_PATH/../bin:$PATH\n")
+file(APPEND "${DRIVER_SCRIPT}" "if [[ $# -ge 1 ]]; then\n")
+file(APPEND "${DRIVER_SCRIPT}" " USER_PATH=$1\n")
+file(APPEND "${DRIVER_SCRIPT}" " export PATH=$USER_PATH:$PATH\n")
+file(APPEND "${DRIVER_SCRIPT}" "fi\n")
+file(APPEND "${DRIVER_SCRIPT}" "\n")
+file(APPEND "${DRIVER_SCRIPT}" "# refer https://github.com/Samsung/ONE/issues/6286\n")
+file(APPEND "${DRIVER_SCRIPT}" "set -o pipefail\n\n")
+file(APPEND "${DRIVER_SCRIPT}" "fail_count=0\n")
+file(APPEND "${DRIVER_SCRIPT}" "trap \"(( fail_count++ ))\" ERR\n\n")
+
+foreach(TEST_ITEM IN ITEMS ${TEST_EXAMPLES})
+ file(APPEND "${DRIVER_SCRIPT}" "/bin/bash \"${TEST_ITEM}.test\" | tee -a runtestall.log\n")
+endforeach(TEST_ITEM)
+
+file(APPEND "${DRIVER_SCRIPT}" "\necho \"special test items\" | tee -a runtestall.log\n\n")
+
+foreach(TEST_ITEM IN ITEMS ${SPECIAL_TEST_ITEMS})
+ file(APPEND "${DRIVER_SCRIPT}" "/bin/bash \"${TEST_ITEM}\" | tee -a runtestall.log\n")
+endforeach(TEST_ITEM)
+
+file(APPEND "${DRIVER_SCRIPT}" "popd > /dev/null\n\n")
+
+file(APPEND "${DRIVER_SCRIPT}"
+"if [[ $fail_count != 0 ]]; then
+ echo \"$fail_count TESTS FAILED\"
+ exit 255
+else
+ echo \"ALL TESTS PASSED!\"
+fi\n
+")
+
+set(PREPARE_TEST_MATERIALS_SH "${CMAKE_CURRENT_SOURCE_DIR}/prepare_test_materials.sh")
+set(EXAMPLE_GENERATOR "${CMAKE_CURRENT_SOURCE_DIR}/example_generator.py")
+set(AUX_GENERATOR "${CMAKE_CURRENT_SOURCE_DIR}/aux_generator.py")
+
+install(FILES "${DRIVER_SCRIPT}"
+ PERMISSIONS OWNER_WRITE OWNER_READ OWNER_EXECUTE
+ GROUP_READ GROUP_EXECUTE
+ WORLD_READ WORLD_EXECUTE
+ DESTINATION "${TEST_DST}")
+
+install(FILES "${PREPARE_TEST_MATERIALS_SH}"
+ PERMISSIONS OWNER_WRITE OWNER_READ OWNER_EXECUTE
+ GROUP_READ GROUP_EXECUTE
+ WORLD_READ WORLD_EXECUTE
+ DESTINATION "${TEST_DST}")
+
+install(FILES "${EXAMPLE_GENERATOR}" "${AUX_GENERATOR}"
+ PERMISSIONS OWNER_WRITE OWNER_READ OWNER_EXECUTE
+ GROUP_READ GROUP_EXECUTE
+ WORLD_READ WORLD_EXECUTE
+ DESTINATION "${TEST_DST}")
+
+install(FILES ${SPECIAL_TEST_ITEMS}
+ PERMISSIONS OWNER_WRITE OWNER_READ OWNER_EXECUTE
+ GROUP_READ GROUP_EXECUTE
+ WORLD_READ WORLD_EXECUTE
+ DESTINATION "${TEST_DST}")
+
+install(FILES "${CMAKE_CURRENT_SOURCE_DIR}/README.md"
+ DESTINATION "${TEST_DST}")
diff --git a/compiler/one-cmds/tests/pytorch-operations/README.md b/compiler/one-cmds/tests/pytorch-operations/README.md
new file mode 100644
index 000000000..231a10eb4
--- /dev/null
+++ b/compiler/one-cmds/tests/pytorch-operations/README.md
@@ -0,0 +1,28 @@
+## Overview
+
+This directory contains auxilliary tests for small pytorch target models.
+
+Most of the models contains single operations, but some contains multiple operations, that represents one operation with complex semantics.
+
+Models for these tests are taken from res/PyTorchExamples.
+
+## To run all tests
+
+Steps:
+1) run 'one-prepare-venv' in bin folder to prepare python virtual-env with TensorFlow
+ - you need to run this only once
+ - read 'doc/how-to-prepare-virtualenv.txt' for more information
+ ```
+ bin/one-prepare-venv
+ ```
+2) run 'test/pytorch-operations/prepare_test_materials.sh' to download test material models
+ - you need to run this only once
+ - you need internet connection to download files
+ - you may need to install 'wget' and 'unzip' packages
+ ```
+ test/pytorch-operations/prepare_test_materials.sh
+ ```
+3) run 'test/pytorch-operations/runtestall.sh' to run the test
+ ```
+ test/pytoch-operations/runtestall.sh
+ ```
diff --git a/compiler/one-cmds/tests/pytorch-operations/aux_generator.py b/compiler/one-cmds/tests/pytorch-operations/aux_generator.py
new file mode 100644
index 000000000..6c9afcded
--- /dev/null
+++ b/compiler/one-cmds/tests/pytorch-operations/aux_generator.py
@@ -0,0 +1,83 @@
+#!/usr/bin/env python3
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# PyTorch aux tests generator
+
+import torch
+import torch.nn as nn
+import json
+import zipfile
+import os
+
+
+# model
+class net_abs(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, input):
+ return torch.abs(input)
+
+
+if __name__ == '__main__':
+ model = net_abs()
+ # save "entire" model for entire_model.test
+ torch.save(model, 'entire_model.pth')
+
+ # save state_dict file for state_dict_model.test
+ state_dict_path = 'state_dict_model.pth'
+ torch.save(model.state_dict(), state_dict_path)
+
+ # create files for mar_torchscript_model.test
+ torchscript_path = 'torchscript_model.pth'
+ inp = torch.randn(1, 2, 3, 3)
+ traced_model = torch.jit.trace(model, inp)
+ torch.jit.save(traced_model, torchscript_path)
+ # create manifest
+ manifest = {}
+ manifest['createdOn'] = '11/11/1111 11:11:11'
+ manifest['runtime'] = 'python'
+ manifest['model'] = {}
+ manifest['model']['modelName'] = 'torchscript_model',
+ manifest['model']['serializedFile'] = torchscript_path
+ manifest['model']['handler'] = 'image_classifier'
+ manifest['model']['modelVersion'] = '1.0'
+ manifest['archiverVersion'] = '0.4.2'
+
+ with zipfile.ZipFile('mar_torchscript_model.mar', 'w') as mar_file:
+ with mar_file.open('MAR-INF/MANIFEST.json', 'w') as manifest_file:
+ manifest_file.write(json.dumps(manifest).encode())
+ mar_file.write(torchscript_path)
+
+ # create files for mar_state_dict_model.test
+ model_file_path = os.path.basename(__file__)
+ # create manifest
+ manifest = {}
+ manifest['createdOn'] = '11/11/1111 11:11:11'
+ manifest['runtime'] = 'python'
+ manifest['model'] = {}
+ manifest['model']['modelName'] = 'state_dict_model',
+ manifest['model']['serializedFile'] = state_dict_path
+ manifest['model']['handler'] = 'image_classifier'
+ manifest['model']['modelFile'] = model_file_path
+ manifest['model']['modelVersion'] = '1.0'
+ manifest['archiverVersion'] = '0.4.2'
+
+ with zipfile.ZipFile('mar_state_dict_model.mar', 'w') as mar_file:
+ with mar_file.open('MAR-INF/MANIFEST.json', 'w') as manifest_file:
+ manifest_file.write(json.dumps(manifest).encode())
+ mar_file.write(state_dict_path)
+ mar_file.write(model_file_path)
diff --git a/compiler/one-cmds/tests/pytorch-operations/entire_model.test b/compiler/one-cmds/tests/pytorch-operations/entire_model.test
new file mode 100644
index 000000000..a72a56ffd
--- /dev/null
+++ b/compiler/one-cmds/tests/pytorch-operations/entire_model.test
@@ -0,0 +1,40 @@
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Test one-import-pytorch ability to import NN model stored in python file and serialized "entire" model.
+# "Entire" model is serialized with `torch.save(model)` method.
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+outputfile="entire_model.circle"
+
+# run test
+one-import-pytorch --input_path=entire_model.pth --python_path=aux_generator.py --output_path=${outputfile} --input_shapes=1,2,3,3 --input_types=float32 &> /dev/null
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
diff --git a/compiler/one-cmds/tests/pytorch-operations/example_generator.py b/compiler/one-cmds/tests/pytorch-operations/example_generator.py
new file mode 100644
index 000000000..20a80c895
--- /dev/null
+++ b/compiler/one-cmds/tests/pytorch-operations/example_generator.py
@@ -0,0 +1,116 @@
+#!/usr/bin/env python3
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# PyTorch Example manager
+
+import torch
+import importlib
+import argparse
+import os
+
+from pathlib import Path
+
+print("PyTorch version=", torch.__version__)
+
+parser = argparse.ArgumentParser(description='Process PyTorch python examples')
+
+parser.add_argument('examples', metavar='EXAMPLES', nargs='+')
+
+args = parser.parse_args()
+
+output_folder = "./"
+
+Path(output_folder).mkdir(parents=True, exist_ok=True)
+
+
+class JitWrapper(torch.nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.model = model
+
+ def forward(self, *args):
+ if len(args) == 1:
+ return self.model.forward(args[0])
+ else:
+ return self.model.forward(args)
+
+
+for example in args.examples:
+ print("Generate '" + example + ".pth'", end='')
+ # load example code
+ # replace - with _ in name, otherwise pytorch generates invalid torchscript
+ module_name = "examples." + example.replace('-', '_')
+ module_loader = importlib.machinery.SourceFileLoader(
+ module_name, os.path.join("examples", example, "__init__.py"))
+ module_spec = importlib.util.spec_from_loader(module_name, module_loader)
+ module = importlib.util.module_from_spec(module_spec)
+ module_loader.exec_module(module)
+
+ jittable_model = JitWrapper(module._model_)
+
+ traced_model = torch.jit.trace(jittable_model, module._dummy_)
+ # save .pth
+ torch.jit.save(traced_model, output_folder + example + ".pth")
+
+ input_shapes = ""
+ input_types = ""
+
+ input_samples = module._dummy_
+ if isinstance(input_samples, torch.Tensor):
+ input_samples = [input_samples]
+ for inp_idx in range(len(input_samples)):
+ input_data = input_samples[inp_idx]
+
+ shape = input_data.shape
+ for dim in range(len(shape)):
+ input_shapes += str(shape[dim])
+ if dim != len(shape) - 1:
+ input_shapes += ","
+
+ if input_data.dtype == torch.bool:
+ input_types += "bool"
+ elif input_data.dtype == torch.uint8:
+ input_types += "uint8"
+ elif input_data.dtype == torch.int8:
+ input_types += "int8"
+ elif input_data.dtype == torch.int16:
+ input_types += "int16"
+ elif input_data.dtype == torch.int32:
+ input_types += "int32"
+ elif input_data.dtype == torch.int64:
+ input_types += "int16"
+ elif input_data.dtype == torch.float16:
+ input_types += "float32"
+ elif input_data.dtype == torch.float32:
+ input_types += "float32"
+ elif input_data.dtype == torch.float64:
+ input_types += "float64"
+ elif input_data.dtype == torch.complex64:
+ input_types += "complex64"
+ elif input_data.dtype == torch.complex128:
+ input_types += "complex128"
+ else:
+ raise ValueError('unsupported dtype')
+
+ if inp_idx != len(input_samples) - 1:
+ input_shapes += ":"
+ input_types += ","
+
+ with open(example + ".spec", "w") as spec_file:
+ print(input_shapes, file=spec_file)
+ print(input_types, file=spec_file)
+
+ print(" - Done")
diff --git a/compiler/one-cmds/tests/pytorch-operations/mar_state_dict_model.test b/compiler/one-cmds/tests/pytorch-operations/mar_state_dict_model.test
new file mode 100644
index 000000000..9892dbbed
--- /dev/null
+++ b/compiler/one-cmds/tests/pytorch-operations/mar_state_dict_model.test
@@ -0,0 +1,40 @@
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Test one-import-pytorch ability to import .mar file.
+# .mar file contains python source of the model and serialized state_dict.
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+outputfile="mar_state_dict_model.circle"
+
+# run test
+one-import-pytorch --input_path=mar_state_dict_model.mar --output_path=${outputfile} --input_shapes=1,2,3,3 --input_types=float32 &> /dev/null
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
diff --git a/compiler/one-cmds/tests/pytorch-operations/mar_torchscript_model.test b/compiler/one-cmds/tests/pytorch-operations/mar_torchscript_model.test
new file mode 100644
index 000000000..3ac38a42e
--- /dev/null
+++ b/compiler/one-cmds/tests/pytorch-operations/mar_torchscript_model.test
@@ -0,0 +1,40 @@
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Test one-import-pytorch ability to import .mar file.
+# .mar file contains TorchScript.
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+outputfile="mar_torchscript_model.circle"
+
+# run test
+one-import-pytorch --input_path=mar_torchscript_model.mar --output_path=${outputfile} --input_shapes=1,2,3,3 --input_types=float32 &> /dev/null
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
diff --git a/compiler/one-cmds/tests/pytorch-operations/prepare_test_materials.sh b/compiler/one-cmds/tests/pytorch-operations/prepare_test_materials.sh
new file mode 100644
index 000000000..5f38610d7
--- /dev/null
+++ b/compiler/one-cmds/tests/pytorch-operations/prepare_test_materials.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+SCRIPT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+pushd $SCRIPT_PATH > /dev/null
+
+for test_case in examples/*; do
+ python3 example_generator.py $(basename ${test_case})
+done
+
+python3 aux_generator.py
+
+popd > /dev/null
diff --git a/compiler/one-cmds/tests/pytorch-operations/state_dict_model.test b/compiler/one-cmds/tests/pytorch-operations/state_dict_model.test
new file mode 100644
index 000000000..ecd2a8112
--- /dev/null
+++ b/compiler/one-cmds/tests/pytorch-operations/state_dict_model.test
@@ -0,0 +1,39 @@
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Test one-import-pytorch ability to import NN model from .py file and serialized state_dict file.
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+outputfile="state_dict_model.circle"
+
+# run test
+one-import-pytorch --input_path=state_dict_model.pth --python_path=aux_generator.py --output_path=${outputfile} --input_shapes=1,2,3,3 --input_types=float32 &> /dev/null
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
diff --git a/compiler/one-cmds/tests/pytorch-operations/torchscript_model.test b/compiler/one-cmds/tests/pytorch-operations/torchscript_model.test
new file mode 100644
index 000000000..590e5b369
--- /dev/null
+++ b/compiler/one-cmds/tests/pytorch-operations/torchscript_model.test
@@ -0,0 +1,39 @@
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Test one-import-pytorch ability to import TorchScript file.
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+outputfile="torchscript_model.circle"
+
+# run test
+one-import-pytorch --input_path=torchscript_model.pth --output_path=${outputfile} --input_shapes=1,2,3,3 --input_types=float32 &> /dev/null
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
diff --git a/compiler/one-cmds/utils.py b/compiler/one-cmds/utils.py
index 5d84c2bd5..be0322aca 100644
--- a/compiler/one-cmds/utils.py
+++ b/compiler/one-cmds/utils.py
@@ -17,80 +17,13 @@
import argparse
import configparser
import glob
+import importlib
import ntpath
import os
import subprocess
import sys
-
-class _CONSTANT:
- __slots__ = () # This prevents access via __dict__.
- OPTIMIZATION_OPTS = (
- # (OPTION_NAME, HELP_MESSAGE)
- ('O1', 'enable O1 optimization pass'),
- ('convert_nchw_to_nhwc',
- 'Experimental: This will convert NCHW operators to NHWC under the assumption that input model is NCHW.'
- ),
- ('expand_broadcast_const', 'expand broadcastable constant node inputs'),
- ('nchw_to_nhwc_input_shape',
- 'convert the input shape of the model (argument for convert_nchw_to_nhwc)'),
- ('nchw_to_nhwc_output_shape',
- 'convert the output shape of the model (argument for convert_nchw_to_nhwc)'),
- ('fold_add_v2', 'fold AddV2 op with constant inputs'),
- ('fold_cast', 'fold Cast op with constant input'),
- ('fold_dequantize', 'fold Dequantize op'),
- ('fold_dwconv', 'fold Depthwise Convolution op with constant inputs'),
- ('fold_sparse_to_dense', 'fold SparseToDense op'),
- ('forward_reshape_to_unaryop', 'Forward Reshape op'),
- ('fuse_add_with_tconv', 'fuse Add op to Transposed'),
- ('fuse_add_with_fully_connected', 'fuse Add op to FullyConnected op'),
- ('fuse_batchnorm_with_conv', 'fuse BatchNorm op to Convolution op'),
- ('fuse_batchnorm_with_dwconv', 'fuse BatchNorm op to Depthwise Convolution op'),
- ('fuse_batchnorm_with_tconv', 'fuse BatchNorm op to Transposed Convolution op'),
- ('fuse_bcq', 'apply Binary Coded Quantization'),
- ('fuse_preactivation_batchnorm',
- 'fuse BatchNorm operators of pre-activations to Convolution op'),
- ('fuse_mean_with_mean', 'fuse two consecutive Mean ops'),
- ('fuse_transpose_with_mean',
- 'fuse Mean with a preceding Transpose under certain conditions'),
- ('make_batchnorm_gamma_positive',
- 'make negative gamma of BatchNorm to a small positive value (1e-10).'
- ' Note that this pass can change the execution result of the model.'
- ' So, use it only when the impact is known to be acceptable.'),
- ('fuse_activation_function', 'fuse Activation function to a preceding operator'),
- ('fuse_instnorm', 'fuse ops to InstanceNorm operator'),
- ('replace_cw_mul_add_with_depthwise_conv',
- 'replace channel-wise Mul/Add with DepthwiseConv2D'),
- ('remove_fakequant', 'remove FakeQuant ops'),
- ('remove_quantdequant', 'remove Quantize-Dequantize sequence'),
- ('remove_redundant_reshape', 'fuse or remove subsequent Reshape ops'),
- ('remove_redundant_transpose', 'fuse or remove subsequent Transpose ops'),
- ('remove_unnecessary_reshape', 'remove unnecessary reshape ops'),
- ('remove_unnecessary_slice', 'remove unnecessary slice ops'),
- ('remove_unnecessary_strided_slice', 'remove unnecessary strided slice ops'),
- ('remove_unnecessary_split', 'remove unnecessary split ops'),
- ('resolve_customop_add', 'convert Custom(Add) op to Add op'),
- ('resolve_customop_batchmatmul',
- 'convert Custom(BatchMatmul) op to BatchMatmul op'),
- ('resolve_customop_matmul', 'convert Custom(Matmul) op to Matmul op'),
- ('resolve_customop_max_pool_with_argmax',
- 'convert Custom(MaxPoolWithArgmax) to net of builtin operators'),
- ('shuffle_weight_to_16x1float32',
- 'convert weight format of FullyConnected op to SHUFFLED16x1FLOAT32.'
- ' Note that it only converts weights whose row is a multiple of 16'),
- ('substitute_pack_to_reshape', 'convert single input Pack op to Reshape op'),
- ('substitute_padv2_to_pad', 'convert certain condition PadV2 to Pad'),
- ('substitute_splitv_to_split', 'convert certain condition SplitV to Split'),
- ('substitute_squeeze_to_reshape', 'convert certain condition Squeeze to Reshape'),
- ('substitute_strided_slice_to_reshape',
- 'convert certain condition StridedSlice to Reshape'),
- ('substitute_transpose_to_reshape',
- 'convert certain condition Transpose to Reshape'),
- ('transform_min_max_to_relu6', 'transform Minimum-Maximum pattern to Relu6 op'),
- ('transform_min_relu_to_relu6', 'transform Minimum(6)-Relu pattern to Relu6 op'))
-
-
-_CONSTANT = _CONSTANT()
+import onelib.constant as _constant
def _add_default_arg(parser):
@@ -116,7 +49,10 @@ def _add_default_arg(parser):
def is_accumulated_arg(arg, driver):
if driver == "one-quantize":
- if arg == "tensor_name" or arg == "scale" or arg == "zero_point":
+ accumulables = [
+ "tensor_name", "scale", "zero_point", "src_tensor_name", "dst_tensor_name"
+ ]
+ if arg in accumulables:
return True
return False
@@ -189,83 +125,6 @@ def _parse_cfg(args, driver_name):
setattr(args, key, config[secton_to_run][key])
-def _make_tf2tfliteV2_cmd(args, driver_path, input_path, output_path):
- """make a command for running tf2tfliteV2.py"""
- cmd = [sys.executable, os.path.expanduser(driver_path)]
- # verbose
- if _is_valid_attr(args, 'verbose'):
- cmd.append('--verbose')
- # model_format
- if _is_valid_attr(args, 'model_format_cmd'):
- cmd.append(getattr(args, 'model_format_cmd'))
- elif _is_valid_attr(args, 'model_format'):
- cmd.append('--' + getattr(args, 'model_format'))
- else:
- cmd.append('--graph_def') # default value
- # converter version
- if _is_valid_attr(args, 'converter_version_cmd'):
- cmd.append(getattr(args, 'converter_version_cmd'))
- elif _is_valid_attr(args, 'converter_version'):
- cmd.append('--' + getattr(args, 'converter_version'))
- else:
- cmd.append('--v1') # default value
- # input_path
- if _is_valid_attr(args, 'input_path'):
- cmd.append('--input_path')
- cmd.append(os.path.expanduser(input_path))
- # output_path
- if _is_valid_attr(args, 'output_path'):
- cmd.append('--output_path')
- cmd.append(os.path.expanduser(output_path))
- # input_arrays
- if _is_valid_attr(args, 'input_arrays'):
- cmd.append('--input_arrays')
- cmd.append(getattr(args, 'input_arrays'))
- # input_shapes
- if _is_valid_attr(args, 'input_shapes'):
- cmd.append('--input_shapes')
- cmd.append(getattr(args, 'input_shapes'))
- # output_arrays
- if _is_valid_attr(args, 'output_arrays'):
- cmd.append('--output_arrays')
- cmd.append(getattr(args, 'output_arrays'))
-
- return cmd
-
-
-def _make_tflite2circle_cmd(driver_path, input_path, output_path):
- """make a command for running tflite2circle"""
- cmd = [driver_path, input_path, output_path]
- return [os.path.expanduser(c) for c in cmd]
-
-
-def _make_circle2circle_cmd(args, driver_path, input_path, output_path):
- """make a command for running circle2circle"""
- cmd = [os.path.expanduser(c) for c in [driver_path, input_path, output_path]]
- # profiling
- if _is_valid_attr(args, 'generate_profile_data'):
- cmd.append('--generate_profile_data')
- # optimization pass(only true/false options)
- # TODO support options whose number of arguments is more than zero
- for opt in _CONSTANT.OPTIMIZATION_OPTS:
- if _is_valid_attr(args, opt[0]):
- # ./driver --opt[0]
- if type(getattr(args, opt[0])) is bool:
- cmd.append('--' + opt[0])
- """
- This condition check is for config file interface, usually would be
- SomeOption=True
- but user can write as follows while development
- SomeOption=False
- instead of removing SomeOption option
- """
- if type(getattr(args, opt[0])) is str and not getattr(
- args, opt[0]).lower() in ['false', '0', 'n']:
- cmd.append('--' + opt[0])
-
- return cmd
-
-
def _print_version_and_exit(file_path):
"""print version of the file located in the file_path"""
script_path = os.path.realpath(file_path)
@@ -368,3 +227,34 @@ def _get_optimization_list(get_name=False):
opt_list = [_remove_suffix(s, '.cfg') for s in opt_list]
return opt_list
+
+
+def _detect_one_import_drivers(search_path):
+ """Looks for import drivers in given directory
+
+ Args:
+ search_path: path to the directory where to search import drivers
+
+ Returns:
+ dict: each entry is related to single detected driver,
+ key is a config section name, value is a driver name
+
+ """
+ import_drivers_dict = {}
+ for module_name in os.listdir(search_path):
+ full_path = os.path.join(search_path, module_name)
+ if not os.path.isfile(full_path):
+ continue
+ if module_name.find("one-import-") != 0:
+ continue
+ module_loader = importlib.machinery.SourceFileLoader(module_name, full_path)
+ module_spec = importlib.util.spec_from_loader(module_name, module_loader)
+ module = importlib.util.module_from_spec(module_spec)
+ try:
+ module_loader.exec_module(module)
+ if hasattr(module, "get_driver_cfg_section"):
+ section = module.get_driver_cfg_section()
+ import_drivers_dict[section] = module_name
+ except:
+ pass
+ return import_drivers_dict
diff --git a/compiler/oneco/CMakeLists.txt b/compiler/oneco/CMakeLists.txt
index 418bc27ac..951194d9d 100644
--- a/compiler/oneco/CMakeLists.txt
+++ b/compiler/oneco/CMakeLists.txt
@@ -22,11 +22,11 @@ target_link_libraries(moco_onnx_frontend PUBLIC moco_onnx_proto)
target_link_libraries(moco_onnx_frontend PUBLIC loco)
target_link_libraries(moco_onnx_frontend PRIVATE cwrap)
-nnas_find_package(GTest QUIET)
-
-if(NOT GTest_FOUND)
+if(NOT ENABLE_TEST)
return()
-endif(NOT GTest_FOUND)
+endif(NOT ENABLE_TEST)
+
+nnas_find_package(GTest QUIET)
add_executable(moco_onnx_frontend_test ${TESTS})
target_include_directories(moco_onnx_frontend_test PRIVATE src)
diff --git a/compiler/pepper-strcast/CMakeLists.txt b/compiler/pepper-strcast/CMakeLists.txt
index 5f87e9488..bcc07f482 100644
--- a/compiler/pepper-strcast/CMakeLists.txt
+++ b/compiler/pepper-strcast/CMakeLists.txt
@@ -3,7 +3,9 @@ file(GLOB_RECURSE TESTS "src/*.test.cpp")
list(REMOVE_ITEM SOURCES ${TESTS})
add_library(pepper_strcast STATIC ${SOURCES})
-set_target_properties(pepper_strcast PROPERTIES POSITION_INDEPENDENT_CODE ON)
+if (NOT NNCC_LIBRARY_NO_PIC)
+ set_target_properties(pepper_strcast PROPERTIES POSITION_INDEPENDENT_CODE ON)
+endif(NOT NNCC_LIBRARY_NO_PIC)
target_include_directories(pepper_strcast PUBLIC include)
target_link_libraries(pepper_strcast PRIVATE nncc_common)
target_link_libraries(pepper_strcast PUBLIC nncc_coverage)
diff --git a/compiler/pota-quantization-value-test/CMakeLists.txt b/compiler/pota-quantization-value-test/CMakeLists.txt
index 00ffb57de..51fd9a391 100644
--- a/compiler/pota-quantization-value-test/CMakeLists.txt
+++ b/compiler/pota-quantization-value-test/CMakeLists.txt
@@ -1,7 +1,9 @@
unset(QUANTIZATION_VALUE_TEST)
unset(QUANTIZATION_VALUE_TEST_WITH_PARAM)
+unset(QUANTIZATION_CONFIG_VALUE_TEST)
+unset(QUANTIZATION_CONFIG_VALUE_TEST_WITH_PARAM)
-nnas_find_package(FlatBuffers EXACT 1.10 QUIET)
+nnas_find_package(FlatBuffers EXACT 2.0 QUIET)
if(NOT FlatBuffers_FOUND)
message(STATUS "Build pota-quantization-value-test: FAILED (missing FlatBuffers)")
return()
@@ -12,6 +14,11 @@ macro(addTest NAME GRANULARITY DTYPE)
list(APPEND QUANTIZATION_VALUE_TEST_WITH_PARAM ${NAME} ${GRANULARITY} ${DTYPE})
endmacro(addTest)
+macro(addQConfTest NAME GRANULARITY DTYPE)
+ list(APPEND QUANTIZATION_CONFIG_VALUE_TEST ${NAME})
+ list(APPEND QUANTIZATION_CONFIG_VALUE_TEST_WITH_PARAM ${NAME} ${GRANULARITY} ${DTYPE})
+endmacro(addQConfTest)
+
# Read "test.lst"
include("test.lst")
# Read "test.local.lst" if exists
@@ -20,12 +27,12 @@ include("test.local.lst" OPTIONAL)
unset(TEST_DEPS)
get_target_property(ARTIFACTS_BIN_PATH testDataGenerator BINARY_DIR)
-get_target_property(SCHEMA_BIN_PATH mio_circle BINARY_DIR)
+get_target_property(SCHEMA_BIN_PATH mio_circle04 BINARY_DIR)
configure_file("${CMAKE_CURRENT_SOURCE_DIR}/gen_h5_explicit_inputs.py"
"${CMAKE_CURRENT_BINARY_DIR}/gen_h5_explicit_inputs.py" COPYONLY)
-set(VIRTUALENV "${NNCC_OVERLAY_DIR}/venv_2_6_0")
+set(VIRTUALENV "${NNCC_OVERLAY_DIR}/venv_2_8_0")
###
### Generate test.config
@@ -89,5 +96,22 @@ add_test(
${QUANTIZATION_VALUE_TEST_WITH_PARAM}
)
+add_test(
+ NAME pota_fake_wquant_test_with_config
+ COMMAND "${CMAKE_CURRENT_SOURCE_DIR}/test_fake_wquant_with_config.sh"
+ "${TEST_CONFIG}"
+ "${ARTIFACTS_BIN_PATH}"
+ ${QUANTIZATION_CONFIG_VALUE_TEST_WITH_PARAM}
+)
+
+add_test(
+ NAME pota_quantization_test_with_config
+ COMMAND "${CMAKE_CURRENT_SOURCE_DIR}/test_quantization_with_config.sh"
+ "${TEST_CONFIG}"
+ "${ARTIFACTS_BIN_PATH}"
+ ${QUANTIZATION_CONFIG_VALUE_TEST_WITH_PARAM}
+)
+
set_tests_properties(pota_record_minmax_test PROPERTIES DEPENDS pota_fake_wquant_test)
set_tests_properties(pota_quantization_test PROPERTIES DEPENDS pota_record_minmax_test)
+set_tests_properties(pota_quantization_test_with_config PROPERTIES DEPENDS pota_fake_wquant_test_with_config)
diff --git a/compiler/pota-quantization-value-test/config_files/Add_002/channel/int16/qconf.json b/compiler/pota-quantization-value-test/config_files/Add_002/channel/int16/qconf.json
new file mode 100644
index 000000000..838b331fd
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/Add_002/channel/int16/qconf.json
@@ -0,0 +1,9 @@
+{
+ "layers" : [
+ {
+ "name" : "ofm",
+ "dtype" : "uint8",
+ "granularity" : "layer"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/Add_002/layer/uint8/qconf.json b/compiler/pota-quantization-value-test/config_files/Add_002/layer/uint8/qconf.json
new file mode 100644
index 000000000..7cd6ce713
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/Add_002/layer/uint8/qconf.json
@@ -0,0 +1,9 @@
+{
+ "layers" : [
+ {
+ "name" : "ofm",
+ "dtype" : "int16",
+ "granularity" : "channel"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/AveragePool2D_000/channel/int16/qconf.json b/compiler/pota-quantization-value-test/config_files/AveragePool2D_000/channel/int16/qconf.json
new file mode 100644
index 000000000..838b331fd
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/AveragePool2D_000/channel/int16/qconf.json
@@ -0,0 +1,9 @@
+{
+ "layers" : [
+ {
+ "name" : "ofm",
+ "dtype" : "uint8",
+ "granularity" : "layer"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/AveragePool2D_000/layer/uint8/qconf.json b/compiler/pota-quantization-value-test/config_files/AveragePool2D_000/layer/uint8/qconf.json
new file mode 100644
index 000000000..7cd6ce713
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/AveragePool2D_000/layer/uint8/qconf.json
@@ -0,0 +1,9 @@
+{
+ "layers" : [
+ {
+ "name" : "ofm",
+ "dtype" : "int16",
+ "granularity" : "channel"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/Concatenation_001/channel/int16/qconf.json b/compiler/pota-quantization-value-test/config_files/Concatenation_001/channel/int16/qconf.json
new file mode 100644
index 000000000..838b331fd
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/Concatenation_001/channel/int16/qconf.json
@@ -0,0 +1,9 @@
+{
+ "layers" : [
+ {
+ "name" : "ofm",
+ "dtype" : "uint8",
+ "granularity" : "layer"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/Concatenation_001/layer/uint8/qconf.json b/compiler/pota-quantization-value-test/config_files/Concatenation_001/layer/uint8/qconf.json
new file mode 100644
index 000000000..7cd6ce713
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/Concatenation_001/layer/uint8/qconf.json
@@ -0,0 +1,9 @@
+{
+ "layers" : [
+ {
+ "name" : "ofm",
+ "dtype" : "int16",
+ "granularity" : "channel"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/Conv2D_004/channel/int16/qconf.json b/compiler/pota-quantization-value-test/config_files/Conv2D_004/channel/int16/qconf.json
new file mode 100644
index 000000000..838b331fd
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/Conv2D_004/channel/int16/qconf.json
@@ -0,0 +1,9 @@
+{
+ "layers" : [
+ {
+ "name" : "ofm",
+ "dtype" : "uint8",
+ "granularity" : "layer"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/Conv2D_004/layer/uint8/qconf.json b/compiler/pota-quantization-value-test/config_files/Conv2D_004/layer/uint8/qconf.json
new file mode 100644
index 000000000..7cd6ce713
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/Conv2D_004/layer/uint8/qconf.json
@@ -0,0 +1,9 @@
+{
+ "layers" : [
+ {
+ "name" : "ofm",
+ "dtype" : "int16",
+ "granularity" : "channel"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/DepthwiseConv2D_002/channel/int16/qconf.json b/compiler/pota-quantization-value-test/config_files/DepthwiseConv2D_002/channel/int16/qconf.json
new file mode 100644
index 000000000..838b331fd
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/DepthwiseConv2D_002/channel/int16/qconf.json
@@ -0,0 +1,9 @@
+{
+ "layers" : [
+ {
+ "name" : "ofm",
+ "dtype" : "uint8",
+ "granularity" : "layer"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/DepthwiseConv2D_002/layer/uint8/qconf.json b/compiler/pota-quantization-value-test/config_files/DepthwiseConv2D_002/layer/uint8/qconf.json
new file mode 100644
index 000000000..7cd6ce713
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/DepthwiseConv2D_002/layer/uint8/qconf.json
@@ -0,0 +1,9 @@
+{
+ "layers" : [
+ {
+ "name" : "ofm",
+ "dtype" : "int16",
+ "granularity" : "channel"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/FullyConnected_003/channel/int16/qconf.json b/compiler/pota-quantization-value-test/config_files/FullyConnected_003/channel/int16/qconf.json
new file mode 100644
index 000000000..174d6e9b0
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/FullyConnected_003/channel/int16/qconf.json
@@ -0,0 +1,9 @@
+{
+ "layers" : [
+ {
+ "name" : "out",
+ "dtype" : "uint8",
+ "granularity" : "layer"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/FullyConnected_003/layer/uint8/qconf.json b/compiler/pota-quantization-value-test/config_files/FullyConnected_003/layer/uint8/qconf.json
new file mode 100644
index 000000000..733f46e60
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/FullyConnected_003/layer/uint8/qconf.json
@@ -0,0 +1,9 @@
+{
+ "layers" : [
+ {
+ "name" : "out",
+ "dtype" : "int16",
+ "granularity" : "channel"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/InstanceNorm_001/channel/int16/qconf.json b/compiler/pota-quantization-value-test/config_files/InstanceNorm_001/channel/int16/qconf.json
new file mode 100644
index 000000000..838b331fd
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/InstanceNorm_001/channel/int16/qconf.json
@@ -0,0 +1,9 @@
+{
+ "layers" : [
+ {
+ "name" : "ofm",
+ "dtype" : "uint8",
+ "granularity" : "layer"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/InstanceNorm_001/layer/uint8/qconf.json b/compiler/pota-quantization-value-test/config_files/InstanceNorm_001/layer/uint8/qconf.json
new file mode 100644
index 000000000..7cd6ce713
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/InstanceNorm_001/layer/uint8/qconf.json
@@ -0,0 +1,9 @@
+{
+ "layers" : [
+ {
+ "name" : "ofm",
+ "dtype" : "int16",
+ "granularity" : "channel"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/MaxPool2D_000/channel/int16/qconf.json b/compiler/pota-quantization-value-test/config_files/MaxPool2D_000/channel/int16/qconf.json
new file mode 100644
index 000000000..838b331fd
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/MaxPool2D_000/channel/int16/qconf.json
@@ -0,0 +1,9 @@
+{
+ "layers" : [
+ {
+ "name" : "ofm",
+ "dtype" : "uint8",
+ "granularity" : "layer"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/MaxPool2D_000/layer/uint8/qconf.json b/compiler/pota-quantization-value-test/config_files/MaxPool2D_000/layer/uint8/qconf.json
new file mode 100644
index 000000000..7cd6ce713
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/MaxPool2D_000/layer/uint8/qconf.json
@@ -0,0 +1,9 @@
+{
+ "layers" : [
+ {
+ "name" : "ofm",
+ "dtype" : "int16",
+ "granularity" : "channel"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/Mean_000/channel/int16/qconf.json b/compiler/pota-quantization-value-test/config_files/Mean_000/channel/int16/qconf.json
new file mode 100644
index 000000000..838b331fd
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/Mean_000/channel/int16/qconf.json
@@ -0,0 +1,9 @@
+{
+ "layers" : [
+ {
+ "name" : "ofm",
+ "dtype" : "uint8",
+ "granularity" : "layer"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/Mean_000/layer/uint8/qconf.json b/compiler/pota-quantization-value-test/config_files/Mean_000/layer/uint8/qconf.json
new file mode 100644
index 000000000..7cd6ce713
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/Mean_000/layer/uint8/qconf.json
@@ -0,0 +1,9 @@
+{
+ "layers" : [
+ {
+ "name" : "ofm",
+ "dtype" : "int16",
+ "granularity" : "channel"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/Mul_001/channel/int16/qconf.json b/compiler/pota-quantization-value-test/config_files/Mul_001/channel/int16/qconf.json
new file mode 100644
index 000000000..838b331fd
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/Mul_001/channel/int16/qconf.json
@@ -0,0 +1,9 @@
+{
+ "layers" : [
+ {
+ "name" : "ofm",
+ "dtype" : "uint8",
+ "granularity" : "layer"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/Mul_001/layer/uint8/qconf.json b/compiler/pota-quantization-value-test/config_files/Mul_001/layer/uint8/qconf.json
new file mode 100644
index 000000000..7cd6ce713
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/Mul_001/layer/uint8/qconf.json
@@ -0,0 +1,9 @@
+{
+ "layers" : [
+ {
+ "name" : "ofm",
+ "dtype" : "int16",
+ "granularity" : "channel"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/PRelu_001/channel/int16/qconf.json b/compiler/pota-quantization-value-test/config_files/PRelu_001/channel/int16/qconf.json
new file mode 100644
index 000000000..838b331fd
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/PRelu_001/channel/int16/qconf.json
@@ -0,0 +1,9 @@
+{
+ "layers" : [
+ {
+ "name" : "ofm",
+ "dtype" : "uint8",
+ "granularity" : "layer"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/PRelu_001/layer/uint8/qconf.json b/compiler/pota-quantization-value-test/config_files/PRelu_001/layer/uint8/qconf.json
new file mode 100644
index 000000000..7cd6ce713
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/PRelu_001/layer/uint8/qconf.json
@@ -0,0 +1,9 @@
+{
+ "layers" : [
+ {
+ "name" : "ofm",
+ "dtype" : "int16",
+ "granularity" : "channel"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/ReLU_000/channel/int16/qconf.json b/compiler/pota-quantization-value-test/config_files/ReLU_000/channel/int16/qconf.json
new file mode 100644
index 000000000..838b331fd
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/ReLU_000/channel/int16/qconf.json
@@ -0,0 +1,9 @@
+{
+ "layers" : [
+ {
+ "name" : "ofm",
+ "dtype" : "uint8",
+ "granularity" : "layer"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/ReLU_000/layer/uint8/qconf.json b/compiler/pota-quantization-value-test/config_files/ReLU_000/layer/uint8/qconf.json
new file mode 100644
index 000000000..7cd6ce713
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/ReLU_000/layer/uint8/qconf.json
@@ -0,0 +1,9 @@
+{
+ "layers" : [
+ {
+ "name" : "ofm",
+ "dtype" : "int16",
+ "granularity" : "channel"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/Split_000/channel/int16/qconf.json b/compiler/pota-quantization-value-test/config_files/Split_000/channel/int16/qconf.json
new file mode 100644
index 000000000..630c3e420
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/Split_000/channel/int16/qconf.json
@@ -0,0 +1,14 @@
+{
+ "layers" : [
+ {
+ "name" : "ofm1",
+ "dtype" : "uint8",
+ "granularity" : "channel"
+ },
+ {
+ "name" : "ofm2",
+ "dtype" : "uint8",
+ "granularity" : "channel"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/Split_000/channel/uint8/qconf.json b/compiler/pota-quantization-value-test/config_files/Split_000/channel/uint8/qconf.json
new file mode 100644
index 000000000..cc98d7c62
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/Split_000/channel/uint8/qconf.json
@@ -0,0 +1,14 @@
+{
+ "layers" : [
+ {
+ "name" : "ofm1",
+ "dtype" : "int16",
+ "granularity" : "channel"
+ },
+ {
+ "name" : "ofm2",
+ "dtype" : "int16",
+ "granularity" : "channel"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/TransposeConv_001/channel/int16/qconf.json b/compiler/pota-quantization-value-test/config_files/TransposeConv_001/channel/int16/qconf.json
new file mode 100644
index 000000000..838b331fd
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/TransposeConv_001/channel/int16/qconf.json
@@ -0,0 +1,9 @@
+{
+ "layers" : [
+ {
+ "name" : "ofm",
+ "dtype" : "uint8",
+ "granularity" : "layer"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/config_files/TransposeConv_001/layer/uint8/qconf.json b/compiler/pota-quantization-value-test/config_files/TransposeConv_001/layer/uint8/qconf.json
new file mode 100644
index 000000000..7cd6ce713
--- /dev/null
+++ b/compiler/pota-quantization-value-test/config_files/TransposeConv_001/layer/uint8/qconf.json
@@ -0,0 +1,9 @@
+{
+ "layers" : [
+ {
+ "name" : "ofm",
+ "dtype" : "int16",
+ "granularity" : "channel"
+ }
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Add_002_config/channel/int16/quantization/ifm1_Quantize.json b/compiler/pota-quantization-value-test/expected_outputs/Add_002_config/channel/int16/quantization/ifm1_Quantize.json
new file mode 100644
index 000000000..a223fa4aa
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Add_002_config/channel/int16/quantization/ifm1_Quantize.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.038489170372486115,
+ "zero_point": 129.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Add_002_config/channel/int16/quantization/ifm2.json b/compiler/pota-quantization-value-test/expected_outputs/Add_002_config/channel/int16/quantization/ifm2.json
new file mode 100644
index 000000000..ec6082d55
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Add_002_config/channel/int16/quantization/ifm2.json
@@ -0,0 +1,32 @@
+{
+ "weights": [
+ [
+ [
+ [
+ 136,
+ 153,
+ 68
+ ],
+ [
+ 51,
+ 34,
+ 221
+ ]
+ ],
+ [
+ [
+ 0,
+ 255,
+ 187
+ ],
+ [
+ 85,
+ 170,
+ 102
+ ]
+ ]
+ ]
+ ],
+ "scale": 0.05882352963089943,
+ "zero_point": 119.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Add_002_config/channel/int16/quantization/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/Add_002_config/channel/int16/quantization/ofm.json
new file mode 100644
index 000000000..afa9b1a8e
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Add_002_config/channel/int16/quantization/ofm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.0892433300614357,
+ "zero_point": 134.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Add_002_config/layer/uint8/quantization/ifm1_Quantize.json b/compiler/pota-quantization-value-test/expected_outputs/Add_002_config/layer/uint8/quantization/ifm1_Quantize.json
new file mode 100644
index 000000000..a7298cb58
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Add_002_config/layer/uint8/quantization/ifm1_Quantize.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.00014653272228315473,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Add_002_config/layer/uint8/quantization/ifm2.json b/compiler/pota-quantization-value-test/expected_outputs/Add_002_config/layer/uint8/quantization/ifm2.json
new file mode 100644
index 000000000..ab968c9fc
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Add_002_config/layer/uint8/quantization/ifm2.json
@@ -0,0 +1,32 @@
+{
+ "weights": [
+ [
+ [
+ [
+ 4096,
+ 8192,
+ -12288
+ ],
+ [
+ -16384,
+ -20479,
+ 24575
+ ]
+ ],
+ [
+ [
+ -28671,
+ 32767,
+ 16384
+ ],
+ [
+ -8192,
+ 12288,
+ -4096
+ ]
+ ]
+ ]
+ ],
+ "scale": 0.0002441480755805969,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Add_002_config/layer/uint8/quantization/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/Add_002_config/layer/uint8/quantization/ofm.json
new file mode 100644
index 000000000..3cb0552e9
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Add_002_config/layer/uint8/quantization/ofm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.00037035736022517085,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/AveragePool2D_000_config/channel/int16/quantization/ifm_Quantize.json b/compiler/pota-quantization-value-test/expected_outputs/AveragePool2D_000_config/channel/int16/quantization/ifm_Quantize.json
new file mode 100644
index 000000000..0528cc9cc
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/AveragePool2D_000_config/channel/int16/quantization/ifm_Quantize.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.03911808878183365,
+ "zero_point": 127.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/AveragePool2D_000_config/channel/int16/quantization/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/AveragePool2D_000_config/channel/int16/quantization/ofm.json
new file mode 100644
index 000000000..ac5da0bda
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/AveragePool2D_000_config/channel/int16/quantization/ofm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.027372928336262703,
+ "zero_point": 141.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/AveragePool2D_000_config/layer/uint8/quantization/ifm_Quantize.json b/compiler/pota-quantization-value-test/expected_outputs/AveragePool2D_000_config/layer/uint8/quantization/ifm_Quantize.json
new file mode 100644
index 000000000..353f15a6b
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/AveragePool2D_000_config/layer/uint8/quantization/ifm_Quantize.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.0001523942337371409,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/AveragePool2D_000_config/layer/uint8/quantization/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/AveragePool2D_000_config/layer/uint8/quantization/ofm.json
new file mode 100644
index 000000000..c4ace78d4
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/AveragePool2D_000_config/layer/uint8/quantization/ofm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.00012122748012188822,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Concatenation_001_config/channel/int16/quantization/ifm1_Quantize.json b/compiler/pota-quantization-value-test/expected_outputs/Concatenation_001_config/channel/int16/quantization/ifm1_Quantize.json
new file mode 100644
index 000000000..522880618
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Concatenation_001_config/channel/int16/quantization/ifm1_Quantize.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.05882352963089943,
+ "zero_point": 119.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Concatenation_001_config/channel/int16/quantization/ifm2.json b/compiler/pota-quantization-value-test/expected_outputs/Concatenation_001_config/channel/int16/quantization/ifm2.json
new file mode 100644
index 000000000..17ba25363
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Concatenation_001_config/channel/int16/quantization/ifm2.json
@@ -0,0 +1,28 @@
+{
+ "weights": [
+ [
+ [
+ [
+ 136,
+ 153
+ ],
+ [
+ 68,
+ 51
+ ]
+ ],
+ [
+ [
+ 34,
+ 221
+ ],
+ [
+ 0,
+ 255
+ ]
+ ]
+ ]
+ ],
+ "scale": 0.05882352963089943,
+ "zero_point": 119.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Concatenation_001_config/channel/int16/quantization/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/Concatenation_001_config/channel/int16/quantization/ofm.json
new file mode 100644
index 000000000..522880618
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Concatenation_001_config/channel/int16/quantization/ofm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.05882352963089943,
+ "zero_point": 119.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Concatenation_001_config/layer/uint8/quantization/ifm1_Quantize.json b/compiler/pota-quantization-value-test/expected_outputs/Concatenation_001_config/layer/uint8/quantization/ifm1_Quantize.json
new file mode 100644
index 000000000..71265a270
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Concatenation_001_config/layer/uint8/quantization/ifm1_Quantize.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.0002441480755805969,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Concatenation_001_config/layer/uint8/quantization/ifm2.json b/compiler/pota-quantization-value-test/expected_outputs/Concatenation_001_config/layer/uint8/quantization/ifm2.json
new file mode 100644
index 000000000..53d7cdba3
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Concatenation_001_config/layer/uint8/quantization/ifm2.json
@@ -0,0 +1,28 @@
+{
+ "weights": [
+ [
+ [
+ [
+ 4096,
+ 8192
+ ],
+ [
+ -12288,
+ -16384
+ ]
+ ],
+ [
+ [
+ -20479,
+ 24575
+ ],
+ [
+ -28671,
+ 32767
+ ]
+ ]
+ ]
+ ],
+ "scale": 0.0002441480755805969,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Concatenation_001_config/layer/uint8/quantization/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/Concatenation_001_config/layer/uint8/quantization/ofm.json
new file mode 100644
index 000000000..71265a270
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Concatenation_001_config/layer/uint8/quantization/ofm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.0002441480755805969,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/channel/int16/fake_quantization/ker.json b/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/channel/int16/fake_quantization/ker.json
new file mode 100644
index 000000000..2558bb2be
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/channel/int16/fake_quantization/ker.json
@@ -0,0 +1,48 @@
+{
+ "weights": [
+ [
+ [
+ [
+ 1.0039215087890625,
+ 2.007843017578125
+ ],
+ [
+ -3.0117650032043457,
+ -4.015686511993408
+ ]
+ ],
+ [
+ [
+ -5.019608497619629,
+ 6.023530006408691
+ ],
+ [
+ -7.027451515197754,
+ 7.9686279296875
+ ]
+ ]
+ ],
+ [
+ [
+ [
+ 4.01568603515625,
+ -2.007843494415283
+ ],
+ [
+ 3.0117645263671875,
+ -1.0039215087890625
+ ]
+ ],
+ [
+ [
+ -7.9686279296875,
+ -6.023530006408691
+ ],
+ [
+ 7.027451515197754,
+ 5.019608497619629
+ ]
+ ]
+ ]
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/channel/int16/quantization/bias.json b/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/channel/int16/quantization/bias.json
new file mode 100644
index 000000000..50d44ece7
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/channel/int16/quantization/bias.json
@@ -0,0 +1,7 @@
+{
+ "weights": [
+ 4069,
+ 8138
+ ],
+ "scale": 0.0002457468386200985
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/channel/int16/quantization/ifm_Quantize.json b/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/channel/int16/quantization/ifm_Quantize.json
new file mode 100644
index 000000000..24508860d
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/channel/int16/quantization/ifm_Quantize.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.003916590008884668,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/channel/int16/quantization/ker.json b/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/channel/int16/quantization/ker.json
new file mode 100644
index 000000000..b249a0ce5
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/channel/int16/quantization/ker.json
@@ -0,0 +1,52 @@
+{
+ "weights": [
+ [
+ [
+ [
+ 143,
+ 159
+ ],
+ [
+ 79,
+ 63
+ ]
+ ],
+ [
+ [
+ 47,
+ 223
+ ],
+ [
+ 15,
+ 254
+ ]
+ ]
+ ],
+ [
+ [
+ [
+ 191,
+ 95
+ ],
+ [
+ 175,
+ 111
+ ]
+ ],
+ [
+ [
+ 0,
+ 31
+ ],
+ [
+ 239,
+ 207
+ ]
+ ]
+ ]
+ ],
+ "scale": 0.062745101749897,
+ "zero_point": 127.0,
+ "min": -7.9686279296875,
+ "max": 8.031373023986816
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/channel/int16/quantization/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/channel/int16/quantization/ofm.json
new file mode 100644
index 000000000..a2dd6681f
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/channel/int16/quantization/ofm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.037479765713214874,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/layer/uint8/fake_quantization/ker.json b/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/layer/uint8/fake_quantization/ker.json
new file mode 100644
index 000000000..8817cbef7
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/layer/uint8/fake_quantization/ker.json
@@ -0,0 +1,48 @@
+{
+ "weights": [
+ [
+ [
+ [
+ 1.000030517578125,
+ 2.00006103515625
+ ],
+ [
+ -3.000091552734375,
+ -4.0001220703125
+ ]
+ ],
+ [
+ [
+ -4.999908447265625,
+ 5.99993896484375
+ ],
+ [
+ -6.999969482421875,
+ 8.0
+ ]
+ ]
+ ],
+ [
+ [
+ [
+ 4.0001220703125,
+ -2.00006103515625
+ ],
+ [
+ 3.000091552734375,
+ -1.000030517578125
+ ]
+ ],
+ [
+ [
+ -8.0,
+ -5.99993896484375
+ ],
+ [
+ 6.999969482421875,
+ 4.999908447265625
+ ]
+ ]
+ ]
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/layer/uint8/quantization/bias.json b/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/layer/uint8/quantization/bias.json
new file mode 100644
index 000000000..b00d8d211
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/layer/uint8/quantization/bias.json
@@ -0,0 +1,10 @@
+{
+ "weights": [
+ 26925029,
+ 53850057
+ ],
+ "scale": [
+ 3.714016479907864e-08,
+ 3.714016479907864e-08
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/layer/uint8/quantization/ifm_Quantize.json b/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/layer/uint8/quantization/ifm_Quantize.json
new file mode 100644
index 000000000..df5d06c09
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/layer/uint8/quantization/ifm_Quantize.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.00015212147263810039,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/layer/uint8/quantization/ker.json b/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/layer/uint8/quantization/ker.json
new file mode 100644
index 000000000..94c794fbb
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/layer/uint8/quantization/ker.json
@@ -0,0 +1,61 @@
+{
+ "weights": [
+ [
+ [
+ [
+ 4096,
+ 8192
+ ],
+ [
+ -12288,
+ -16384
+ ]
+ ],
+ [
+ [
+ -20479,
+ 24575
+ ],
+ [
+ -28671,
+ 32767
+ ]
+ ]
+ ],
+ [
+ [
+ [
+ 16384,
+ -8192
+ ],
+ [
+ 12288,
+ -4096
+ ]
+ ],
+ [
+ [
+ -32767,
+ -24575
+ ],
+ [
+ 28671,
+ 20479
+ ]
+ ]
+ ]
+ ],
+ "scale": [
+ 0.00024414807580797754,
+ 0.00024414807580797754
+ ],
+ "zero_point": 0.0,
+ "min": [
+ -8.0,
+ -8.0
+ ],
+ "max": [
+ 8.0,
+ 8.0
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/layer/uint8/quantization/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/layer/uint8/quantization/ofm.json
new file mode 100644
index 000000000..e02eeb9dc
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Conv2D_004_config/layer/uint8/quantization/ofm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.002048635622486472,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/channel/int16/uint8/fake_quantization/ker.json b/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/channel/int16/uint8/fake_quantization/ker.json
new file mode 100644
index 000000000..cd3479781
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/channel/int16/uint8/fake_quantization/ker.json
@@ -0,0 +1,34 @@
+{
+ "weights": [
+ [
+ [
+ [
+ 0.9725494384765625,
+ 1.945098876953125,
+ 3.039216995239258,
+ 4.0117645263671875
+ ],
+ [
+ -8.996077537536621,
+ 9.9686279296875,
+ -10.94117546081543,
+ 12.035295486450195
+ ]
+ ],
+ [
+ [
+ 4.98431396484375,
+ 5.9568634033203125,
+ 7.050981521606445,
+ 8.023530960083008
+ ],
+ [
+ 13.007843017578125,
+ -13.980391502380371,
+ 14.95294189453125,
+ -16.04705810546875
+ ]
+ ]
+ ]
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/channel/int16/uint8/quantization/bias.json b/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/channel/int16/uint8/quantization/bias.json
new file mode 100644
index 000000000..e60ff312e
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/channel/int16/uint8/quantization/bias.json
@@ -0,0 +1,9 @@
+{
+ "weights": [
+ 2156,
+ 4312,
+ 6468,
+ 8624
+ ],
+ "scale": 0.0004638272181067826
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/channel/int16/uint8/quantization/ifm_Quantize.json b/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/channel/int16/uint8/quantization/ifm_Quantize.json
new file mode 100644
index 000000000..4ec4ef2d7
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/channel/int16/uint8/quantization/ifm_Quantize.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.0038153529167175293,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/channel/int16/uint8/quantization/ker.json b/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/channel/int16/uint8/quantization/ker.json
new file mode 100644
index 000000000..01835fbde
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/channel/int16/uint8/quantization/ker.json
@@ -0,0 +1,38 @@
+{
+ "weights": [
+ [
+ [
+ [
+ 140,
+ 148,
+ 157,
+ 165
+ ],
+ [
+ 58,
+ 214,
+ 42,
+ 231
+ ]
+ ],
+ [
+ [
+ 173,
+ 181,
+ 190,
+ 198
+ ],
+ [
+ 239,
+ 17,
+ 255,
+ 0
+ ]
+ ]
+ ]
+ ],
+ "scale": 0.12156862765550613,
+ "zero_point": 132.0,
+ "min": -16.04705810546875,
+ "max": 14.952940940856934
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/channel/int16/uint8/quantization/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/channel/int16/uint8/quantization/ofm.json
new file mode 100644
index 000000000..39c64f3ef
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/channel/int16/uint8/quantization/ofm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.07362665981054306,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/layer/uint8/fake_quantization/ker.json b/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/layer/uint8/fake_quantization/ker.json
new file mode 100644
index 000000000..20c1f6759
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/layer/uint8/fake_quantization/ker.json
@@ -0,0 +1,34 @@
+{
+ "weights": [
+ [
+ [
+ [
+ 1.00018310546875,
+ 2.0,
+ 2.99981689453125,
+ 4.0001220703125
+ ],
+ [
+ -9.00006103515625,
+ 10.0,
+ -10.99993896484375,
+ 11.9998779296875
+ ]
+ ],
+ [
+ [
+ 5.0001220703125,
+ 6.0,
+ 6.9998779296875,
+ 8.000244140625
+ ],
+ [
+ 13.0,
+ -14.0,
+ 15.0,
+ -16.0
+ ]
+ ]
+ ]
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/layer/uint8/quantization/bias.json b/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/layer/uint8/quantization/bias.json
new file mode 100644
index 000000000..632333144
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/layer/uint8/quantization/bias.json
@@ -0,0 +1,14 @@
+{
+ "weights": [
+ 17503969,
+ 32507370,
+ 45510319,
+ 56887898
+ ],
+ "scale": [
+ 5.7129901172951205e-08,
+ 6.152450895548591e-08,
+ 6.591911673802062e-08,
+ 7.031372452055533e-08
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/layer/uint8/quantization/ifm_Quantize.json b/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/layer/uint8/quantization/ifm_Quantize.json
new file mode 100644
index 000000000..7105a686d
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/layer/uint8/quantization/ifm_Quantize.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.00014399811334442347,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/layer/uint8/quantization/ker.json b/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/layer/uint8/quantization/ker.json
new file mode 100644
index 000000000..d465a7c17
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/layer/uint8/quantization/ker.json
@@ -0,0 +1,53 @@
+{
+ "weights": [
+ [
+ [
+ [
+ 2521,
+ 4681,
+ 6553,
+ 8192
+ ],
+ [
+ -22685,
+ 23405,
+ -24029,
+ 24575
+ ]
+ ],
+ [
+ [
+ 12603,
+ 14043,
+ 15291,
+ 16384
+ ],
+ [
+ 32767,
+ -32767,
+ 32767,
+ -32767
+ ]
+ ]
+ ]
+ ],
+ "scale": [
+ 0.0003967406231879635,
+ 0.0004272591326639607,
+ 0.0004577776421399579,
+ 0.0004882961516159551
+ ],
+ "zero_point": 0.0,
+ "min": [
+ -13.0,
+ -14.0,
+ -15.0,
+ -16.0
+ ],
+ "max": [
+ 13.0,
+ 14.0,
+ 15.0,
+ 16.0
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/layer/uint8/quantization/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/layer/uint8/quantization/ofm.json
new file mode 100644
index 000000000..2d84cd7d8
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/DepthwiseConv2D_002_config/layer/uint8/quantization/ofm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.0031168656423687935,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/channel/int16/fake_quantization/weight.json b/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/channel/int16/fake_quantization/weight.json
new file mode 100644
index 000000000..e1da53ab0
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/channel/int16/fake_quantization/weight.json
@@ -0,0 +1,76 @@
+{
+ "weights": [
+ [
+ 1.0039215087890625,
+ 2.007843017578125,
+ -3.0117650032043457,
+ -4.015686511993408,
+ -5.019608497619629,
+ 6.023530006408691,
+ -7.027451515197754,
+ 7.9686279296875,
+ 4.01568603515625,
+ -2.007843494415283,
+ 3.0117645263671875,
+ -1.0039215087890625,
+ -7.9686279296875,
+ -6.023530006408691,
+ 7.027451515197754,
+ 5.019608497619629
+ ],
+ [
+ 1.0039215087890625,
+ 2.007843017578125,
+ -3.0117650032043457,
+ -4.015686511993408,
+ -5.019608497619629,
+ 6.023530006408691,
+ -7.027451515197754,
+ 7.9686279296875,
+ 4.01568603515625,
+ -2.007843494415283,
+ 3.0117645263671875,
+ -1.0039215087890625,
+ -7.9686279296875,
+ -6.023530006408691,
+ 7.027451515197754,
+ 5.019608497619629
+ ],
+ [
+ 1.0039215087890625,
+ 2.007843017578125,
+ -3.0117650032043457,
+ -4.015686511993408,
+ -5.019608497619629,
+ 6.023530006408691,
+ -7.027451515197754,
+ 7.9686279296875,
+ 4.01568603515625,
+ -2.007843494415283,
+ 3.0117645263671875,
+ -1.0039215087890625,
+ -7.9686279296875,
+ -6.023530006408691,
+ 7.027451515197754,
+ 5.019608497619629
+ ],
+ [
+ 1.0039215087890625,
+ 2.007843017578125,
+ -3.0117650032043457,
+ -4.015686511993408,
+ -5.019608497619629,
+ 6.023530006408691,
+ -7.027451515197754,
+ 7.9686279296875,
+ 4.01568603515625,
+ -2.007843494415283,
+ 3.0117645263671875,
+ -1.0039215087890625,
+ -7.9686279296875,
+ -6.023530006408691,
+ 7.027451515197754,
+ 5.019608497619629
+ ]
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/channel/int16/quantization/bias.json b/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/channel/int16/quantization/bias.json
new file mode 100644
index 000000000..ecb49bb64
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/channel/int16/quantization/bias.json
@@ -0,0 +1,9 @@
+{
+ "weights": [
+ 415,
+ -829,
+ -1244,
+ 1658
+ ],
+ "scale": 0.00241205753304663
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/channel/int16/quantization/in_Quantize.json b/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/channel/int16/quantization/in_Quantize.json
new file mode 100644
index 000000000..654824b5d
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/channel/int16/quantization/in_Quantize.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.03844216465950012,
+ "zero_point": 126.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/channel/int16/quantization/out.json b/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/channel/int16/quantization/out.json
new file mode 100644
index 000000000..3baa42155
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/channel/int16/quantization/out.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.741962730884552,
+ "zero_point": 156.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/channel/int16/quantization/weight.json b/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/channel/int16/quantization/weight.json
new file mode 100644
index 000000000..940224049
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/channel/int16/quantization/weight.json
@@ -0,0 +1,80 @@
+{
+ "weights": [
+ [
+ 143,
+ 159,
+ 79,
+ 63,
+ 47,
+ 223,
+ 15,
+ 254,
+ 191,
+ 95,
+ 175,
+ 111,
+ 0,
+ 31,
+ 239,
+ 207
+ ],
+ [
+ 143,
+ 159,
+ 79,
+ 63,
+ 47,
+ 223,
+ 15,
+ 254,
+ 191,
+ 95,
+ 175,
+ 111,
+ 0,
+ 31,
+ 239,
+ 207
+ ],
+ [
+ 143,
+ 159,
+ 79,
+ 63,
+ 47,
+ 223,
+ 15,
+ 254,
+ 191,
+ 95,
+ 175,
+ 111,
+ 0,
+ 31,
+ 239,
+ 207
+ ],
+ [
+ 143,
+ 159,
+ 79,
+ 63,
+ 47,
+ 223,
+ 15,
+ 254,
+ 191,
+ 95,
+ 175,
+ 111,
+ 0,
+ 31,
+ 239,
+ 207
+ ]
+ ],
+ "scale": 0.062745101749897,
+ "zero_point": 127.0,
+ "min": -7.9686279296875,
+ "max": 8.031373023986816
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/layer/uint8/fake_quantization/weight.json b/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/layer/uint8/fake_quantization/weight.json
new file mode 100644
index 000000000..559e537fc
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/layer/uint8/fake_quantization/weight.json
@@ -0,0 +1,76 @@
+{
+ "weights": [
+ [
+ 1.000030517578125,
+ 2.00006103515625,
+ -3.000091552734375,
+ -4.0001220703125,
+ -4.999908447265625,
+ 5.99993896484375,
+ -6.999969482421875,
+ 8.0,
+ 4.0001220703125,
+ -2.00006103515625,
+ 3.000091552734375,
+ -1.000030517578125,
+ -8.0,
+ -5.99993896484375,
+ 6.999969482421875,
+ 4.999908447265625
+ ],
+ [
+ 1.000030517578125,
+ 2.00006103515625,
+ -3.000091552734375,
+ -4.0001220703125,
+ -4.999908447265625,
+ 5.99993896484375,
+ -6.999969482421875,
+ 8.0,
+ 4.0001220703125,
+ -2.00006103515625,
+ 3.000091552734375,
+ -1.000030517578125,
+ -8.0,
+ -5.99993896484375,
+ 6.999969482421875,
+ 4.999908447265625
+ ],
+ [
+ 1.000030517578125,
+ 2.00006103515625,
+ -3.000091552734375,
+ -4.0001220703125,
+ -4.999908447265625,
+ 5.99993896484375,
+ -6.999969482421875,
+ 8.0,
+ 4.0001220703125,
+ -2.00006103515625,
+ 3.000091552734375,
+ -1.000030517578125,
+ -8.0,
+ -5.99993896484375,
+ 6.999969482421875,
+ 4.999908447265625
+ ],
+ [
+ 1.000030517578125,
+ 2.00006103515625,
+ -3.000091552734375,
+ -4.0001220703125,
+ -4.999908447265625,
+ 5.99993896484375,
+ -6.999969482421875,
+ 8.0,
+ 4.0001220703125,
+ -2.00006103515625,
+ 3.000091552734375,
+ -1.000030517578125,
+ -8.0,
+ -5.99993896484375,
+ 6.999969482421875,
+ 4.999908447265625
+ ]
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/layer/uint8/quantization/bias.json b/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/layer/uint8/quantization/bias.json
new file mode 100644
index 000000000..0186c03f4
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/layer/uint8/quantization/bias.json
@@ -0,0 +1,14 @@
+{
+ "weights": [
+ 27619368,
+ -55238737,
+ -82858105,
+ 110477474
+ ],
+ "scale": [
+ 3.620647604581258e-08,
+ 3.620647604581258e-08,
+ 3.620647604581258e-08,
+ 3.620647604581258e-08
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/layer/uint8/quantization/in_Quantize.json b/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/layer/uint8/quantization/in_Quantize.json
new file mode 100644
index 000000000..1fd68cabe
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/layer/uint8/quantization/in_Quantize.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.00014829720021225512,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/layer/uint8/quantization/out.json b/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/layer/uint8/quantization/out.json
new file mode 100644
index 000000000..b2950218c
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/layer/uint8/quantization/out.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.003870659740641713,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/layer/uint8/quantization/weight.json b/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/layer/uint8/quantization/weight.json
new file mode 100644
index 000000000..69254d12b
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/FullyConnected_003_config/layer/uint8/quantization/weight.json
@@ -0,0 +1,95 @@
+{
+ "weights": [
+ [
+ 4096,
+ 8192,
+ -12288,
+ -16384,
+ -20479,
+ 24575,
+ -28671,
+ 32767,
+ 16384,
+ -8192,
+ 12288,
+ -4096,
+ -32767,
+ -24575,
+ 28671,
+ 20479
+ ],
+ [
+ 4096,
+ 8192,
+ -12288,
+ -16384,
+ -20479,
+ 24575,
+ -28671,
+ 32767,
+ 16384,
+ -8192,
+ 12288,
+ -4096,
+ -32767,
+ -24575,
+ 28671,
+ 20479
+ ],
+ [
+ 4096,
+ 8192,
+ -12288,
+ -16384,
+ -20479,
+ 24575,
+ -28671,
+ 32767,
+ 16384,
+ -8192,
+ 12288,
+ -4096,
+ -32767,
+ -24575,
+ 28671,
+ 20479
+ ],
+ [
+ 4096,
+ 8192,
+ -12288,
+ -16384,
+ -20479,
+ 24575,
+ -28671,
+ 32767,
+ 16384,
+ -8192,
+ 12288,
+ -4096,
+ -32767,
+ -24575,
+ 28671,
+ 20479
+ ]
+ ],
+ "scale": [
+ 0.00024414807580797754,
+ 0.00024414807580797754,
+ 0.00024414807580797754,
+ 0.00024414807580797754
+ ],
+ "zero_point": 0.0,
+ "min": [
+ -8.0,
+ -8.0,
+ -8.0,
+ -8.0
+ ],
+ "max": [
+ 8.0,
+ 8.0,
+ 8.0,
+ 8.0
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/MaxPool2D_000_config/channel/int16/quantization/ifm_Quantize.json b/compiler/pota-quantization-value-test/expected_outputs/MaxPool2D_000_config/channel/int16/quantization/ifm_Quantize.json
new file mode 100644
index 000000000..9bf6c9bff
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/MaxPool2D_000_config/channel/int16/quantization/ifm_Quantize.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.03876218944787979,
+ "zero_point": 126.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/MaxPool2D_000_config/channel/int16/quantization/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/MaxPool2D_000_config/channel/int16/quantization/ofm.json
new file mode 100644
index 000000000..87de1116e
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/MaxPool2D_000_config/channel/int16/quantization/ofm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.029836513102054596,
+ "zero_point": 88.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/MaxPool2D_000_config/layer/uint8/quantization/ifm_Quantize.json b/compiler/pota-quantization-value-test/expected_outputs/MaxPool2D_000_config/layer/uint8/quantization/ifm_Quantize.json
new file mode 100644
index 000000000..5d9052815
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/MaxPool2D_000_config/layer/uint8/quantization/ifm_Quantize.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.00015059474390000105,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/MaxPool2D_000_config/layer/uint8/quantization/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/MaxPool2D_000_config/layer/uint8/quantization/ofm.json
new file mode 100644
index 000000000..25491f05d
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/MaxPool2D_000_config/layer/uint8/quantization/ofm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.00014986195310484618,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Mean_000_config/channel/int16/quantization/ifm_Quantize.json b/compiler/pota-quantization-value-test/expected_outputs/Mean_000_config/channel/int16/quantization/ifm_Quantize.json
new file mode 100644
index 000000000..ede36c6ad
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Mean_000_config/channel/int16/quantization/ifm_Quantize.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.039086975157260895,
+ "zero_point": 128.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Mean_000_config/channel/int16/quantization/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/Mean_000_config/channel/int16/quantization/ofm.json
new file mode 100644
index 000000000..bd2fc7f62
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Mean_000_config/channel/int16/quantization/ofm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.028692100197076797,
+ "zero_point": 131.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Mean_000_config/layer/uint8/quantization/ifm_Quantize.json b/compiler/pota-quantization-value-test/expected_outputs/Mean_000_config/layer/uint8/quantization/ifm_Quantize.json
new file mode 100644
index 000000000..18c3b0421
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Mean_000_config/layer/uint8/quantization/ifm_Quantize.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.00015251495642587543,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Mean_000_config/layer/uint8/quantization/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/Mean_000_config/layer/uint8/quantization/ofm.json
new file mode 100644
index 000000000..145ee8fda
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Mean_000_config/layer/uint8/quantization/ofm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.00013844699424225837,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Mean_000_config/layer/uint8/quantization/reduction_indices.json b/compiler/pota-quantization-value-test/expected_outputs/Mean_000_config/layer/uint8/quantization/reduction_indices.json
new file mode 100644
index 000000000..394cfb322
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Mean_000_config/layer/uint8/quantization/reduction_indices.json
@@ -0,0 +1,5 @@
+{
+ "weights": [
+ -1
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Mul_001_config/channel/int16/quantization/ifm1_Quantize.json b/compiler/pota-quantization-value-test/expected_outputs/Mul_001_config/channel/int16/quantization/ifm1_Quantize.json
new file mode 100644
index 000000000..bbff8952d
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Mul_001_config/channel/int16/quantization/ifm1_Quantize.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.03780897706747055,
+ "zero_point": 131.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Mul_001_config/channel/int16/quantization/ifm2.json b/compiler/pota-quantization-value-test/expected_outputs/Mul_001_config/channel/int16/quantization/ifm2.json
new file mode 100644
index 000000000..ec6082d55
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Mul_001_config/channel/int16/quantization/ifm2.json
@@ -0,0 +1,32 @@
+{
+ "weights": [
+ [
+ [
+ [
+ 136,
+ 153,
+ 68
+ ],
+ [
+ 51,
+ 34,
+ 221
+ ]
+ ],
+ [
+ [
+ 0,
+ 255,
+ 187
+ ],
+ [
+ 85,
+ 170,
+ 102
+ ]
+ ]
+ ]
+ ],
+ "scale": 0.05882352963089943,
+ "zero_point": 119.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Mul_001_config/channel/int16/quantization/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/Mul_001_config/channel/int16/quantization/ofm.json
new file mode 100644
index 000000000..cec0bdf9a
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Mul_001_config/channel/int16/quantization/ofm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.232084259390831,
+ "zero_point": 111.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Mul_001_config/layer/uint8/quantization/ifm1_Quantize.json b/compiler/pota-quantization-value-test/expected_outputs/Mul_001_config/layer/uint8/quantization/ifm1_Quantize.json
new file mode 100644
index 000000000..f329b43be
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Mul_001_config/layer/uint8/quantization/ifm1_Quantize.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.0001513722527306527,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Mul_001_config/layer/uint8/quantization/ifm2.json b/compiler/pota-quantization-value-test/expected_outputs/Mul_001_config/layer/uint8/quantization/ifm2.json
new file mode 100644
index 000000000..ab968c9fc
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Mul_001_config/layer/uint8/quantization/ifm2.json
@@ -0,0 +1,32 @@
+{
+ "weights": [
+ [
+ [
+ [
+ 4096,
+ 8192,
+ -12288
+ ],
+ [
+ -16384,
+ -20479,
+ 24575
+ ]
+ ],
+ [
+ [
+ -28671,
+ 32767,
+ 16384
+ ],
+ [
+ -8192,
+ 12288,
+ -4096
+ ]
+ ]
+ ]
+ ],
+ "scale": 0.0002441480755805969,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Mul_001_config/layer/uint8/quantization/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/Mul_001_config/layer/uint8/quantization/ofm.json
new file mode 100644
index 000000000..4b5118c3e
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Mul_001_config/layer/uint8/quantization/ofm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.000991688808426261,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/PRelu_001_config/channel/int16/quantization/alpha.json b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001_config/channel/int16/quantization/alpha.json
new file mode 100644
index 000000000..7c001602f
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001_config/channel/int16/quantization/alpha.json
@@ -0,0 +1,13 @@
+{
+ "weights": [
+ [
+ [
+ 51,
+ 153,
+ 255
+ ]
+ ]
+ ],
+ "scale": 0.0019607844296842813,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/PRelu_001_config/channel/int16/quantization/ifm_Quantize.json b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001_config/channel/int16/quantization/ifm_Quantize.json
new file mode 100644
index 000000000..05ce9dd2c
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001_config/channel/int16/quantization/ifm_Quantize.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.03849203139543533,
+ "zero_point": 127.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/PRelu_001_config/channel/int16/quantization/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001_config/channel/int16/quantization/ofm.json
new file mode 100644
index 000000000..8f883094a
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001_config/channel/int16/quantization/ofm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.02848827838897705,
+ "zero_point": 82.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/PRelu_001_config/layer/uint8/quantization/alpha.json b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001_config/layer/uint8/quantization/alpha.json
new file mode 100644
index 000000000..6f99899d5
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001_config/layer/uint8/quantization/alpha.json
@@ -0,0 +1,21 @@
+{
+ "weights": [
+ [
+ [
+ 1,
+ 1,
+ 1
+ ]
+ ]
+ ],
+ "scale": [
+ 0.10000000149011612,
+ 0.30000001192092896,
+ 0.5
+ ],
+ "zero_point": [
+ 0,
+ 0,
+ 0
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/PRelu_001_config/layer/uint8/quantization/ifm_Quantize.json b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001_config/layer/uint8/quantization/ifm_Quantize.json
new file mode 100644
index 000000000..7d1f4c795
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001_config/layer/uint8/quantization/ifm_Quantize.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.00015214986342471093,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/PRelu_001_config/layer/uint8/quantization/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001_config/layer/uint8/quantization/ofm.json
new file mode 100644
index 000000000..533c1e3e0
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001_config/layer/uint8/quantization/ofm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.00015159364556893706,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/ReLU_000_config/channel/int16/quantization/ifm_Quantize.json b/compiler/pota-quantization-value-test/expected_outputs/ReLU_000_config/channel/int16/quantization/ifm_Quantize.json
new file mode 100644
index 000000000..3b97773ce
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/ReLU_000_config/channel/int16/quantization/ifm_Quantize.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.03907399624586105,
+ "zero_point": 127.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/ReLU_000_config/channel/int16/quantization/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/ReLU_000_config/channel/int16/quantization/ofm.json
new file mode 100644
index 000000000..698a8a7ee
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/ReLU_000_config/channel/int16/quantization/ofm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.01955186203122139,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/ReLU_000_config/layer/uint8/quantization/ifm_Quantize.json b/compiler/pota-quantization-value-test/expected_outputs/ReLU_000_config/layer/uint8/quantization/ifm_Quantize.json
new file mode 100644
index 000000000..5a52a1b7b
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/ReLU_000_config/layer/uint8/quantization/ifm_Quantize.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.0001474507007515058,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/ReLU_000_config/layer/uint8/quantization/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/ReLU_000_config/layer/uint8/quantization/ofm.json
new file mode 100644
index 000000000..ff9e41ec8
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/ReLU_000_config/layer/uint8/quantization/ofm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.0001422425702912733,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/int16/quantization/ifm_Quantize.json b/compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/int16/quantization/ifm_Quantize.json
new file mode 100644
index 000000000..aaba6131c
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/int16/quantization/ifm_Quantize.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.038689617067575455,
+ "zero_point": 128.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/int16/quantization/ofm1.json b/compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/int16/quantization/ofm1.json
new file mode 100644
index 000000000..aaba6131c
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/int16/quantization/ofm1.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.038689617067575455,
+ "zero_point": 128.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/int16/quantization/ofm2.json b/compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/int16/quantization/ofm2.json
new file mode 100644
index 000000000..aaba6131c
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/int16/quantization/ofm2.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.038689617067575455,
+ "zero_point": 128.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/int16/quantization/split_dim.json b/compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/int16/quantization/split_dim.json
new file mode 100644
index 000000000..ac7cde187
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/int16/quantization/split_dim.json
@@ -0,0 +1,5 @@
+{
+ "weights": [
+ 0
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/uint8/quantization/ifm_Quantize.json b/compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/uint8/quantization/ifm_Quantize.json
new file mode 100644
index 000000000..2fb0c68d8
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/uint8/quantization/ifm_Quantize.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.00014983004075475037,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/uint8/quantization/ofm1.json b/compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/uint8/quantization/ofm1.json
new file mode 100644
index 000000000..2fb0c68d8
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/uint8/quantization/ofm1.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.00014983004075475037,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/uint8/quantization/ofm2.json b/compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/uint8/quantization/ofm2.json
new file mode 100644
index 000000000..2fb0c68d8
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/uint8/quantization/ofm2.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.00014983004075475037,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/uint8/quantization/split_dim.json b/compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/uint8/quantization/split_dim.json
new file mode 100644
index 000000000..ac7cde187
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/Split_000_config/channel/uint8/quantization/split_dim.json
@@ -0,0 +1,5 @@
+{
+ "weights": [
+ 0
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/channel/int16/fake_quantization/ker.json b/compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/channel/int16/fake_quantization/ker.json
new file mode 100644
index 000000000..76a0440a0
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/channel/int16/fake_quantization/ker.json
@@ -0,0 +1,48 @@
+{
+ "weights": [
+ [
+ [
+ [
+ 0.960784912109375,
+ 2.0588245391845703
+ ],
+ [
+ -3.0196075439453125,
+ -3.980391502380371
+ ],
+ [
+ 4.9411773681640625,
+ -6.039215087890625
+ ]
+ ],
+ [
+ [
+ 7.0,
+ 7.960784912109375
+ ],
+ [
+ -9.058823585510254,
+ -10.019607543945312
+ ],
+ [
+ 10.980392456054688,
+ -11.941176414489746
+ ]
+ ],
+ [
+ [
+ 13.039216995239258,
+ 14.000001907348633
+ ],
+ [
+ -14.960784912109375,
+ -16.05882453918457
+ ],
+ [
+ 17.019607543945312,
+ -17.980392456054688
+ ]
+ ]
+ ]
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/channel/int16/quantization/ifm_Quantize.json b/compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/channel/int16/quantization/ifm_Quantize.json
new file mode 100644
index 000000000..dc5ca8dd5
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/channel/int16/quantization/ifm_Quantize.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.03869570419192314,
+ "zero_point": 126.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/channel/int16/quantization/ker.json b/compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/channel/int16/quantization/ker.json
new file mode 100644
index 000000000..bc150bbb0
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/channel/int16/quantization/ker.json
@@ -0,0 +1,52 @@
+{
+ "weights": [
+ [
+ [
+ [
+ 138,
+ 146
+ ],
+ [
+ 109,
+ 102
+ ],
+ [
+ 167,
+ 87
+ ]
+ ],
+ [
+ [
+ 182,
+ 189
+ ],
+ [
+ 65,
+ 58
+ ],
+ [
+ 211,
+ 44
+ ]
+ ],
+ [
+ [
+ 226,
+ 233
+ ],
+ [
+ 22,
+ 14
+ ],
+ [
+ 255,
+ 0
+ ]
+ ]
+ ]
+ ],
+ "scale": 0.13725490868091583,
+ "zero_point": 131.0,
+ "min": -17.980392456054688,
+ "max": 17.019609451293945
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/channel/int16/quantization/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/channel/int16/quantization/ofm.json
new file mode 100644
index 000000000..bfd862189
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/channel/int16/quantization/ofm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 1.6333034038543701,
+ "zero_point": 127.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/layer/uint8/fake_quantization/ker.json b/compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/layer/uint8/fake_quantization/ker.json
new file mode 100644
index 000000000..6df24eb42
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/layer/uint8/fake_quantization/ker.json
@@ -0,0 +1,48 @@
+{
+ "weights": [
+ [
+ [
+ [
+ 0.999786376953125,
+ 2.0001220703125
+ ],
+ [
+ -2.999908447265625,
+ -4.000244140625
+ ],
+ [
+ 5.000030517578125,
+ -5.99981689453125
+ ]
+ ],
+ [
+ [
+ 7.000152587890625,
+ 7.99993896484375
+ ],
+ [
+ -9.000274658203125,
+ -10.00006103515625
+ ],
+ [
+ 10.999847412109375,
+ -12.00018310546875
+ ]
+ ],
+ [
+ [
+ 12.999969482421875,
+ 13.999755859375
+ ],
+ [
+ -15.000091552734375,
+ -15.9998779296875
+ ],
+ [
+ 17.000213623046875,
+ -18.0
+ ]
+ ]
+ ]
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/layer/uint8/quantization/ifm_Quantize.json b/compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/layer/uint8/quantization/ifm_Quantize.json
new file mode 100644
index 000000000..82f7fa2b6
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/layer/uint8/quantization/ifm_Quantize.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.00015178922330960631,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/layer/uint8/quantization/ker.json b/compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/layer/uint8/quantization/ker.json
new file mode 100644
index 000000000..8d0ceb1c6
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/layer/uint8/quantization/ker.json
@@ -0,0 +1,58 @@
+{
+ "weights": [
+ [
+ [
+ [
+ 1820,
+ 3641
+ ],
+ [
+ -5461,
+ -7282
+ ],
+ [
+ 9102,
+ -10922
+ ]
+ ],
+ [
+ [
+ 12743,
+ 14563
+ ],
+ [
+ -16384,
+ -18204
+ ],
+ [
+ 20024,
+ -21845
+ ]
+ ],
+ [
+ [
+ 23665,
+ 25485
+ ],
+ [
+ -27306,
+ -29126
+ ],
+ [
+ 30947,
+ -32767
+ ]
+ ]
+ ]
+ ],
+ "scale": [
+ 0.0005493331705679495
+ ],
+ "zero_point": 0.0,
+ "min": [
+ -18.0
+ ],
+ "max": [
+ 18.0
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/layer/uint8/quantization/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/layer/uint8/quantization/ofm.json
new file mode 100644
index 000000000..f370bf44d
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/TransposeConv_001_config/layer/uint8/quantization/ofm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.0122029148042202,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/requires.cmake b/compiler/pota-quantization-value-test/requires.cmake
index 4eb7204e1..5ce8dfb5d 100644
--- a/compiler/pota-quantization-value-test/requires.cmake
+++ b/compiler/pota-quantization-value-test/requires.cmake
@@ -2,4 +2,4 @@ require("record-minmax")
require("circle-quantizer")
require("circle-tensordump")
require("common-artifacts")
-require("mio-circle")
+require("mio-circle04")
diff --git a/compiler/pota-quantization-value-test/test.lst b/compiler/pota-quantization-value-test/test.lst
index 4beec8c0e..e169de57c 100644
--- a/compiler/pota-quantization-value-test/test.lst
+++ b/compiler/pota-quantization-value-test/test.lst
@@ -31,3 +31,32 @@ addTest(Split_000 channel int16)
addTest(TransposeConv_001 channel uint8)
addTest(TransposeConv_001 channel int16)
addTest(TransposeConv_001 layer uint8)
+
+addQConfTest(Add_002 layer uint8)
+addQConfTest(Add_002 channel int16)
+addQConfTest(AveragePool2D_000 layer uint8)
+addQConfTest(AveragePool2D_000 channel int16)
+addQConfTest(Concatenation_001 layer uint8)
+addQConfTest(Concatenation_001 channel int16)
+addQConfTest(Conv2D_004 channel int16)
+addQConfTest(Conv2D_004 layer uint8)
+addQConfTest(DepthwiseConv2D_002 channel int16)
+addQConfTest(DepthwiseConv2D_002 layer uint8)
+addQConfTest(FullyConnected_003 channel int16)
+addQConfTest(FullyConnected_003 layer uint8)
+#addQConfTest(InstanceNorm_001 layer uint8) Enable this when int16 CWQ data is ready.
+#addQConfTest(InstanceNorm_001 channel int16) Enable this when int16 CWQ data is ready.
+addQConfTest(Mean_000 layer uint8)
+addQConfTest(Mean_000 channel int16)
+addQConfTest(MaxPool2D_000 layer uint8)
+addQConfTest(MaxPool2D_000 channel int16)
+addQConfTest(Mul_001 layer uint8)
+addQConfTest(Mul_001 channel int16)
+addQConfTest(PRelu_001 layer uint8)
+addQConfTest(PRelu_001 channel int16)
+addQConfTest(ReLU_000 layer uint8)
+addQConfTest(ReLU_000 channel int16)
+addQConfTest(Split_000 channel uint8)
+addQConfTest(Split_000 channel int16)
+addQConfTest(TransposeConv_001 channel int16)
+addQConfTest(TransposeConv_001 layer uint8)
diff --git a/compiler/pota-quantization-value-test/test_fake_wquant_with_config.sh b/compiler/pota-quantization-value-test/test_fake_wquant_with_config.sh
new file mode 100755
index 000000000..070b2738e
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_fake_wquant_with_config.sh
@@ -0,0 +1,87 @@
+#!/bin/bash
+
+# This script tests fake quantization with config file
+#
+# HOW TO USE
+#
+# ./test_fake_wquant_with_config.sh <path/to/test.config> <path/to/work_dir> <TEST 1> <TEST 2> ...
+# test.config : set ${RECORD_MINMAX_PATH} and ${CIRCLE_QUANTIZER_PATH}
+# work_dir : build directory of quantization-value-test (ex: build/compiler/quantization-value-test)
+
+SOURCE_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+COMPARE_SCRIPT_PATH="${SOURCE_PATH}/compare_tensors.py"
+CONFIG_PATH="$1"; shift
+BIN_PATH=$(dirname "${CONFIG_PATH}")
+WORKDIR="$1"; shift
+
+source "${CONFIG_PATH}"
+
+echo "-- Found CIRCLE_QUANTIZER: ${CIRCLE_QUANTIZER_PATH}"
+echo "-- Found CIRCLE_TENSORDUMP: ${CIRCLE_TENSORDUMP_PATH}"
+echo "-- Found workdir: ${WORKDIR}"
+
+TESTED=()
+PASSED=()
+FAILED=()
+
+pushd "${WORKDIR}"
+while [ "$1" != "" ]; do
+ MODELNAME=$1; shift
+ GRANULARITY=$1; shift
+ DTYPE=$1; shift
+ TESTCASE="${MODELNAME}.${GRANULARITY}.${DTYPE}"
+
+ TESTED+=("${TESTCASE}")
+
+ TESTCASE_FILE="${WORKDIR}/${TESTCASE}"
+ TEST_RESULT_FILE="${BIN_PATH}/${TESTCASE}"
+
+ PASSED_TAG="${TEST_RESULT_FILE}.fake_quantized.mixed.passed"
+ rm -f "${PASSED_TAG}"
+
+ cat > "${TEST_RESULT_FILE}_fake_quantization_with_config.log" <(
+ exec 2>&1
+ set -ex
+
+ # Run circle-quantizer with --quantize_dequantize_weights
+ "${CIRCLE_QUANTIZER_PATH}" \
+ --quantize_dequantize_weights float32 "${DTYPE}" "${GRANULARITY}" \
+ --config "${SOURCE_PATH}/config_files/${MODELNAME}/${GRANULARITY}/${DTYPE}/qconf.json" \
+ "${WORKDIR}/${MODELNAME}.circle" \
+ "${TEST_RESULT_FILE}.fake_quantized.mixed.circle"
+
+ # Dump weights values (circle-tensordump)
+ "${CIRCLE_TENSORDUMP_PATH}" \
+ "${TEST_RESULT_FILE}.fake_quantized.mixed.circle" \
+ --tensors_to_hdf5 "${TEST_RESULT_FILE}.fake_quantized.mixed.circle.h5"
+
+ # Compare result
+ "${VIRTUALENV}/bin/python" "${COMPARE_SCRIPT_PATH}" \
+ --input_h5 "${TEST_RESULT_FILE}.fake_quantized.mixed.circle.h5" \
+ --expect_dir "${SOURCE_PATH}/expected_outputs/${MODELNAME}_config/${GRANULARITY}/${DTYPE}/fake_quantization" \
+ --mode fake_quantization
+
+ if [[ $? -eq 0 ]]; then
+ touch "${PASSED_TAG}"
+ fi
+ )
+
+ if [[ -f "${PASSED_TAG}" ]]; then
+ PASSED+=("$TESTCASE")
+ else
+ FAILED+=("$TESTCASE")
+ fi
+done
+popd
+
+if [[ ${#TESTED[@]} -ne ${#PASSED[@]} ]]; then
+ echo "FAILED"
+ for TEST in "${FAILED[@]}"
+ do
+ echo "- ${TEST}"
+ done
+ exit 255
+fi
+
+echo "PASSED"
+exit 0
diff --git a/compiler/pota-quantization-value-test/test_inputs/Add_002_config/channel/int16/0.txt b/compiler/pota-quantization-value-test/test_inputs/Add_002_config/channel/int16/0.txt
new file mode 100644
index 000000000..b6e2efa3d
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Add_002_config/channel/int16/0.txt
@@ -0,0 +1 @@
+-0.8596993, 4.8127713,-3.4127183, 4.2323627,-2.2201376,-1.5362649,-4.9921966, 0.9565166, 3.2879171,-1.3590081,-3.771852 ,-4.1042285
diff --git a/compiler/pota-quantization-value-test/test_inputs/Add_002_config/channel/int16/1.txt b/compiler/pota-quantization-value-test/test_inputs/Add_002_config/channel/int16/1.txt
new file mode 100644
index 000000000..bcf2807ba
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Add_002_config/channel/int16/1.txt
@@ -0,0 +1 @@
+ 0.14624089, 4.7304125 , 4.833998 , 4.2321773 ,-2.0582533 ,-2.3694758 , 1.4213978 , 2.2444596 , 3.3630798 ,-0.70257574, 3.586656 ,-2.513805
diff --git a/compiler/pota-quantization-value-test/test_inputs/Add_002_config/channel/int16/2.txt b/compiler/pota-quantization-value-test/test_inputs/Add_002_config/channel/int16/2.txt
new file mode 100644
index 000000000..c3e32d2c5
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Add_002_config/channel/int16/2.txt
@@ -0,0 +1 @@
+ 2.175218 , 0.02776978,-2.6291077 , 3.5350094 ,-1.2364857 ,-3.3151364 ,-0.92507887, 2.8038094 ,-1.8781518 , 3.6221995 , 2.4015775 ,-2.9217577
diff --git a/compiler/pota-quantization-value-test/test_inputs/Add_002_config/channel/int16/3.txt b/compiler/pota-quantization-value-test/test_inputs/Add_002_config/channel/int16/3.txt
new file mode 100644
index 000000000..a92abd4f6
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Add_002_config/channel/int16/3.txt
@@ -0,0 +1 @@
+-1.0345451,-1.5055941,-4.144375 ,-4.727011 , 1.5841546, 4.5780725,-4.24402 ,-2.3966947,-3.0370803,-1.0234503,-0.2750057, 3.2965126
diff --git a/compiler/pota-quantization-value-test/test_inputs/Add_002_config/channel/int16/4.txt b/compiler/pota-quantization-value-test/test_inputs/Add_002_config/channel/int16/4.txt
new file mode 100644
index 000000000..2f2937fcb
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Add_002_config/channel/int16/4.txt
@@ -0,0 +1 @@
+-2.4460397 , 2.6090143 , 4.1773095 , 0.11204174,-3.3053472 , 2.5160108 ,-3.0612547 , 1.0667087 , 2.8952355 , 3.842513 , 0.6790793 ,-0.33375
diff --git a/compiler/pota-quantization-value-test/test_inputs/Add_002_config/layer/uint8/0.txt b/compiler/pota-quantization-value-test/test_inputs/Add_002_config/layer/uint8/0.txt
new file mode 100644
index 000000000..a219546a1
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Add_002_config/layer/uint8/0.txt
@@ -0,0 +1 @@
+-0.48516417,-4.5555663 ,-2.9907737 , 2.422857 , 1.010034 , 3.6436582 , 0.29334423,-4.0628953 , 1.0116768 , 3.0871766 , 3.3341465 , 4.3921704
diff --git a/compiler/pota-quantization-value-test/test_inputs/Add_002_config/layer/uint8/1.txt b/compiler/pota-quantization-value-test/test_inputs/Add_002_config/layer/uint8/1.txt
new file mode 100644
index 000000000..70d3139a0
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Add_002_config/layer/uint8/1.txt
@@ -0,0 +1 @@
+-0.7787985 , 4.101575 ,-0.4839729 , 0.35971674,-4.3452406 ,-4.811665 ,-3.8693128 , 4.239326 , 0.44103175, 3.5549765 , 2.5334291 , 1.4546562
diff --git a/compiler/pota-quantization-value-test/test_inputs/Add_002_config/layer/uint8/2.txt b/compiler/pota-quantization-value-test/test_inputs/Add_002_config/layer/uint8/2.txt
new file mode 100644
index 000000000..3c38f8d5d
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Add_002_config/layer/uint8/2.txt
@@ -0,0 +1 @@
+ 3.5943313,-1.4843192, 1.956341 ,-1.3242344, 1.5901331,-3.641623 , 4.6022506,-0.307265 ,-0.6359913,-4.0109854,-1.2064985, 1.1137954
diff --git a/compiler/pota-quantization-value-test/test_inputs/Add_002_config/layer/uint8/3.txt b/compiler/pota-quantization-value-test/test_inputs/Add_002_config/layer/uint8/3.txt
new file mode 100644
index 000000000..e89a022f5
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Add_002_config/layer/uint8/3.txt
@@ -0,0 +1 @@
+ 3.1036437 ,-0.39538398,-0.07278133, 4.547673 , 3.9132211 , 2.6468625 ,-4.2830634 ,-2.0573084 , 2.1074655 ,-4.0634165 ,-4.55598 ,-0.7942089
diff --git a/compiler/pota-quantization-value-test/test_inputs/Add_002_config/layer/uint8/4.txt b/compiler/pota-quantization-value-test/test_inputs/Add_002_config/layer/uint8/4.txt
new file mode 100644
index 000000000..2b00832cd
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Add_002_config/layer/uint8/4.txt
@@ -0,0 +1 @@
+-2.7745228, 1.4813256, 4.4699864, 3.7466738,-2.9847758,-4.453416 , 3.2515864,-1.2459193,-4.44965 ,-1.8452735, 4.423347 , 4.2998137
diff --git a/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/channel/int16/0.txt b/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/channel/int16/0.txt
new file mode 100644
index 000000000..e42cbf88b
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/channel/int16/0.txt
@@ -0,0 +1 @@
+-4.1358833e+00, 1.7854472e+00, 4.1751757e+00, 5.5915713e-01,-2.6459083e-01,-1.7176826e+00,-1.8155930e+00, 2.8710868e+00,-2.7043006e+00, 1.0959731e+00,-2.0176995e+00,-6.5950048e-01,-3.6413522e+00,-4.1966043e+00,-2.6820884e+00,-3.6055098e+00, 3.6852844e+00, 8.9128174e-02, 1.3107824e+00,-3.6425626e+00,-3.2318896e-01, 3.6238370e+00,-4.9837337e+00,-4.0550299e+00,-1.4882606e+00, 1.5547658e+00,-1.1696080e+00, 2.1651111e+00, 4.9318314e+00,-3.5928023e+00,-1.2348548e+00,-1.7002642e+00, 1.7365140e+00,-8.8151926e-01,-4.1655774e+00,-1.0166957e+00,-3.7440193e+00, 2.8588972e+00, 4.1286149e+00,-4.9504828e+00, 4.8477168e+00,-2.2587967e+00, 2.8542519e+00,-7.9565448e-01, 6.8252671e-01, 2.5875571e-01,-6.3935977e-01,-4.8547015e+00, 4.1373856e-03,-1.3893708e+00, 8.8775367e-01, 2.1222150e-01, 3.1871333e+00, 1.3869151e+00,-3.8274391e+00, 3.2623324e+00, 7.2669631e-01, 1.0303619e+00, 8.1438148e-01, 8.1272924e-01,-2.7527118e+00, 1.8215455e+00,-1.6416427e-01, 4.9103169e+00
diff --git a/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/channel/int16/1.txt b/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/channel/int16/1.txt
new file mode 100644
index 000000000..7caf8ce9e
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/channel/int16/1.txt
@@ -0,0 +1 @@
+-4.250757 , 1.4186406 , 0.63726735,-0.35924944, 1.9436699 , 3.2695885 , 3.6638293 , 4.5166173 , 1.3807241 ,-1.9112543 ,-1.9026492 ,-0.4800549 , 2.818216 ,-4.6390033 ,-3.8570547 , 3.6634028 ,-1.2112037 ,-1.3335027 , 1.3524677 , 2.7240725 ,-3.8335826 , 1.1397903 ,-3.1570992 ,-4.802078 , 3.8334577 , 0.23457901, 0.7132307 , 2.9887354 , 2.9702394 ,-1.4113717 ,-0.66712093, 0.77366674, 1.9308351 ,-0.45465755, 4.925366 , 2.4214447 , 2.8401468 , 0.49789894, 0.53141665,-2.7466767 , 0.2059374 ,-4.9661317 ,-4.1334467 , 1.6928389 ,-0.42529574, 1.1033608 , 4.275776 , 1.5063075 , 2.3528252 , 0.79505247, 3.9829993 ,-4.8472476 ,-1.2752185 , 3.7365675 , 1.976164 ,-4.742636 ,-2.7199092 ,-2.9191706 ,-3.181069 ,-4.489485 , 4.0847454 , 2.2164 , 0.9725334 ,-0.72566307
diff --git a/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/channel/int16/2.txt b/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/channel/int16/2.txt
new file mode 100644
index 000000000..7facffa57
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/channel/int16/2.txt
@@ -0,0 +1 @@
+-3.8293874 ,-0.13678598,-2.5444264 , 1.654611 ,-4.3037786 ,-3.4240584 ,-4.5642533 , 4.1250315 , 1.0469195 , 4.2802887 , 3.1617825 ,-3.1706758 ,-0.99622065, 2.7707603 , 3.7494645 ,-1.4548893 , 2.328633 , 1.7976477 ,-1.2107176 ,-2.0178459 ,-0.6488357 ,-2.9393644 , 2.8918762 , 3.6192262 ,-4.1777225 , 1.3264071 , 0.32620123, 0.7890992 ,-3.304334 , 3.4893208 , 2.5354576 ,-4.7718143 , 3.8602633 , 0.4927564 , 2.2971296 ,-0.3296792 , 2.8115997 ,-0.75152504, 0.558675 ,-2.343631 , 4.650826 ,-3.0893488 , 0.8726873 , 0.24922371, 2.7634025 , 1.0358421 ,-3.862506 ,-3.169402 ,-2.5373347 , 0.9484093 , 4.1409917 ,-4.0408096 ,-2.7231216 ,-2.548547 ,-2.6315095 , 0.8164778 ,-3.017436 , 1.1860138 ,-1.8634807 , 1.8684052 , 1.8657844 , 1.7747321 ,-3.1472425 ,-1.3989028
diff --git a/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/channel/int16/3.txt b/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/channel/int16/3.txt
new file mode 100644
index 000000000..0be8fdd19
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/channel/int16/3.txt
@@ -0,0 +1 @@
+-2.0492268 ,-2.2555764 ,-1.3543441 ,-3.7278662 ,-4.8601675 , 3.1095552 , 4.6319957 , 3.0211062 , 1.7870535 , 4.8839574 ,-1.3494394 , 2.635408 ,-0.24201432, 1.312397 , 0.16790341, 2.42507 ,-3.101355 , 3.1760497 ,-4.500736 ,-2.53691 , 1.064206 , 0.62096214, 2.803344 ,-4.6166744 ,-4.624786 , 3.667064 ,-1.484021 , 4.9401817 ,-3.763283 , 3.4351027 ,-2.906393 , 4.9945946 ,-3.2997096 , 3.6325612 ,-0.47211674, 0.28783202, 1.8703817 ,-4.042374 ,-3.3353784 , 4.9085765 ,-1.6753131 ,-3.4926984 ,-4.8663344 ,-4.495712 , 2.3402312 ,-1.0722051 , 0.28559962, 2.1208072 , 1.3024254 , 3.4810693 , 0.09860361, 1.695624 , 1.3901931 , 1.6858819 , 3.8231227 , 4.5972557 ,-4.6835494 , 0.5753765 ,-2.2377403 , 0.13013013,-2.1165738 ,-0.26044115,-0.653468 , 1.1010929
diff --git a/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/channel/int16/4.txt b/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/channel/int16/4.txt
new file mode 100644
index 000000000..7e2d618f9
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/channel/int16/4.txt
@@ -0,0 +1 @@
+ 4.397323 ,-0.51448834, 2.5729322 ,-4.3229046 , 1.149113 ,-3.8652143 ,-1.7352968 ,-0.7575065 ,-0.41720778, 4.327346 ,-4.2363043 , 0.8653738 ,-1.7511971 ,-0.7874244 ,-4.0734816 , 2.5622475 ,-3.1229742 ,-1.1783633 , 0.4017013 ,-0.76175183,-1.058416 , 1.128772 ,-3.0143378 ,-2.6688366 ,-2.575279 ,-4.326955 , 4.175434 , 4.791393 ,-1.10654 ,-4.4417224 , 3.5057635 , 1.5339037 ,-4.0297494 ,-3.7187057 ,-0.6645762 , 4.215642 , 1.6742749 , 2.5468905 , 1.73195 ,-3.3100636 ,-4.4818826 ,-2.5627983 ,-1.4624406 , 1.2433167 ,-4.005364 ,-4.3450556 ,-1.0652863 ,-1.0240986 , 3.989825 ,-4.1690702 ,-4.595108 ,-1.1154945 , 0.65749156, 2.5127344 , 2.509761 ,-4.3936505 , 3.6513395 ,-2.3340352 ,-4.3615093 , 3.5973237 , 0.9316653 , 1.9391845 , 3.6356397 , 0.8133118
diff --git a/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/layer/uint8/0.txt b/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/layer/uint8/0.txt
new file mode 100644
index 000000000..2a6b09b27
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/layer/uint8/0.txt
@@ -0,0 +1 @@
+-4.629505 , 1.0121975 ,-0.13417433,-2.329806 ,-3.4927373 ,-0.7574039 ,-2.2674313 , 3.1983519 , 2.4298382 ,-0.23268977, 2.0218065 ,-1.5087285 ,-1.3953347 ,-3.8100643 ,-1.7438283 , 3.9852605 , 2.9817178 ,-4.0460877 , 0.09402129, 4.3802586 ,-1.0991771 , 0.4134776 , 2.8136911 ,-3.6254618 ,-3.925183 , 4.691824 , 4.381538 ,-3.235543 ,-2.6764185 , 2.659456 ,-3.2127233 , 0.0206281 , 3.4056723 ,-1.693684 , 1.1005328 ,-3.1486542 , 0.77198106, 1.4526777 ,-2.3614178 , 4.8214664 ,-3.1486242 , 0.58941853,-4.1100698 , 4.1982718 , 1.7219902 ,-2.4375956 ,-1.7505955 , 1.7465224 ,-2.7494361 , 4.0679016 , 1.8936038 ,-4.523818 ,-3.4124248 ,-4.809946 ,-1.939553 , 4.9411273 , 1.6261404 ,-2.6846552 , 2.1339247 , 0.61396503,-1.6662381 , 2.4282491 , 2.662007 ,-0.40868336
diff --git a/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/layer/uint8/1.txt b/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/layer/uint8/1.txt
new file mode 100644
index 000000000..470da6c74
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/layer/uint8/1.txt
@@ -0,0 +1 @@
+ 0.70593804, 3.253847 , 1.1094694 , 0.5295975 , 0.5944647 ,-2.4391694 , 4.7912955 , 4.4374456 ,-2.942428 ,-3.5038033 ,-3.180417 , 2.1914082 ,-4.5295396 ,-3.0037553 ,-2.265191 , 0.20113531, 2.3805366 ,-0.9111223 ,-4.3170924 , 4.08436 , 1.1006241 ,-1.286977 , 4.811279 , 0.9131829 , 3.2051497 ,-2.8660698 ,-3.188871 , 1.4163305 , 4.061829 , 2.7783196 ,-3.4975152 , 3.4888391 , 2.5789826 ,-1.5264264 ,-0.13952135,-1.280177 , 2.4716458 , 2.6200528 ,-2.515086 , 3.441416 , 2.4515297 ,-0.9845471 , 0.9481396 , 1.1518412 , 1.6088997 , 1.445077 , 2.2620194 ,-2.0843177 ,-0.7263964 , 1.8159748 ,-3.3673623 , 0.2554476 ,-4.3550563 ,-1.4280493 ,-2.2702312 ,-4.7424164 ,-0.57241255,-2.813357 , 2.9161859 ,-0.9036504 , 0.00511268, 0.60724795, 4.8010454 , 1.6000834
diff --git a/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/layer/uint8/2.txt b/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/layer/uint8/2.txt
new file mode 100644
index 000000000..d9e048b61
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/layer/uint8/2.txt
@@ -0,0 +1 @@
+ 7.07888961e-01, 4.75798702e+00,-1.47843570e-01,-1.95845592e+00, 4.26537895e+00,-3.03711486e+00,-1.35137546e+00,-1.10638596e-01,-1.02415502e+00,-2.65345359e+00, 5.48920631e-01,-4.38003826e+00, 3.61377740e+00,-2.91408587e+00,-3.22874010e-01,-4.74363208e-01, 3.45294738e+00, 1.02204478e+00,-1.44102740e+00, 6.80687547e-01,-2.44050741e+00, 3.71395111e+00,-2.14443612e+00, 3.70928717e+00, 1.35871637e+00, 9.73374963e-01, 1.57826161e+00,-2.91381836e-01, 1.46376801e+00, 2.96391749e+00, 1.08418810e+00,-3.50718546e+00, 4.68637037e+00, 1.04839933e+00, 2.24482760e-01, 2.38816309e+00, 3.18772525e-01,-3.90284014e+00,-3.32757282e+00,-1.61143410e+00,-1.26013708e+00, 2.24948835e+00, 7.63151050e-01, 4.18296242e+00,-8.69123042e-01, 3.19850564e-01, 3.52391124e-01, 3.30018830e+00,-4.64861393e+00,-4.64479780e+00,-2.68103647e+00,-1.13277221e+00, 2.02201343e+00,-4.05572534e-01, 3.06759548e+00,-3.55881310e+00,-1.14900565e+00,-3.00835490e+00, 1.31509733e+00, 2.50206441e-01, 2.47731134e-01, 4.98673916e+00,-1.74064383e-01,-4.43180744e-03
diff --git a/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/layer/uint8/3.txt b/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/layer/uint8/3.txt
new file mode 100644
index 000000000..cdbf98e8a
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/layer/uint8/3.txt
@@ -0,0 +1 @@
+ 3.5591762 , 4.8821726 , 0.44271094, 4.786732 ,-2.4497197 , 2.4973536 , 2.034311 , 4.8329844 ,-3.9451184 , 4.9937835 , 2.0246332 ,-2.8319602 , 3.9617133 , 4.10946 ,-4.3191586 ,-2.8492777 ,-2.648121 ,-4.199404 ,-0.05163948,-4.7944984 , 2.8989205 , 1.4747709 ,-3.1194637 ,-2.877846 ,-0.39301065, 2.616311 , 2.6305614 , 1.7303206 , 3.6059175 ,-2.745988 , 2.5924454 , 3.0149276 , 4.0359216 ,-0.6135884 ,-2.5023808 ,-2.3395267 ,-3.0633461 ,-2.3836162 ,-4.4779797 ,-1.30866 , 1.9110863 , 0.654628 ,-4.559368 , 0.34231895,-0.8196542 , 4.7275734 , 3.2823656 ,-4.9644713 , 2.9191613 ,-3.4621727 ,-4.276584 ,-1.7153062 , 1.8820064 , 1.2659297 , 3.4141889 ,-4.905296 , 4.619848 ,-3.9501083 ,-1.5550466 , 3.6841137 , 1.7121594 , 1.9466268 , 1.5684807 , 4.5554323
diff --git a/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/layer/uint8/4.txt b/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/layer/uint8/4.txt
new file mode 100644
index 000000000..065d77df6
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/AveragePool2D_000_config/layer/uint8/4.txt
@@ -0,0 +1 @@
+-2.2269225 ,-1.2782103 ,-3.381931 ,-1.5229299 , 2.0681949 , 1.7630705 ,-0.81455594,-2.6558595 ,-3.4870632 ,-4.647749 , 2.4453654 ,-2.242679 ,-1.0272806 , 0.5656208 , 0.69442594,-4.4343104 ,-3.9649677 ,-3.8908577 ,-1.642287 , 3.0714357 , 1.0880747 ,-2.1665683 ,-4.0994506 , 2.004911 , 3.5922902 , 3.775 , 1.1580672 ,-1.4154137 ,-4.4964633 ,-1.696588 , 4.0220857 ,-1.2785947 ,-4.2075186 ,-4.515838 , 0.99715126, 3.0928102 ,-2.295537 ,-4.772882 ,-1.2936146 ,-2.6903791 , 0.10453273,-1.8041211 , 3.787591 , 0.9493053 ,-4.41586 , 3.4252715 ,-0.25001565, 4.655357 ,-1.8767506 , 0.00600041, 4.660605 , 2.550518 ,-3.830558 , 1.7777463 ,-0.7170577 ,-0.26554853,-3.5770113 ,-1.1354474 , 4.663121 , 3.100427 , 0.03313563,-1.7419808 ,-1.4426676 ,-3.912533
diff --git a/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/channel/int16/0.txt b/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/channel/int16/0.txt
new file mode 100644
index 000000000..9def1c2eb
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/channel/int16/0.txt
@@ -0,0 +1 @@
+0.24671102,3.271825 ,3.979895 ,1.3334678
diff --git a/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/channel/int16/1.txt b/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/channel/int16/1.txt
new file mode 100644
index 000000000..eaec2409f
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/channel/int16/1.txt
@@ -0,0 +1 @@
+ 1.9181111, 2.2396102,-2.8641696,-1.9045062
diff --git a/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/channel/int16/2.txt b/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/channel/int16/2.txt
new file mode 100644
index 000000000..3e05181cc
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/channel/int16/2.txt
@@ -0,0 +1 @@
+4.751434 ,2.8798263 ,0.15149078,2.9485583
diff --git a/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/channel/int16/3.txt b/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/channel/int16/3.txt
new file mode 100644
index 000000000..19d95b267
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/channel/int16/3.txt
@@ -0,0 +1 @@
+-1.5743442 , 0.6716824 , 0.75737774,-0.27396253
diff --git a/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/channel/int16/4.txt b/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/channel/int16/4.txt
new file mode 100644
index 000000000..d302e07a9
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/channel/int16/4.txt
@@ -0,0 +1 @@
+-1.0539489 , 1.9595883 , 0.19975437, 2.526178
diff --git a/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/layer/uint8/0.txt b/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/layer/uint8/0.txt
new file mode 100644
index 000000000..af1c2dff8
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/layer/uint8/0.txt
@@ -0,0 +1 @@
+-4.0575085 , 2.5941508 ,-2.550309 ,-0.03760919
diff --git a/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/layer/uint8/1.txt b/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/layer/uint8/1.txt
new file mode 100644
index 000000000..0ede613ac
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/layer/uint8/1.txt
@@ -0,0 +1 @@
+ 0.4857123,-4.032874 ,-3.687589 ,-1.235227
diff --git a/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/layer/uint8/2.txt b/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/layer/uint8/2.txt
new file mode 100644
index 000000000..b0b0392ba
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/layer/uint8/2.txt
@@ -0,0 +1 @@
+ 0.21878362, 3.9175916 ,-4.6141233 , 3.709655
diff --git a/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/layer/uint8/3.txt b/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/layer/uint8/3.txt
new file mode 100644
index 000000000..d8a8cad12
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/layer/uint8/3.txt
@@ -0,0 +1 @@
+-1.9645791,-1.4466153, 1.2543651,-1.0288917
diff --git a/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/layer/uint8/4.txt b/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/layer/uint8/4.txt
new file mode 100644
index 000000000..ca2a1c3b4
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Concatenation_001_config/layer/uint8/4.txt
@@ -0,0 +1 @@
+-2.1611342, 2.4875243, 3.096089 ,-1.1327268
diff --git a/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/channel/int16/0.txt b/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/channel/int16/0.txt
new file mode 100644
index 000000000..0614b5e83
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/channel/int16/0.txt
@@ -0,0 +1 @@
+0.01090685,0.0581577 ,0.637094 ,0.64067715,0.26264507,0.13692169,0.9649414 ,0.5117181 ,0.18012471,0.07855253,0.6358017 ,0.62257963,0.41469443,0.93169045,0.20763828,0.7634293 ,0.75929826,0.72708374,0.23463063,0.58222896,0.6351517 ,0.68781173,0.5558012 ,0.7652179
diff --git a/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/channel/int16/1.txt b/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/channel/int16/1.txt
new file mode 100644
index 000000000..b1c39382f
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/channel/int16/1.txt
@@ -0,0 +1 @@
+0.57017624,0.08235867,0.03672464,0.40372616,0.7353964 ,0.59611887,0.7675548 ,0.21004233,0.09803218,0.20009473,0.8821493 ,0.17015271,0.14840214,0.99910176,0.37003204,0.22893582,0.43173164,0.3105084 ,0.41997132,0.43714985,0.08115962,0.71896386,0.7810953 ,0.00524598
diff --git a/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/channel/int16/2.txt b/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/channel/int16/2.txt
new file mode 100644
index 000000000..7e562de75
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/channel/int16/2.txt
@@ -0,0 +1 @@
+0.65292275,0.79842275,0.97853714,0.6711518 ,0.607567 ,0.40971732,0.74838483,0.95853555,0.32158023,0.911524 ,0.66938365,0.8573132 ,0.3047727 ,0.5561248 ,0.914098 ,0.07650814,0.37868017,0.29269257,0.19652605,0.63025194,0.61496884,0.32011527,0.8204132 ,0.21866946
diff --git a/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/channel/int16/3.txt b/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/channel/int16/3.txt
new file mode 100644
index 000000000..2958a7f54
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/channel/int16/3.txt
@@ -0,0 +1 @@
+0.4548901 ,0.56957537,0.0252368 ,0.4884317 ,0.7516498 ,0.02631272,0.22107519,0.95249426,0.34902394,0.11520014,0.808911 ,0.4148615 ,0.63615656,0.84020686,0.3633697 ,0.23993976,0.54176176,0.86938345,0.81628686,0.6380988 ,0.91891205,0.0406627 ,0.90289026,0.9429013
diff --git a/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/channel/int16/4.txt b/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/channel/int16/4.txt
new file mode 100644
index 000000000..fc969308e
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/channel/int16/4.txt
@@ -0,0 +1 @@
+0.9309136 ,0.02123719,0.64467335,0.6910113 ,0.47402772,0.54622203,0.31527275,0.81530565,0.98981965,0.36102158,0.03114039,0.1902339 ,0.45183742,0.60178596,0.4683102 ,0.59810966,0.40558222,0.5420302 ,0.72699505,0.9575108 ,0.46746576,0.08518691,0.40302262,0.69213694
diff --git a/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/layer/uint8/0.txt b/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/layer/uint8/0.txt
new file mode 100644
index 000000000..f82ad6704
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/layer/uint8/0.txt
@@ -0,0 +1 @@
+ 1.4040831 , 4.8621206 , 0.22880335,-0.3116556 , 0.260938 ,-0.61554366, 3.779648 ,-4.650609 , 3.886638 ,-0.25574106,-0.45002133, 4.9870906 ,-2.3277295 ,-4.9648423 ,-3.7695415 , 3.2857463 ,-4.5514555 ,-3.7705963 , 3.8458307 ,-4.797776 ,-3.4295716 ,-4.6026535 ,-1.4011091 , 2.8851774
diff --git a/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/layer/uint8/1.txt b/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/layer/uint8/1.txt
new file mode 100644
index 000000000..722337286
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/layer/uint8/1.txt
@@ -0,0 +1 @@
+-4.171929 ,-2.2911541 , 2.8965824 , 0.27504483,-1.6088463 ,-0.6509234 ,-3.262618 , 0.9633116 , 2.4504175 , 0.97706884, 0.4212074 , 1.4083375 ,-2.9757218 ,-3.1010823 ,-1.7146534 , 4.105306 , 0.07195274, 3.0232217 ,-2.7568955 ,-4.8887763 ,-3.4171093 ,-0.91494775, 2.5260248 , 4.74184
diff --git a/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/layer/uint8/2.txt b/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/layer/uint8/2.txt
new file mode 100644
index 000000000..1283a8ad1
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/layer/uint8/2.txt
@@ -0,0 +1 @@
+ 0.14139967, 1.9541235 ,-4.945228 ,-0.48999134, 3.7479703 , 0.29318067, 0.21036309, 4.357736 ,-4.3354783 ,-1.9236348 , 0.49615476,-1.8418436 ,-2.425741 , 4.817022 , 1.5093465 , 2.417444 ,-4.69463 , 0.3433745 ,-4.5979595 ,-3.9027495 ,-0.29977685, 4.9239326 ,-0.39175773, 1.277211
diff --git a/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/layer/uint8/3.txt b/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/layer/uint8/3.txt
new file mode 100644
index 000000000..c931e1752
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/layer/uint8/3.txt
@@ -0,0 +1 @@
+-3.692852 ,-1.0075341 ,-2.4409268 , 0.92995465,-3.1325107 , 4.028981 , 0.8446181 ,-2.2990613 , 4.0820794 , 3.1633005 , 4.1527267 ,-3.9514909 , 2.6104712 , 4.660645 ,-1.7398617 , 0.15663597,-3.6861904 ,-2.9019265 , 3.8828175 ,-2.712909 , 4.3699546 ,-3.5953352 ,-3.0655813 , 0.59767616
diff --git a/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/layer/uint8/4.txt b/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/layer/uint8/4.txt
new file mode 100644
index 000000000..d33c2dbec
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Conv2D_004_config/layer/uint8/4.txt
@@ -0,0 +1 @@
+-2.8695228 , 2.865197 , 0.6635586 , 0.22709726, 2.85572 ,-4.2051144 , 1.5833759 ,-4.4277377 , 4.0004573 , 2.4766827 , 3.0412688 ,-4.8891425 ,-4.489896 , 3.0812325 , 2.1947708 , 1.6387184 , 0.31932488,-0.41092923,-0.0730476 , 0.7265327 , 4.1333 , 3.157228 , 4.7395325 , 3.4576747
diff --git a/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/channel/int16/0.txt b/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/channel/int16/0.txt
new file mode 100644
index 000000000..f4fb503ea
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/channel/int16/0.txt
@@ -0,0 +1 @@
+0.4383064 ,0.8700848 ,0.86010957,0.08396256,0.7963264 ,0.4156023 ,0.28146362,0.82196397,0.9921972 ,0.09969576,0.23987265,0.6734369 ,0.5469574 ,0.20805728,0.32639247,0.76773816
diff --git a/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/channel/int16/1.txt b/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/channel/int16/1.txt
new file mode 100644
index 000000000..af4b01576
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/channel/int16/1.txt
@@ -0,0 +1 @@
+0.4565062 ,0.92036587,0.47286046,0.18118097,0.5347498 ,0.91550153,0.300375 ,0.00581101,0.38686675,0.91085213,0.07278002,0.35556316,0.13014294,0.7274307 ,0.13867259,0.27517235
diff --git a/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/channel/int16/2.txt b/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/channel/int16/2.txt
new file mode 100644
index 000000000..57716034e
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/channel/int16/2.txt
@@ -0,0 +1 @@
+0.6900174 ,0.28745306,0.30255774,0.5095008 ,0.6689176 ,0.4914624 ,0.92629427,0.504829 ,0.33514255,0.49005315,0.08569656,0.60965323,0.82193315,0.12380831,0.06971261,0.8822662
diff --git a/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/channel/int16/3.txt b/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/channel/int16/3.txt
new file mode 100644
index 000000000..1e03d83b0
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/channel/int16/3.txt
@@ -0,0 +1 @@
+0.4240734 ,0.5430392 ,0.7536325 ,0.46065134,0.00315792,0.02719985,0.7080977 ,0.24389206,0.8114604 ,0.13292362,0.346597 ,0.70247084,0.55753845,0.01969242,0.82950485,0.66249627
diff --git a/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/channel/int16/4.txt b/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/channel/int16/4.txt
new file mode 100644
index 000000000..89ee30a6b
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/channel/int16/4.txt
@@ -0,0 +1 @@
+0.31586212,0.19079527,0.9161567 ,0.8614566 ,0.9018915 ,0.34651542,0.62554437,0.05542602,0.8268219 ,0.38112178,0.9396123 ,0.49426383,0.8034765 ,0.72456217,0.5404088 ,0.8512237
diff --git a/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/layer/uint8/0.txt b/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/layer/uint8/0.txt
new file mode 100644
index 000000000..cc434b0a8
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/layer/uint8/0.txt
@@ -0,0 +1 @@
+-4.0618963 ,-0.56899416,-2.6450877 , 2.4534085 , 1.98115 , 1.906561 ,-3.9617727 ,-0.6071247 , 3.1096997 , 4.4270124 ,-2.8755112 ,-1.8822336 ,-2.3567479 , 1.9797888 ,-3.5018713 , 3.429169
diff --git a/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/layer/uint8/1.txt b/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/layer/uint8/1.txt
new file mode 100644
index 000000000..2c637a1d2
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/layer/uint8/1.txt
@@ -0,0 +1 @@
+-1.6089132 , 1.4328785 ,-3.2579598 ,-2.1328773 ,-2.6566415 , 2.541386 ,-4.3314023 , 0.48684084, 3.3134763 ,-2.69083 ,-0.45710313,-3.6763198 , 0.22075526,-3.159208 ,-2.1573126 , 4.1621423
diff --git a/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/layer/uint8/2.txt b/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/layer/uint8/2.txt
new file mode 100644
index 000000000..4b57fe8e0
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/layer/uint8/2.txt
@@ -0,0 +1 @@
+-4.061572 , 3.0518744 , 2.694435 ,-4.720131 , 1.3782452 , 4.083631 , 4.1221976 ,-1.2299284 , 3.096133 , 3.8382158 ,-1.9518853 , 4.350529 , 0.09219506, 2.6483617 , 0.74373996, 2.7447948
diff --git a/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/layer/uint8/3.txt b/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/layer/uint8/3.txt
new file mode 100644
index 000000000..49c3022c2
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/layer/uint8/3.txt
@@ -0,0 +1 @@
+ 4.68769 ,-3.2768764 , 3.1849844 , 4.497627 ,-1.2611016 ,-3.1152303 ,-0.8408633 , 0.4938034 , 4.0921655 ,-2.3150117 , 0.10100875,-3.8374226 , 4.08059 ,-0.74594986,-3.1000822 , 4.3654246
diff --git a/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/layer/uint8/4.txt b/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/layer/uint8/4.txt
new file mode 100644
index 000000000..e02c8ca16
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/DepthwiseConv2D_002_config/layer/uint8/4.txt
@@ -0,0 +1 @@
+-3.6168842 , 4.1935644 , 0.73750836, 4.6044145 , 2.8967912 ,-1.8085694 , 4.539956 ,-0.37032878, 1.9738418 , 1.5388782 ,-2.945171 ,-3.3875864 ,-4.516983 ,-3.4998245 ,-4.676514 ,-2.2738194
diff --git a/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/channel/int16/0.txt b/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/channel/int16/0.txt
new file mode 100644
index 000000000..233e5eae3
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/channel/int16/0.txt
@@ -0,0 +1 @@
+ 2.7731526 , 2.451602 , 3.7535272 ,-1.2774152 , 1.5482912 , 1.3402948 , 4.4792123 ,-4.4954367 , 3.354679 ,-3.3615496 ,-4.619757 ,-3.3659618 , 4.7626247 ,-1.3596478 ,-4.835548 , 0.78964525
diff --git a/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/channel/int16/1.txt b/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/channel/int16/1.txt
new file mode 100644
index 000000000..6a126081d
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/channel/int16/1.txt
@@ -0,0 +1 @@
+ 0.5400839 ,-3.2621996 ,-3.4817135 , 3.8183312 , 0.48498327, 2.9812584 , 4.111276 , 0.11223658, 4.7201405 , 2.4256718 , 1.4895477 , 4.7596602 ,-0.32709372, 1.3507305 ,-0.30043927,-1.8077502
diff --git a/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/channel/int16/2.txt b/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/channel/int16/2.txt
new file mode 100644
index 000000000..eccd2c625
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/channel/int16/2.txt
@@ -0,0 +1 @@
+ 3.8758078 , 4.978636 ,-0.22925885,-2.6760504 ,-1.9160627 ,-4.609644 ,-0.9515802 , 3.558274 , 2.9096057 , 0.3340422 , 0.38608226,-0.32168412, 4.688853 ,-4.583811 ,-2.5113506 ,-4.6688786
diff --git a/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/channel/int16/3.txt b/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/channel/int16/3.txt
new file mode 100644
index 000000000..0da05277c
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/channel/int16/3.txt
@@ -0,0 +1 @@
+-2.9868221 , 2.4237797 , 1.0833962 ,-0.9231426 ,-2.1091506 ,-2.6163697 ,-0.23101932,-1.9252896 , 4.7034135 , 3.1088963 ,-2.345823 ,-2.7866168 ,-3.186763 ,-4.431844 , 3.3113294 , 0.9501982
diff --git a/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/channel/int16/4.txt b/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/channel/int16/4.txt
new file mode 100644
index 000000000..ace24f7c1
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/channel/int16/4.txt
@@ -0,0 +1 @@
+ 3.9716747 ,-2.254871 , 1.1943274 ,-2.212602 , 3.4311683 , 1.114989 , 4.0739036 , 0.47244295,-3.5793104 ,-3.359908 ,-4.7657595 , 2.0369127 ,-2.5619278 ,-3.4452975 ,-4.5852203 ,-1.137643
diff --git a/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/layer/uint8/0.txt b/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/layer/uint8/0.txt
new file mode 100644
index 000000000..18b34c8b1
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/layer/uint8/0.txt
@@ -0,0 +1 @@
+ 1.5887886e+00,-4.7446389e+00,-8.6568648e-01,-2.9789083e+00, 4.4470620e+00,-4.6563668e+00,-3.8466794e+00, 1.8815753e-03,-2.7699089e+00, 5.2776605e-01, 3.6518128e+00,-3.0939088e+00,-3.6008542e+00, 7.2454107e-01, 2.2568390e+00,-4.4835806e+00
diff --git a/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/layer/uint8/1.txt b/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/layer/uint8/1.txt
new file mode 100644
index 000000000..d652da699
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/layer/uint8/1.txt
@@ -0,0 +1 @@
+ 4.770412 ,-1.7520845 , 2.4057522 ,-0.74166125,-0.10780027, 4.5796657 ,-3.513094 ,-3.0285823 , 1.2001143 , 2.806742 ,-2.0503895 , 2.8160958 ,-1.5392824 ,-3.7772799 , 2.9158401 ,-1.0586692
diff --git a/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/layer/uint8/2.txt b/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/layer/uint8/2.txt
new file mode 100644
index 000000000..e6d6e004f
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/layer/uint8/2.txt
@@ -0,0 +1 @@
+ 3.937408 ,-0.11191579, 2.2054992 , 2.847275 , 3.4895647 , 4.2361116 ,-3.2401278 ,-1.5813186 ,-4.558396 ,-0.89455926, 4.204445 , 3.5968838 , 2.773891 ,-2.9562843 ,-0.62606305,-0.03814701
diff --git a/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/layer/uint8/3.txt b/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/layer/uint8/3.txt
new file mode 100644
index 000000000..8b472058e
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/layer/uint8/3.txt
@@ -0,0 +1 @@
+ 3.5032003 , 4.6036057 , 0.28915945, 4.671659 ,-1.978598 , 2.1773603 ,-0.54175234,-3.0131943 ,-2.7422159 ,-3.4361897 , 0.2850049 , 4.1412387 ,-4.86403 ,-0.67577606,-1.4206086 ,-2.357092
diff --git a/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/layer/uint8/4.txt b/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/layer/uint8/4.txt
new file mode 100644
index 000000000..bba80be5f
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/FullyConnected_003_config/layer/uint8/4.txt
@@ -0,0 +1 @@
+ 2.5063417 , 0.22874236, 2.2677753 ,-4.4159026 , 1.7464 , 4.6051064 ,-4.2867146 , 2.730521 , 1.6372519 , 0.70292765, 3.459053 ,-4.162376 , 0.36788836, 2.213299 , 4.110952 , 1.6797827
diff --git a/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/channel/int16/0.txt b/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/channel/int16/0.txt
new file mode 100644
index 000000000..31a2db03e
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/channel/int16/0.txt
@@ -0,0 +1 @@
+-4.1984134 , 3.7565446 , 1.3521377 ,-4.0263743 ,-1.929471 ,-3.7523155 , 1.3858393 , 4.1565247 ,-2.4681342 , 0.3598748 ,-2.0044599 , 3.7168603 , 3.6330557 , 3.0176272 ,-4.4643235 ,-0.1893698 , 3.8839848 ,-4.5703125 , 3.365731 , 4.5556674 , 4.954971 , 1.7591819 ,-0.9497736 ,-0.8527185 ,-1.1863561 ,-4.522639 ,-4.3187394 ,-3.702939 , 0.15341021, 0.8564923 , 1.9076811 , 4.2765 ,-3.7695112 ,-1.6033245 , 2.3159432 ,-1.6656336 , 1.4186145 , 4.334284 , 4.0654674 ,-4.518256 , 0.72815216, 2.5133176 ,-4.238172 , 1.0198449 ,-0.9638457 , 2.5847483 , 4.0381308 , 4.472872 , 0.11794223, 1.3358012 , 1.7975981 , 2.168553 ,-3.5131238 , 3.8412008 , 3.851232 ,-2.130775 , 3.556102 , 0.69062364,-4.668594 ,-4.619906 ,-2.87768 ,-1.0679495 ,-4.523185 , 4.184176
diff --git a/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/channel/int16/1.txt b/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/channel/int16/1.txt
new file mode 100644
index 000000000..2bdd62b24
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/channel/int16/1.txt
@@ -0,0 +1 @@
+ 2.9193265 , 4.315574 ,-3.7834768 , 3.4352486 , 4.1452866 ,-4.0322523 , 1.8039155 ,-4.080042 ,-1.1999705 , 4.9018297 ,-0.27180746, 1.709373 , 4.3322196 , 4.9179945 ,-3.977508 , 2.3486571 ,-0.11026379,-0.24730131, 2.3269305 , 2.1862001 , 0.92486495, 3.5822759 , 2.8370361 , 3.915398 ,-0.6385275 ,-0.02720119,-1.408676 ,-4.4472733 , 1.2901759 ,-4.60209 ,-2.9502335 ,-2.650517 ,-1.4038593 ,-2.967456 ,-2.0060933 ,-1.9603083 ,-0.4727794 ,-1.7877682 ,-3.9565926 , 1.4452418 , 2.5925353 ,-4.5134907 ,-4.195412 , 2.4681656 , 0.7140492 , 3.0753498 , 0.269442 ,-4.768041 ,-3.5370746 , 1.0272335 ,-0.7654047 ,-1.977087 , 3.1920779 , 0.37378865, 4.016262 ,-3.3201067 ,-4.7767315 ,-3.5074112 ,-4.094166 , 1.6035818 , 1.6506963 ,-3.2142932 , 4.7714067 ,-1.7164946
diff --git a/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/channel/int16/2.txt b/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/channel/int16/2.txt
new file mode 100644
index 000000000..8c770f61d
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/channel/int16/2.txt
@@ -0,0 +1 @@
+-1.8028042 , 1.7280815 ,-3.0464594 ,-2.810487 , 0.582805 ,-1.786865 ,-1.7263526 ,-0.36871073, 3.3955328 ,-3.9523299 ,-1.880003 , 4.9068613 , 4.6292953 , 3.9778202 ,-1.859954 , 2.8149757 , 4.5020967 ,-4.160163 , 1.9295161 ,-1.2508658 , 0.5669804 , 0.99246883,-2.4829247 , 0.88920474,-3.7942843 , 2.4626305 , 4.3087935 , 3.0680852 , 3.0893688 , 3.1640174 ,-0.41890725, 0.5377459 ,-4.0344224 ,-4.5812287 , 0.5720303 , 1.802316 ,-0.31413126, 2.9586952 , 1.1723012 ,-4.696369 ,-3.7047153 ,-1.8109767 ,-3.6122723 , 1.2727392 , 4.4057164 , 3.8347735 ,-4.739083 , 2.4655118 , 0.45258832, 4.0693913 ,-3.3486447 ,-0.64714307, 1.4990507 , 2.771129 ,-0.6109979 ,-1.0617865 , 2.0837703 ,-1.633663 , 1.8431798 ,-4.3942385 , 4.8523426 , 1.1941985 , 3.0366988 , 4.7991366
diff --git a/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/channel/int16/3.txt b/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/channel/int16/3.txt
new file mode 100644
index 000000000..8a4c9ebb5
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/channel/int16/3.txt
@@ -0,0 +1 @@
+-2.2375767 ,-1.1274278 , 0.18025301,-4.598087 , 1.1042122 , 3.1241179 , 1.9084688 ,-1.214722 , 4.596646 , 4.1969523 , 4.658112 , 3.143779 ,-2.6940444 ,-1.5482163 , 1.542811 ,-1.1338089 , 3.721594 , 0.24673286, 4.71102 , 2.7811737 , 1.171089 , 4.145586 ,-2.6335135 , 1.1190183 ,-3.7932637 ,-4.6548123 ,-3.10302 ,-3.392706 ,-3.856141 , 0.6618614 , 0.9668614 , 4.4293485 , 1.3193 , 4.983464 , 1.659716 ,-3.185926 , 4.8983006 , 1.6323217 , 0.18800464,-1.9328839 , 4.6031475 , 3.459718 , 4.128766 ,-3.4701612 ,-2.3796144 , 1.6752707 ,-3.6569223 , 2.922704 , 3.642789 ,-1.6817225 , 3.151759 ,-1.5401909 ,-3.8259532 , 2.4556105 ,-4.4989905 , 1.2779988 ,-0.62634754, 3.5827441 ,-0.82541114, 2.1539748 , 4.583461 , 1.2231985 ,-1.4457659 ,-2.9194565
diff --git a/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/channel/int16/4.txt b/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/channel/int16/4.txt
new file mode 100644
index 000000000..5110f86aa
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/channel/int16/4.txt
@@ -0,0 +1 @@
+-4.011289 , 0.9077414 ,-2.8109396 ,-4.33598 ,-2.6516347 ,-3.917852 , 3.2461808 , 1.7588768 ,-1.9439132 , 2.190185 , 1.5180751 , 0.3587409 ,-4.3434815 ,-4.1376143 , 3.750847 , 1.5820616 , 0.03843357, 4.71235 , 1.0592757 ,-1.7640393 , 0.44547582, 2.8698466 , 4.5816092 , 4.6638517 , 1.4207541 , 1.863644 , 3.6007912 , 0.6800818 ,-2.4884489 , 3.0707197 , 3.3961668 ,-4.331953 , 2.7828538 ,-0.16146964,-4.9070745 ,-2.9787786 , 0.3337284 ,-3.935533 ,-3.303555 , 2.376896 ,-4.7058997 ,-2.2409894 , 0.07352693,-2.6024988 , 4.9593167 ,-4.7717366 , 1.6590588 , 4.063875 ,-3.8855767 , 2.6274624 , 4.901856 , 4.157007 ,-3.292969 , 3.579326 , 3.9860668 ,-3.0936542 ,-4.7793274 , 0.71697485,-2.0354068 ,-2.1414943 , 3.6339438 , 0.10732502,-0.86129206, 4.4152017
diff --git a/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/layer/uint8/0.txt b/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/layer/uint8/0.txt
new file mode 100644
index 000000000..1a4fc3ed0
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/layer/uint8/0.txt
@@ -0,0 +1 @@
+ 2.2145607 , 0.88045335, 0.45151594, 2.852104 , 3.191637 ,-0.4578638 , 1.4858874 ,-2.1207588 ,-0.77495986,-4.1637363 , 0.83028954,-3.9974387 ,-3.3348315 , 3.7137656 ,-2.9883633 , 3.4332464 , 3.7178712 , 3.5850213 , 0.9240786 ,-0.07091421,-4.516931 , 3.965739 ,-4.828566 , 3.860382 , 0.3243482 , 1.6835089 ,-1.4710085 ,-2.6625636 , 1.942659 , 0.12808529, 1.3640044 ,-3.0124736 ,-3.646485 , 1.6046281 , 1.1087954 ,-2.4648561 ,-2.3274968 , 1.2196178 , 3.0752547 , 1.8316921 ,-2.926682 ,-2.247648 , 4.1264873 , 4.700915 ,-0.6861696 , 3.5246365 ,-2.5577545 , 1.832533 ,-4.3125343 ,-2.8579648 , 3.5299218 ,-0.67911506, 0.86782926,-2.918562 ,-3.3644724 ,-2.0097935 , 0.3721956 ,-1.3528451 , 3.8267515 , 4.916677 , 3.2055025 ,-0.64435905, 3.877367 ,-1.830818
diff --git a/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/layer/uint8/1.txt b/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/layer/uint8/1.txt
new file mode 100644
index 000000000..09c06c74c
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/layer/uint8/1.txt
@@ -0,0 +1 @@
+ 4.5410523 , 4.4007382 , 3.3252192 , 0.40420002,-4.7642856 , 2.0282986 , 2.32176 , 3.160375 ,-4.3348713 ,-2.324847 , 4.327631 , 3.253995 , 0.53624976,-4.4896946 , 4.0600896 , 2.697662 ,-3.0693228 ,-4.7954664 , 2.010163 , 4.5790668 , 0.00921074,-4.638007 ,-2.612561 , 4.338762 ,-1.3632652 ,-0.55081725, 4.273717 , 3.1074166 , 3.1386747 ,-4.033469 ,-0.7298752 ,-3.4973295 , 4.454913 ,-0.5148646 ,-2.4100194 , 2.7154703 , 4.1507893 , 2.3424785 ,-1.7028755 ,-2.6013496 ,-1.831555 ,-4.07971 ,-1.039077 ,-1.8733021 ,-3.885844 , 3.5691998 ,-3.8779395 ,-4.7566814 ,-3.570575 ,-3.0510366 ,-4.6841617 ,-4.751285 ,-2.9700782 , 3.4774506 ,-1.3150035 ,-3.6287053 , 2.2280993 , 4.502896 , 3.9448938 , 3.3926914 , 1.560589 , 3.3307595 , 2.6545596 , 2.0503757
diff --git a/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/layer/uint8/2.txt b/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/layer/uint8/2.txt
new file mode 100644
index 000000000..24b7a248f
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/layer/uint8/2.txt
@@ -0,0 +1 @@
+ 4.5630627e+00,-4.5077333e+00, 6.8117022e-03,-1.1568142e-02, 2.3568916e+00,-2.9918964e+00,-4.8542055e-01, 4.7381549e+00, 3.1183126e+00,-2.6462586e+00, 3.0083582e+00, 1.4518642e-01,-2.4764729e+00,-4.8520207e+00,-4.8022575e+00,-1.8167463e-01,-3.1106927e+00,-2.4183941e+00,-4.1466684e+00,-3.6997426e+00,-3.9788694e+00,-3.0889416e+00,-2.2332447e+00, 1.8608164e+00, 2.8619974e+00,-3.6986623e+00,-1.3749057e+00,-9.2409855e-01, 2.7646086e+00,-3.3385031e+00, 7.6255083e-01, 1.0236104e+00,-1.7077237e+00,-4.4339476e+00,-1.1930060e+00,-1.7226344e+00,-3.1680160e+00,-1.8338548e+00,-2.6412952e+00,-8.2973856e-01, 4.2303777e+00, 3.4531716e-03,-3.3162324e+00, 8.4682000e-01, 2.5807633e+00, 2.7543969e+00, 6.8153429e-01, 4.7182851e+00, 4.2617507e+00,-1.4446728e+00,-4.3752551e+00, 3.5699592e+00, 9.6946698e-01,-2.0700858e+00, 2.0899124e+00, 1.6371955e+00,-9.5873147e-01, 3.1151581e+00, 2.9369416e+00, 4.4568644e+00,-9.4711387e-01,-4.1349549e+00, 3.3031983e+00, 4.1091359e-01
diff --git a/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/layer/uint8/3.txt b/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/layer/uint8/3.txt
new file mode 100644
index 000000000..088eb62cd
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/layer/uint8/3.txt
@@ -0,0 +1 @@
+ 2.5168443 , 3.7492614 ,-3.7076504 , 0.49709523,-4.642194 , 1.8201847 ,-1.396746 ,-1.0660223 , 3.3333528 ,-1.7719259 ,-2.3515563 ,-2.0570705 ,-4.7125244 ,-1.593302 ,-2.1072757 ,-4.4396334 , 4.3185077 ,-2.7568438 ,-0.59535027,-3.9871383 ,-2.6216223 , 0.39957425,-1.3687986 ,-3.1157744 , 1.2557942 , 2.3428473 ,-4.906711 , 3.5663006 ,-0.46128616,-4.7818427 ,-0.8876555 , 2.5066485 ,-1.3254607 ,-3.6097736 , 1.2686944 ,-1.37061 , 4.762917 ,-3.489012 ,-2.7905307 ,-0.2612837 ,-3.3236315 , 0.8347171 , 2.5582032 , 0.42744452, 1.7428764 , 2.4122005 ,-3.6781132 , 2.8811646 ,-2.7060914 ,-0.4752588 , 0.44432116, 0.5011615 , 3.2550313 , 0.02670379, 2.6197197 ,-4.319786 ,-1.4056181 ,-3.3794782 , 0.66822946,-1.4262298 ,-0.2465175 ,-4.6432767 ,-3.580772 , 2.960096
diff --git a/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/layer/uint8/4.txt b/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/layer/uint8/4.txt
new file mode 100644
index 000000000..bb8129473
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/MaxPool2D_000_config/layer/uint8/4.txt
@@ -0,0 +1 @@
+-4.9356976 , 3.9426446 ,-4.746647 , 2.3674695 , 0.54803735, 3.1911538 , 0.28858757, 0.4800329 , 2.0652595 ,-4.5046906 , 0.21695825,-0.17217463, 2.4329293 ,-1.2274694 ,-0.11534467,-2.096684 , 2.6882868 ,-2.5291932 , 0.56199783,-2.0743406 , 0.95846254, 4.004705 , 0.89853394, 2.9610496 , 2.9799032 , 1.5339601 ,-1.7136513 , 2.1797504 ,-4.2055335 , 1.5059681 , 3.0828342 ,-1.7946475 ,-2.7096524 , 3.1037905 , 0.75922704,-1.1446673 ,-2.084073 ,-1.2888353 ,-1.6958839 ,-0.8388285 ,-1.0279479 , 1.1291095 , 4.080411 , 3.6791847 , 0.9237894 ,-4.70821 , 0.5730598 ,-1.3565379 ,-2.7533107 ,-0.4583869 ,-1.4416862 ,-3.6039822 ,-1.1611387 ,-2.6919081 ,-0.6557734 ,-2.9248757 , 1.4998456 , 3.2239568 , 0.23668556,-3.4410136 ,-2.3170567 , 3.66808 , 1.9004405 , 4.3537745
diff --git a/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/channel/int16/0.txt b/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/channel/int16/0.txt
new file mode 100644
index 000000000..182eb5290
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/channel/int16/0.txt
@@ -0,0 +1 @@
+ 3.4251418 , 1.8884782 ,-4.061519 ,-2.1329548 , 3.851976 , 3.668601 ,-0.7418167 , 2.379966 , 0.87259316,-3.96981 ,-4.627804 ,-3.3958297 , 3.025158 ,-1.299777 ,-4.322816 , 3.9173064 ,-0.55214256, 1.9224825 ,-4.8571157 ,-4.778045 , 3.3015614 , 0.56785774, 4.7985554 ,-0.4355816 , 4.9478025 , 1.7909397 ,-0.7620663 ,-0.09947702,-3.0230513 , 1.3817457 ,-4.5706887 ,-3.4097836 ,-4.7086477 ,-3.4651487 , 1.4401027 , 4.7513933 ,-1.0788624 ,-3.4946275 , 4.607974 ,-3.1215246 ,-1.4637078 ,-3.5266285 , 2.1268125 , 0.19458893, 4.058288 , 2.2452407 , 0.7575343 , 0.12213306, 4.885321 ,-1.2482406 ,-1.1034219 ,-4.054173 ,-3.6471267 , 4.774012 , 0.9450243 ,-2.5827825 ,-2.3991685 ,-2.8482654 , 0.9294943 ,-3.1165063 ,-1.6113516 , 0.04260086, 2.0987031 , 2.1601508 , 4.9740996 , 3.7719023 , 2.6817482 , 0.42131838,-1.4525859 ,-0.5124655 , 2.6313434 , 4.5606523 ,-4.6180778 , 4.788594 ,-0.8446551 ,-1.5460813 , 1.4288356 ,-1.9648911 ,-4.9766145 ,-2.405665 ,-0.30327383, 3.5204673 ,-3.848158 ,-2.6913974 ,-2.76141 , 4.336643 , 1.4205143 , 4.5898 ,-0.93183124, 4.2199287 ,-4.216924 ,-1.0979122 ,-2.3032405 ,-3.4457245 , 2.944412 , 2.137278 , 1.0326933 , 2.3116126 , 4.2138443 , 1.8283377 , 0.28901085,-1.8877143 , 0.50673705, 1.4360197 ,-2.924691 , 0.9819095 , 3.4656513 ,-2.541582 ,-1.9102442 , 3.3629627 ,-0.9675056 , 0.5937253 ,-2.4236617 ,-1.4193813 ,-0.7552614 ,-1.7121441 , 4.39647 ,-2.2712908 ,-4.3387337 , 1.5912663 , 0.8397044 , 0.17277755, 1.5272428 , 3.571715 ,-1.4471695 , 1.8623346 ,-4.3603377 , 1.2116091 , 4.960487 , 2.3681397 , 1.2925869 ,-4.3249073 , 2.4402251 ,-1.4506928 , 3.023616 ,-3.232099 ,-4.0106025 , 3.5774167 ,-0.6024932 , 1.0183483 ,-2.8215308 , 3.7395437 , 1.9100485 , 3.892712 , 4.6569633 ,-3.251774 ,-3.6923678 ,-4.8891983 ,-3.8605282 ,-4.0293036 ,-2.8199108 , 4.1668954 , 2.1569817 ,-2.9700332 ,-0.7035824 ,-0.5176811 ,-3.1826456 ,-3.334556 , 4.9103675 , 3.8513231 , 2.8609774 , 1.1845547 ,-1.4094447 ,-2.0445833 , 0.9833705 , 4.481276 , 3.83006 , 4.6240997 ,-4.268881 ,-0.85518706,-2.2650888 , 4.032545 , 0.9495817 , 1.1353155 ,-4.6551876 ,-2.2839146 , 2.6291692 ,-3.0398533 , 0.52652216,-1.8323399 ,-0.12300313, 0.46178594, 1.120684 , 1.4657134 ,-1.9794375 , 0.08941289,-4.4573083 , 2.7112565 , 4.9227715 , 2.4938288 ,-0.37153494,-4.1604757 , 4.7694197 ,-1.3021677 , 2.454714 ,-2.4902875 ,-2.760436 , 0.05183195,-2.6723208 ,-1.1471758 ,-2.2565122 , 0.20876396,-0.7288584 , 0.4386669 , 0.7846054 , 2.7294593 ,-3.836883 , 2.7501638 ,-4.775067 ,-3.2403855 ,-2.0307286 ,-1.6403166 , 4.9471517 , 1.0428456 , 2.5126355 , 3.0090203 ,-2.3476288 ,-2.9215205 , 3.8079188 , 0.83959275, 4.2670302 , 1.2338712 , 2.7329903 , 2.2549257 , 4.882931 , 0.12783106,-2.4392028 ,-2.4590807 , 4.2874207 ,-0.08333418,-3.4244132 ,-0.2235516 ,-4.23632 ,-1.3970895 , 2.1245553 ,-2.513883 ,-2.8092728 ,-1.9194845 ,-4.1932216 ,-3.7431748 ,-1.1063433 ,-3.714845 , 1.7230242 ,-0.19162221, 1.1123114 , 3.937181 , 2.6165597 ,-0.61531806, 0.44309503,-2.9260228 ,-3.1617007 , 0.0663496 , 2.4541974 ,-2.714474 , 4.2564497 , 1.2300675
diff --git a/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/channel/int16/1.txt b/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/channel/int16/1.txt
new file mode 100644
index 000000000..dd8037244
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/channel/int16/1.txt
@@ -0,0 +1 @@
+-4.8834 ,-4.6238756 , 2.020674 ,-2.3068821 , 3.7487323 ,-0.36079448, 0.08661745, 3.423143 , 3.3073757 ,-2.709357 , 4.4810205 , 3.4159606 , 4.1597505 ,-4.249789 , 2.3782206 ,-2.02848 , 0.90137833,-0.6249625 ,-3.5300052 ,-4.1113796 ,-3.768913 ,-3.59854 , 2.0896666 , 1.7677166 ,-2.3101497 ,-1.0116942 ,-3.7846713 , 2.4777756 , 3.413987 ,-2.1964507 , 0.08637846, 0.02552292,-1.9918599 , 0.7785565 ,-4.065995 , 0.8808776 ,-2.0446506 ,-1.8421272 , 0.42566776, 3.8834689 , 4.900111 ,-3.0617309 , 4.0613194 ,-3.3601153 , 3.678536 ,-4.1136184 ,-4.2903633 ,-2.6918027 , 3.4335177 ,-3.9272869 ,-1.6882807 ,-1.9629028 , 4.2125826 , 1.6536059 ,-1.1801353 , 4.8443203 , 2.9393198 , 0.4306524 , 4.390743 ,-4.6322317 , 2.932263 , 4.140538 , 2.7385068 , 2.620753 , 2.0725663 ,-1.3642436 ,-0.48539641,-4.2409816 ,-1.5950899 ,-1.688442 , 4.4769464 ,-1.25038 , 3.462903 , 0.5011836 , 0.981037 , 0.63532305,-3.4727957 , 4.6721544 ,-3.481392 , 2.8904114 ,-1.7057139 , 1.0501702 , 3.0799537 , 1.6698593 ,-1.3895478 , 4.487443 , 2.5352533 ,-0.19357985, 0.78166926, 3.5892236 ,-4.3259463 , 2.8381345 , 1.3652785 ,-0.40142608,-0.62102544,-3.088937 ,-4.0266094 , 4.7095647 , 2.0513067 ,-1.8115149 , 0.11062156,-4.5980725 , 2.809295 , 4.2042894 ,-3.4689455 ,-1.3418434 , 2.9026117 ,-1.6125411 , 2.153075 ,-3.4445221 , 3.4869678 , 1.8746428 , 0.8482056 , 3.0525062 , 1.715966 , 1.7684505 ,-2.0022326 ,-4.3427444 ,-3.1659825 , 1.6855526 , 3.1612136 , 2.0646648 ,-3.972224 ,-2.91726 ,-3.5450957 ,-2.7226381 ,-0.3273488 ,-2.5905557 , 3.6621993 ,-4.3285728 ,-0.6200474 , 0.08522832,-2.1981175 ,-3.4179437 , 2.5989106 ,-0.8503352 ,-3.3723786 , 3.9595454 ,-0.5431398 ,-2.6962373 , 1.9689399 ,-2.8925 ,-1.2064192 , 1.606632 , 2.2728612 ,-0.1403075 ,-4.8031726 , 0.1549256 ,-1.3698703 , 0.78889227,-2.286554 , 0.96417916,-0.10438658,-3.8131578 , 2.9322996 , 2.4103441 , 4.4864798 , 0.02176606,-1.1966147 ,-3.6921146 , 4.943659 ,-1.0050472 ,-1.2238564 ,-4.5758605 ,-2.6865735 , 1.7294792 , 4.180183 , 3.157911 ,-3.581904 ,-2.9112866 , 4.1674094 , 3.2326035 ,-2.7883985 ,-0.09154221, 0.8667318 ,-4.532571 , 0.816668 , 3.1307516 ,-4.1993947 ,-1.0503744 , 0.123965 , 0.17691068,-3.1465137 ,-1.4964765 , 3.4077635 ,-0.35415363, 1.9092371 ,-4.709203 , 1.148622 , 4.4766874 ,-2.193539 ,-3.7959206 , 1.4420112 ,-2.5300896 , 4.107192 , 3.4666913 ,-2.1158516 ,-3.182484 ,-2.8406513 ,-1.9396024 ,-2.3695247 , 3.8301885 ,-1.5032169 ,-0.48879272, 0.41695955,-1.1829228 , 4.822825 ,-2.9244933 ,-3.8178608 , 2.7742817 , 2.6998327 ,-3.1187122 , 2.508593 , 1.2989064 , 2.3436947 ,-0.39074868,-3.034766 ,-1.8690065 , 4.850296 ,-2.4549792 , 4.839528 , 2.2758777 , 2.6689568 , 3.2014422 , 3.6975234 ,-3.2566156 , 3.546554 , 1.9570364 ,-2.753807 , 2.3366053 ,-4.357898 , 4.9184504 ,-1.0057111 ,-3.8582199 , 1.2416974 , 4.355522 ,-2.7863925 , 0.4679685 , 2.6850772 , 2.9984746 , 2.434312 , 2.9931593 , 2.2637212 ,-0.18371914,-4.07688 ,-2.0402577 , 0.5173147 , 0.19596666, 4.71653 , 4.291663 ,-3.3575501 ,-1.0857964 ,-0.16504912, 3.6683955 , 2.9581416 ,-1.354989
diff --git a/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/channel/int16/2.txt b/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/channel/int16/2.txt
new file mode 100644
index 000000000..1295bfdba
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/channel/int16/2.txt
@@ -0,0 +1 @@
+ 1.2340723 ,-1.7371651 , 4.271641 ,-2.3332376 , 0.82301813,-3.4199295 ,-0.75806665,-2.2647665 , 2.613749 , 2.2658496 ,-2.1277714 ,-0.465433 ,-0.1323059 ,-1.9658507 ,-4.7780223 ,-4.392719 ,-0.81063855,-3.639001 ,-3.6398284 , 4.6309023 ,-0.17483327, 1.7921627 ,-1.1493484 ,-3.8145075 , 2.2367268 ,-0.40209827,-1.4159911 , 2.3032134 ,-4.154446 , 1.6760192 , 2.3430173 ,-1.386683 , 3.3363335 ,-2.976934 , 3.3983 ,-0.0069695 , 3.7025425 ,-1.8683758 , 0.72029626, 2.7558882 ,-4.4060984 , 2.553126 ,-3.5888321 , 1.8549582 ,-0.52258795, 4.6549897 , 0.8886988 ,-3.0400214 ,-3.6890693 , 3.6663766 ,-4.8026586 , 1.0636287 ,-2.9774907 , 0.39021772,-4.2414255 , 2.914968 ,-0.24334456,-4.0344954 ,-1.1011956 ,-3.8205252 , 0.05693521,-4.1379023 , 1.0584197 ,-4.0404034 , 4.841462 ,-1.2727845 , 2.6974225 ,-4.2507453 ,-2.7101111 ,-2.9800036 , 0.3082796 , 3.6763537 , 2.3277721 ,-4.9667864 ,-2.4498677 , 0.2704629 , 3.006634 ,-1.1129389 , 4.373073 ,-1.2066779 ,-3.1575904 ,-2.721046 ,-0.861226 , 1.7315729 , 2.255666 , 2.5448847 , 3.1268334 , 1.5189171 ,-3.1992466 , 0.607633 , 4.0749955 , 1.2546133 ,-1.5335796 ,-1.6200712 ,-3.9392874 , 1.053699 ,-0.87970537,-3.9218261 ,-2.2724128 , 0.82235074,-2.3400521 , 3.6467028 , 1.6891364 ,-1.6333519 , 2.2639709 ,-0.08272895,-3.076964 , 3.731091 , 3.7932968 , 2.496441 ,-4.12142 ,-2.0908666 ,-4.994248 ,-0.0429902 ,-4.6083336 ,-4.522535 , 4.717733 , 1.6715643 ,-4.779822 , 1.2919815 ,-4.6121325 ,-0.6206874 ,-2.6633883 ,-1.9632595 ,-3.2203329 ,-0.6556523 , 1.3083993 , 0.13287744, 4.599294 ,-1.1777852 ,-2.9159715 ,-0.25669238, 0.48217958,-3.9736347 ,-0.774503 ,-0.7264863 ,-3.0058725 ,-2.1682055 , 2.6579158 ,-4.4020653 , 3.0450368 , 1.3798735 ,-4.9858127 ,-4.5812607 ,-3.7349749 ,-4.4158583 , 1.631093 ,-3.0769646 ,-3.8406906 , 1.6544044 , 0.36895755,-1.8196682 ,-2.0880237 ,-3.708266 ,-2.0277069 , 1.0536597 ,-3.6726243 , 1.1704421 , 2.3201573 , 1.4994124 , 4.0197086 , 2.1001272 ,-0.39845964, 4.879206 ,-4.6042013 , 4.367211 , 2.2712052 , 2.7754369 ,-3.156667 , 4.349216 ,-4.111492 , 1.0267047 ,-2.3381946 , 4.8876834 , 4.876814 ,-0.28538027, 4.8861 ,-0.95963717, 0.46279734,-4.5789995 , 0.26168647,-0.8879058 , 2.4468584 , 1.3030591 , 3.7261188 , 3.9933589 , 2.4964094 ,-1.3851117 , 0.7147012 ,-3.8367457 , 0.79737735,-0.5907085 , 4.317288 , 0.7659837 ,-4.821792 ,-1.466433 ,-1.147227 ,-1.8638811 , 2.5115767 , 1.9449657 ,-2.4122007 ,-2.4968379 , 0.7738737 ,-1.4761454 , 4.131583 , 0.4211128 ,-2.4312468 ,-1.9722428 , 2.2810268 , 4.950381 ,-0.0406047 , 4.67312 , 0.66613483,-0.28880936, 3.2917845 , 1.6225572 , 4.809879 , 0.48241946,-3.654634 , 0.68542016, 1.3973923 , 3.479005 ,-1.4296091 , 0.64391786,-4.0887494 ,-2.186845 ,-4.5834355 ,-0.67726034, 2.4158256 ,-2.4787726 , 0.4353257 , 2.9205139 , 0.10488439, 2.0790074 ,-4.5518365 ,-3.3856661 , 3.940736 ,-1.7141095 ,-4.8946457 , 1.1085542 , 3.785141 ,-2.4175835 , 3.7720537 , 4.623048 , 2.2239215 , 0.11616404, 0.09229392,-3.637964 ,-2.334849 ,-0.95000714,-2.1338253 , 3.2281857 ,-4.0220475 , 4.7304025 ,-1.8075961 , 0.2428817
diff --git a/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/channel/int16/3.txt b/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/channel/int16/3.txt
new file mode 100644
index 000000000..378b5fea5
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/channel/int16/3.txt
@@ -0,0 +1 @@
+ 2.4605505 ,-2.7001262 ,-4.3874917 ,-2.9867616 ,-3.4332 , 0.76675916, 3.4377892 ,-0.6712793 , 1.8018581 , 1.8148962 , 2.0353577 ,-4.766427 , 3.2487285 , 3.886249 ,-2.8867183 ,-0.7906634 ,-4.376028 ,-4.2085958 ,-0.36025277, 0.6360799 ,-4.687723 , 4.8313313 , 3.3582768 , 2.1117954 , 0.9821817 , 3.3697798 ,-1.1784939 ,-3.1590316 ,-0.24019621, 0.20640443, 1.2808957 , 2.3346424 , 2.13951 , 0.61864626, 2.4020443 ,-1.9671458 ,-1.6852348 , 0.32225233,-2.3928862 ,-4.173372 ,-2.282281 ,-1.271318 , 3.0839682 ,-4.4726086 ,-0.635177 , 3.2710915 , 3.08071 ,-0.7311931 , 2.1444874 , 0.4102332 ,-3.332888 ,-4.8965516 , 3.903695 , 1.4920163 ,-4.041926 ,-0.3941788 , 3.6352818 ,-2.098405 ,-0.9248165 , 2.6277795 , 3.225142 ,-1.4461963 ,-4.2050753 ,-0.2213572 , 1.9704323 , 3.298732 ,-4.710403 , 3.6876736 , 2.0771818 , 1.3559113 , 1.328373 ,-4.4079022 ,-3.28067 , 3.8852313 , 2.322237 , 2.3243637 ,-1.9126451 , 4.6277676 , 1.7031307 , 0.74861574,-4.688967 , 3.9351206 ,-1.8054084 , 1.5824287 , 3.5381088 , 2.4798677 ,-3.3099444 ,-3.8518245 , 1.5562242 ,-1.9466928 , 0.08375791,-0.16754703, 2.9265418 ,-1.6599798 , 2.766202 ,-2.8269696 ,-0.19389874, 2.0869334 ,-1.5073173 ,-3.2024453 ,-3.6522708 ,-4.588111 ,-2.3425827 , 4.8709297 ,-1.4231887 , 1.0590451 ,-1.6406479 , 0.37192422, 0.7313186 , 0.3865313 ,-4.2832613 , 3.9712496 , 0.07653506, 0.2593589 ,-2.6036396 ,-0.45185068, 3.6537335 ,-0.6341783 ,-0.6381408 ,-1.0992868 , 2.766365 , 4.666631 , 4.416099 ,-3.6654727 ,-4.0626607 ,-3.4928396 ,-0.6944366 , 4.869798 , 4.2240977 , 0.9655519 ,-2.5654511 , 1.3396966 ,-3.7639391 ,-1.2369057 ,-3.7242758 ,-0.5189227 , 1.6548159 ,-2.6197302 , 4.2732763 , 2.239486 ,-4.316255 , 3.2419755 ,-1.9283817 , 0.22489135, 2.6034477 , 0.15818155, 2.0811818 , 0.836994 , 2.7832468 ,-0.68581384, 0.89475006,-3.1455147 ,-4.818614 ,-4.1738377 , 0.4281551 ,-2.935886 ,-3.7582467 , 0.58168256, 0.2854076 , 1.0492616 , 2.2415884 ,-4.4923434 ,-3.2479804 , 3.8439462 , 3.9802108 ,-0.9027783 , 1.7783072 ,-2.2782066 , 4.4638705 , 4.28735 , 4.291463 , 1.1685107 , 1.2765578 ,-3.7954235 ,-3.494621 , 4.4340134 ,-3.5995178 ,-4.3025713 , 3.3037348 ,-3.6675146 ,-1.7871013 ,-1.2922373 , 0.72924066,-4.7065907 , 2.1388702 , 2.3570008 , 3.9203117 , 0.07483537,-2.8389792 ,-1.795164 ,-4.380931 , 1.3189598 , 2.4404252 , 4.4774084 ,-1.2798066 ,-4.95842 , 1.8095461 , 4.2692375 ,-2.0918155 , 0.33083543,-3.794544 , 1.4940621 ,-3.9446015 ,-0.38208306, 0.30863285,-0.6832849 ,-2.5675633 ,-4.948772 , 1.5904989 , 3.0415509 ,-4.899339 , 0.9415345 ,-0.91124976, 4.4849253 ,-3.4605968 , 1.6737833 , 1.9091597 , 1.3111106 , 2.0829957 ,-2.1308084 ,-2.912219 , 1.1306196 , 2.231948 , 4.7522073 ,-2.1438766 ,-2.1000512 ,-0.2984778 ,-1.2093959 , 2.6259391 , 1.8113437 ,-4.137133 , 2.716111 , 3.4318748 ,-0.89123845,-3.70718 , 2.453927 ,-0.22418758,-3.098459 ,-4.4986243 , 0.85048616, 2.8023102 , 3.743153 , 0.9931644 , 3.8588202 , 1.7585737 ,-4.2855363 ,-2.5475764 ,-0.83141845,-1.9358089 , 3.1711586 , 2.4221613 ,-1.881327 ,-3.7230873 ,-4.55259 ,-0.42294836, 4.64625
diff --git a/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/channel/int16/4.txt b/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/channel/int16/4.txt
new file mode 100644
index 000000000..339435425
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/channel/int16/4.txt
@@ -0,0 +1 @@
+-3.37344313e+00, 2.78325319e+00,-7.30300546e-01, 1.33456266e+00, 3.96648932e+00, 4.33421373e+00,-3.11558557e+00,-3.64659280e-02,-1.73589993e+00, 4.81018400e+00,-8.32905114e-01, 2.33330703e+00, 1.85830116e+00,-4.60395622e+00, 5.26070774e-01,-4.71355534e+00,-2.97202754e+00, 3.57638383e+00, 4.50985909e+00, 2.08423686e+00,-1.85349309e+00,-2.18306184e+00,-4.65403509e+00, 4.31280661e+00, 1.16069472e+00,-4.85344124e+00, 8.40563923e-02,-1.98723459e+00,-4.29561710e+00,-2.57372570e+00,-4.22641230e+00,-4.00811911e+00,-9.61861551e-01,-2.14665198e+00, 4.18120289e+00,-3.87826174e-01,-2.86187083e-01,-4.84979200e+00,-1.34733701e+00, 1.27489030e+00, 1.98844969e+00,-4.11230135e+00,-1.61191213e+00, 2.63515592e+00, 4.35539484e+00,-1.56582773e+00,-2.45283508e+00, 1.44556177e+00,-8.56053472e-01, 3.25111747e+00, 3.58699083e+00,-2.47732449e+00, 3.64130282e+00,-4.91288567e+00, 8.97059917e-01,-2.26010180e+00, 4.91831064e+00, 4.45047706e-01, 1.88655663e+00, 3.20642543e+00, 1.38243341e+00, 9.06112790e-01, 1.15262544e+00,-2.39862514e+00,-2.87477684e+00, 7.36831248e-01, 3.18799114e+00, 1.22698748e+00, 5.63625395e-01, 1.29130912e+00,-4.89572334e+00, 2.11258578e+00,-4.55420208e+00, 4.94569272e-01,-7.08617330e-01,-1.84863120e-01,-4.81965256e+00,-1.06512284e+00, 4.79633398e-02, 2.70429182e+00, 4.78289175e+00,-2.11806059e+00, 4.23046875e+00, 3.18022132e+00,-8.39496255e-01, 3.13150501e+00,-3.24103773e-01,-7.48505890e-01,-2.45754886e+00, 4.16639376e+00, 3.25864077e+00, 3.40006447e+00,-3.77217412e+00, 2.93266010e+00, 3.33685803e+00, 1.02347994e+00,-2.22839618e+00,-1.90375733e+00, 3.24283957e+00,-4.01684284e-01,-4.45417643e+00, 3.74440104e-01, 3.33520865e+00, 6.64106190e-01, 3.84395885e+00, 2.38586918e-01,-1.51634857e-01,-2.64977455e+00,-3.45786500e+00, 4.89002228e+00,-1.07323432e+00,-2.92749858e+00,-1.76510501e+00,-3.44604325e+00,-1.89681911e+00, 4.20239258e+00,-1.75864971e+00, 2.13181686e+00, 3.90355319e-01,-4.11911535e+00, 6.61891177e-02,-4.32988214e+00,-1.42876351e+00, 3.12163901e+00,-4.56227779e+00, 4.17938662e+00, 9.63881195e-01, 4.35952139e+00, 1.61931109e+00, 4.11196423e+00, 2.25612569e+00,-4.77538586e+00,-1.72600198e+00,-4.39411783e+00,-8.98730099e-01,-1.04562032e+00,-2.81517529e+00, 3.57167959e+00, 1.90318239e+00, 2.17302442e+00,-3.79942179e+00, 2.19838643e+00,-4.16209459e+00, 4.45025682e+00, 1.68786839e-01,-2.56879544e+00, 3.60925221e+00, 1.06542781e-01,-3.48755455e+00,-6.77028894e-01,-3.51582170e+00, 3.90697241e+00, 4.49116230e+00,-1.56180394e+00, 4.96249914e+00, 9.63374436e-01, 2.72304177e+00, 8.38046610e-01,-2.91993833e+00,-9.41783428e-01, 8.00800502e-01, 3.89176035e+00, 6.70560122e-01, 2.76782703e+00,-1.37075472e+00,-3.25303817e+00,-4.41226482e+00,-8.38777184e-01, 1.73568249e+00,-1.09438455e+00,-1.08815920e+00, 1.06787062e+00, 2.04415274e+00,-2.93027782e+00,-6.86941504e-01, 3.83109421e-01,-3.49270535e+00,-2.13225913e+00,-3.61786675e+00, 1.32213378e+00,-2.89654016e+00, 4.23944092e+00, 4.53665400e+00, 4.26081800e+00,-1.95718706e+00, 4.72295076e-01,-3.08592963e+00, 2.53354859e+00, 3.80069661e+00,-1.14408419e-01, 2.39438844e+00,-4.73618507e+00, 2.35079074e+00,-1.43686843e+00, 1.32946157e+00, 1.10381134e-01,-3.49878430e+00, 2.83181930e+00, 4.57872486e+00, 2.29953095e-01, 7.19881415e-01,-2.97208834e+00, 4.11286211e+00,-3.89149117e+00, 3.83631349e+00, 4.14627981e+00,-1.14082299e-01,-6.89825296e-01,-2.55468488e+00,-4.04466152e+00, 9.95541453e-01,-2.59181118e+00,-4.60567427e+00,-4.77339029e+00,-7.36041367e-02, 1.85957468e+00,-3.42530179e+00, 4.55782986e+00,-3.29603004e+00, 3.55632234e+00, 2.40858841e+00,-2.07399082e+00,-3.96705031e+00, 4.41718817e+00, 3.19581985e+00,-3.72379017e+00,-3.76826024e+00, 6.79764748e-01,-4.43838930e+00, 2.29627752e+00, 2.34923697e+00,-4.23308420e+00, 3.80186272e+00, 8.65862250e-01, 8.44927967e-01,-1.05974531e+00, 4.70531940e+00, 1.25060010e+00, 4.82314730e+00,-4.53093815e+00, 4.51410580e+00, 4.95166332e-01,-3.45584202e+00, 1.82002666e-03,-3.27616286e+00,-2.68104935e+00, 2.39554620e+00, 2.99364328e+00,-2.57998848e+00,-4.35891914e+00, 4.64737415e+00,-5.74958742e-01, 6.47293210e-01, 1.85961032e+00, 4.49567413e+00,-4.36166048e+00
diff --git a/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/layer/uint8/0.txt b/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/layer/uint8/0.txt
new file mode 100644
index 000000000..e0e52c398
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/layer/uint8/0.txt
@@ -0,0 +1 @@
+ 4.5734663 , 3.96675 ,-2.7826853 , 4.377681 , 1.8424977 ,-2.8312624 , 0.65628445,-3.7023883 ,-1.8941027 , 0.53154576,-3.9718776 ,-3.3961854 ,-2.7500536 , 2.6793208 , 3.3515985 , 2.0939343 ,-4.3965416 ,-1.7462187 , 0.5660886 , 4.497879 ,-2.2529721 ,-4.8996797 ,-0.00740948,-2.941367 , 1.9482567 ,-2.462802 ,-0.7897884 , 3.1501546 , 3.1216884 ,-3.506249 , 2.871302 ,-3.964653 ,-0.40679944, 2.8930066 ,-4.783338 ,-1.8733944 , 2.2654383 ,-0.41361305,-3.7790897 ,-1.9458629 ,-2.274427 ,-2.9192872 ,-0.73215395, 2.8135974 , 2.1402152 , 4.516366 , 1.58816 ,-4.607831 ,-3.5409598 , 1.9784997 , 3.11111 , 1.0872442 ,-3.6907403 ,-4.774325 ,-4.9267297 , 1.2962086 , 2.4646177 , 2.2726526 , 4.8766675 ,-2.9272413 ,-0.06221364,-0.80498594,-2.319938 ,-3.8261194 ,-2.3452706 , 2.5408983 ,-0.80628425,-1.4547366 ,-4.4171157 , 3.1584027 , 4.2213454 , 3.0342784 , 2.0285478 , 3.4517126 , 1.870827 , 2.812075 , 1.0776864 ,-4.524331 , 3.1467574 ,-2.366355 ,-4.7368546 , 1.940347 , 4.282059 , 1.2666475 ,-4.9559174 , 2.8177614 , 1.1941892 ,-0.25412267,-2.833778 , 1.1770393 , 4.9503546 , 4.582686 ,-1.0778978 ,-2.9030416 , 3.2517505 , 1.556093 ,-3.7605543 , 0.5915735 ,-2.6323159 , 4.596147 ,-0.90292877, 2.8230112 , 4.9295835 , 3.523853 , 1.7742149 ,-2.6014073 , 2.162894 , 1.9364033 , 4.0920115 , 0.81613404, 2.4198878 ,-0.907447 ,-4.79113 ,-3.4193892 ,-0.3334577 ,-1.0439668 , 4.2233415 , 1.4482704 , 1.3646252 ,-0.9206041 , 4.4994802 ,-4.2411633 , 0.6763335 ,-1.3827848 , 1.8579848 , 1.6426222 , 0.904467 , 3.876264 ,-4.6476808 , 4.576801 ,-1.4680524 , 2.441134 , 3.2343059 , 0.23119794, 2.5640545 ,-0.7293438 , 3.7184558 ,-1.6056752 , 3.1490617 , 4.6837263 , 4.7100887 ,-2.785927 ,-0.1520597 ,-1.9914767 ,-4.00598 ,-2.7502792 , 3.7857378 , 2.8444788 , 4.9911737 , 0.29277426,-4.779576 , 3.223367 , 1.3517398 , 4.8757277 , 3.8083189 , 1.7660266 ,-2.1543872 , 4.822371 , 2.089687 ,-4.7373757 ,-2.4061642 , 2.0387447 ,-4.067881 ,-3.1757388 , 0.24974413,-0.24441184,-0.1168329 ,-0.35149318, 2.0035832 ,-4.248678 ,-1.4723817 , 3.8218668 ,-2.8085105 , 4.6995482 ,-3.0093114 ,-3.648268 ,-1.0374364 , 0.04459473, 2.3945484 ,-0.63439727, 3.3920286 , 2.403765 , 1.303556 , 3.232244 ,-0.44932058, 0.9601637 ,-3.3821623 ,-4.257736 ,-4.095783 , 0.42818338,-4.925627 ,-1.8419602 , 4.9393196 , 0.8049334 , 4.431875 , 2.8487725 , 2.1205912 , 1.7367444 ,-4.337498 ,-3.574642 ,-3.8927085 ,-0.35219863, 2.8415039 ,-0.2887568 ,-0.89806557, 2.669602 , 4.8017626 , 4.278042 ,-1.2604581 , 3.152027 , 2.1625066 , 1.5039738 ,-3.7209976 ,-0.72354925, 4.006067 ,-3.7651584 , 0.7198826 , 3.9594896 , 0.6228397 , 2.8464649 ,-0.18740664,-2.0530953 , 3.5185826 , 2.5037062 , 0.3990585 ,-4.423475 , 4.6931167 ,-1.0078553 , 0.74727917,-4.289701 , 1.697721 , 3.4963684 , 1.5796075 , 2.296678 ,-2.9379995 , 4.4748416 , 0.25155628, 4.1183267 , 0.9506131 , 1.2903908 ,-4.6828184 ,-2.309908 ,-4.2793307 ,-2.2069294 ,-4.038367 , 4.641971 ,-2.3178709 ,-2.2683682 ,-0.96986157, 2.6649144 , 2.3106637 ,-1.8052462 ,-4.9433284 , 1.7941002 , 4.80127 ,-0.06690114
diff --git a/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/layer/uint8/1.txt b/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/layer/uint8/1.txt
new file mode 100644
index 000000000..9a8f222e7
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/layer/uint8/1.txt
@@ -0,0 +1 @@
+ 2.2282960e+00, 1.0135865e+00,-4.1930809e+00, 5.3674412e-01,-3.2516165e+00, 1.2745492e+00, 4.2867136e+00, 1.9524460e+00,-3.6757104e+00,-3.6086998e+00,-9.4525421e-01,-3.4005399e+00, 3.3607626e+00, 4.2363039e-01,-2.5177178e+00,-3.0130227e+00,-4.1442380e+00, 4.4951862e-01,-6.4387190e-01, 4.3701029e+00,-3.6790867e+00, 3.2749624e+00,-2.2554400e+00, 1.8269253e+00, 1.8358005e+00,-6.0994375e-01, 3.5964453e+00, 4.8953295e+00,-2.6134133e+00,-3.9301482e-01, 4.0286818e+00,-8.9392501e-01, 2.6430035e+00,-1.0339550e+00,-4.2311502e+00, 5.1657695e-01,-3.0095081e+00,-3.2156844e+00, 3.0075660e+00,-2.4905038e+00, 2.2380588e+00, 4.6933036e+00,-2.7880669e+00,-3.3672907e+00, 2.5187421e+00, 2.1843061e+00,-3.9957666e+00,-4.5409918e+00,-1.7282218e+00,-4.6849327e+00, 3.1863580e+00, 2.4342964e+00,-4.5180349e+00,-2.4310455e+00,-2.6789901e+00,-1.6438740e+00, 4.9613748e+00,-3.7800386e+00,-4.4277740e+00, 1.0571244e+00,-3.3765689e-02,-6.2219787e-01, 2.1075857e+00,-2.0555353e+00, 2.6996508e+00,-3.0303302e+00,-3.8262250e+00,-4.5048919e-01, 2.6760142e+00, 3.2696848e+00, 2.8136756e+00,-2.7064829e+00, 8.5861349e-01,-1.8871003e+00,-9.5355767e-01, 2.3704410e+00, 4.8897211e-02,-4.6371531e+00, 1.5693765e+00, 3.7866819e+00,-2.9738419e+00, 1.2106347e+00,-5.8760280e-03,-6.4124316e-01, 4.2396611e-01, 4.8550687e+00,-3.0650468e+00,-1.2087260e+00,-2.4833875e+00, 2.1272743e+00,-1.8991195e-01,-3.5372739e+00,-2.3402226e+00,-1.0234243e+00, 2.8981063e+00, 8.7964945e-02, 3.2136328e+00,-3.4051507e+00,-4.5538807e+00,-4.0228786e+00,-1.8993270e-01,-4.5704255e+00, 1.8850164e+00, 9.9910229e-01,-4.8424377e+00,-3.1492932e+00, 2.3922281e+00, 4.8503261e+00,-2.1037047e+00, 3.3602579e+00, 1.3546667e+00, 1.3481154e+00,-2.3604252e+00,-1.3253393e+00,-3.5330158e-01,-2.1313765e+00, 3.1442962e+00,-1.1570807e+00,-4.5890884e+00,-4.1608801e+00, 1.8554245e+00, 2.4646142e+00,-1.8453486e+00, 3.3489871e+00,-1.1248070e+00, 3.1451607e+00,-1.4458319e+00,-2.2727523e+00,-2.0378258e+00, 2.4566815e+00, 3.8839689e-01, 4.2570353e+00, 2.3613093e+00, 1.2956337e+00,-7.5734973e-01,-1.4549307e+00, 9.3240172e-01, 4.3444591e+00,-6.4935732e-01, 2.5328317e+00,-2.3545196e+00,-4.7553263e+00, 2.6134777e+00,-2.5526178e+00,-1.7996631e+00,-2.0215256e+00,-4.6141486e+00,-1.7283168e+00, 2.5297335e-01, 3.7009020e+00,-1.9858284e+00,-3.4631619e+00,-1.5858738e+00,-2.5620985e+00, 3.2822473e+00,-3.2632313e+00,-9.0714562e-01,-2.3562717e+00, 4.4088845e+00,-3.6630182e+00, 5.5761892e-01, 1.6045070e+00,-3.6806375e-01, 4.3184443e+00,-1.3219705e+00, 1.5496376e+00,-1.5801797e+00, 2.1545045e+00,-4.0106788e+00, 3.4172714e+00,-4.2495294e+00,-6.1115064e-03,-7.2607052e-01,-7.3130745e-01,-4.4462271e+00, 4.8119636e+00,-4.7460346e+00,-3.0464313e+00,-2.8801811e+00,-1.4347218e-03, 4.4133449e+00,-3.3173063e-01, 4.3802023e+00, 2.6040417e-01,-2.5531218e+00, 3.7436140e+00,-4.1636271e+00,-3.3907690e+00,-1.4418361e+00,-3.6933661e+00,-2.6342602e+00,-3.1492887e+00,-5.5590755e-01,-1.6814464e-01,-1.0868104e+00, 4.9451909e+00, 3.4104226e+00, 1.0342516e+00, 4.7993002e+00, 1.2480364e-01, 1.6109833e-01, 2.6366503e+00, 1.6535910e+00, 4.3810592e+00, 4.4755011e+00, 4.3265424e+00,-3.1934264e-01, 9.8549920e-01, 1.9962710e-01, 2.8525822e+00,-3.7352023e+00,-1.3402178e+00, 2.5931063e+00,-2.6708813e+00,-7.6831090e-01, 3.0769660e+00, 1.4107993e+00,-1.8936746e+00,-4.7568636e+00,-1.9222193e+00, 4.7693071e+00, 2.8644614e+00, 4.1877995e+00,-3.6974251e+00, 4.5314616e-01,-7.1986055e-01, 4.8653622e+00, 1.4722897e+00,-8.6220115e-01,-4.1846976e+00, 3.7767217e+00, 3.7630556e+00,-4.5851058e-01,-4.9183292e+00,-1.8750135e+00, 1.0773923e+00,-5.2709883e-01,-9.2767686e-01,-1.3984675e+00,-2.0892789e+00,-4.3801632e+00, 4.0080590e+00, 4.2269025e+00,-1.2195336e+00,-2.2649438e+00, 4.6874623e+00,-3.8354571e+00, 5.9588730e-01,-2.8315885e+00, 3.0605823e-01, 2.1416895e+00, 1.6045133e+00,-3.3075256e+00, 4.9898911e+00, 1.7708080e-02, 3.5305614e+00
diff --git a/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/layer/uint8/2.txt b/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/layer/uint8/2.txt
new file mode 100644
index 000000000..1b2e33401
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/layer/uint8/2.txt
@@ -0,0 +1 @@
+ 1.9229428 , 2.1045275 , 2.0514195 , 1.7149676 ,-4.1647053 , 4.3958654 , 2.1192055 ,-2.4357705 , 2.249189 , 4.7986865 ,-1.0146881 , 2.5108647 , 0.7262246 ,-2.3110187 ,-0.434008 , 2.6220334 , 1.3261455 ,-2.0402927 , 0.6362597 , 0.12827367, 0.94167644, 1.6396433 , 2.802215 , 0.92637545,-2.8669958 , 2.1684341 , 4.7197456 ,-3.0393784 ,-1.5588902 ,-1.5589788 ,-1.2792847 ,-4.301159 , 3.6853306 , 3.5522077 ,-3.5120559 , 3.6523628 , 0.52381915,-4.3210206 , 3.1021209 ,-4.4059095 , 4.574733 ,-3.708168 ,-3.4609973 , 0.04494883, 4.6041393 , 4.6209555 ,-2.184693 , 3.3114836 , 4.0440845 ,-4.362543 ,-3.0185041 ,-3.4911432 ,-1.0443465 ,-3.1546419 ,-3.0831194 ,-1.8959469 ,-3.7653599 ,-1.8753844 , 3.969308 , 4.0960746 , 0.256032 ,-0.11065102, 4.753394 , 4.8433857 , 0.17249103, 0.44612473, 3.5996687 ,-3.7071083 , 4.15448 , 2.7609568 , 0.7979912 , 2.6985793 , 0.24981445,-0.7343978 ,-3.8946455 ,-3.4738345 ,-2.0124238 , 4.6603985 , 0.9002829 ,-2.2128618 ,-0.8752893 ,-3.0990481 , 2.770291 ,-1.4642559 , 0.4561498 , 0.5808671 , 2.4227936 ,-2.400878 , 0.6494001 , 1.0195295 ,-3.2693145 , 1.9889433 , 3.5208216 , 3.6280289 , 4.322899 ,-2.805155 , 3.7704606 , 0.6797415 , 4.442675 ,-0.5069875 , 1.3373847 , 4.6953626 ,-0.7946793 ,-2.7352958 ,-1.9969261 , 0.43059692, 2.50853 , 1.9314603 , 1.3780333 , 2.0536468 ,-1.572231 ,-4.5323825 ,-1.3175989 ,-1.5515776 ,-0.05870355, 0.32408538,-4.2935586 ,-1.561555 ,-1.7551405 ,-0.93950266, 3.2540953 ,-4.623753 ,-3.4944966 ,-0.7603045 , 0.76591074,-4.9114766 ,-2.679303 , 0.12950227, 4.094419 , 4.781908 ,-3.6946337 , 2.766349 ,-0.45678583,-2.275264 , 2.0858452 , 3.1182098 ,-1.2942638 , 4.4418044 , 2.2264028 ,-3.3838644 , 1.4427853 , 3.7365992 ,-1.1815038 , 1.4555137 , 0.22728541,-0.18817298, 3.454521 , 3.1835914 , 4.0786743 ,-1.5111316 , 1.1560454 ,-0.04693017, 0.44183066,-0.7420173 ,-1.2243766 , 3.4453049 ,-2.969513 ,-0.82397145, 4.870895 , 3.0178127 , 1.7217305 , 4.482936 , 1.9468685 , 3.9970267 , 4.7294793 , 2.9921744 , 4.470473 , 4.7626653 , 0.13104612,-4.651569 , 2.7991815 ,-4.734433 ,-2.4499187 , 1.0739365 ,-1.5583646 , 3.6531756 , 2.7731194 ,-4.72427 ,-4.5801177 ,-4.035709 , 2.5767221 ,-2.8133557 ,-1.8342617 , 3.5808434 ,-2.1022995 ,-3.5421894 ,-3.0776916 , 3.168665 ,-0.07246887,-1.2413273 , 4.7964606 ,-1.0624843 , 0.75939703, 2.5336463 ,-4.8622346 ,-4.9744167 , 2.1007512 , 1.5271608 , 0.37077245, 1.7765028 , 2.2724373 , 2.1864665 ,-0.37378153, 1.3559381 ,-1.4220421 ,-1.4756224 , 3.6143627 , 2.7846546 ,-2.5194893 , 3.005039 ,-3.6451447 ,-1.9118739 , 0.04718782,-3.0775185 ,-1.4801219 ,-2.35909 ,-0.4728799 , 4.610093 ,-4.472677 ,-4.530808 , 0.12514372, 0.05973044, 4.457302 , 3.1129916 , 3.6036162 , 4.5086145 ,-3.548999 , 0.4976606 ,-3.6525648 ,-2.1937015 ,-1.3205789 ,-2.6594079 , 4.415343 , 3.219482 ,-3.7286756 , 3.4116418 , 0.82889384,-3.0168123 , 4.382766 , 2.7633846 , 3.6949344 , 3.9806223 ,-0.6415279 ,-0.3193684 ,-1.3176754 ,-1.4990829 , 4.694691 ,-1.0581211 , 1.2103747 ,-0.26690048,-1.157015 ,-1.8951306 ,-0.8580171 ,-4.3080263 , 4.0737123 ,-1.2607352
diff --git a/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/layer/uint8/3.txt b/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/layer/uint8/3.txt
new file mode 100644
index 000000000..50ed09011
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/layer/uint8/3.txt
@@ -0,0 +1 @@
+ 4.9386005 , 3.7248888 , 3.3261378 , 4.8302746 ,-3.9337704 ,-4.2943096 , 0.16059242, 0.17785172,-2.4971933 ,-2.933359 ,-4.598231 , 4.7816315 ,-0.6563864 , 4.452592 , 1.8066075 , 3.1572745 , 4.500678 ,-1.1609873 ,-1.6962403 , 1.567031 ,-3.3120036 , 1.8150452 ,-2.7486987 ,-1.6800771 , 1.4895486 , 1.120401 , 1.4983965 , 4.7132416 , 0.39645562,-3.12486 ,-0.5966056 , 4.618641 , 1.225812 , 0.99017185, 3.9918585 , 1.299415 ,-1.2995726 , 4.202907 , 3.8657827 ,-4.0268126 ,-0.90370494, 0.5030568 ,-2.9651542 ,-4.1249614 ,-2.8990393 ,-4.1228724 ,-1.2640246 ,-0.72640723,-1.7128279 , 2.7710931 , 2.8189523 ,-0.8384207 , 0.71266395, 3.8393862 ,-1.7801509 ,-3.1485069 , 3.2076547 , 2.267659 ,-3.745656 ,-4.373508 , 0.86005193,-4.9145784 , 0.9253047 , 1.1243923 , 0.46507052, 1.9978004 ,-4.642887 ,-2.1898057 , 0.88199854,-2.1837327 , 1.1112527 ,-1.4548608 ,-3.5766103 ,-1.5607064 ,-3.630397 ,-1.9193211 ,-0.8931484 ,-0.2812017 ,-1.2881653 ,-2.5051243 ,-3.5648384 ,-0.5431733 ,-0.47036746,-2.8132265 ,-0.4302025 ,-4.003176 , 0.31743896,-3.074693 ,-3.3994603 , 0.62276137, 0.12920536,-2.5154057 ,-0.22098878,-2.711012 ,-0.303956 , 4.6025276 , 3.1887815 ,-0.50345755,-2.6543994 ,-0.8452558 ,-1.4075644 , 3.6716504 , 2.7388885 ,-4.9426928 , 3.5494354 , 4.777085 ,-3.3904083 ,-2.4746811 ,-2.943489 , 1.3607427 , 1.313449 ,-2.7959676 , 4.5932074 , 0.2460288 ,-1.1802251 , 0.6807028 ,-3.7335384 ,-0.30950046, 0.0558207 ,-4.7604976 ,-4.5745177 ,-3.3872643 ,-1.102581 ,-1.5612804 ,-1.2933319 , 4.5290637 ,-2.5096595 , 0.8673844 , 0.6069363 , 0.8294639 ,-0.05487671,-2.5923786 , 3.2974155 , 2.252853 ,-2.4157743 , 1.6614583 , 1.975577 ,-2.7390766 ,-0.26459846, 0.8946814 ,-3.257953 , 4.0526175 ,-1.5219783 , 4.6063023 ,-0.09599628, 3.2825923 , 2.0063279 ,-3.597641 ,-0.41604096,-2.5593333 , 1.8169669 ,-3.6998532 ,-2.3723404 , 0.4008657 , 2.1002467 , 4.9284163 , 4.6011457 ,-4.8977246 , 4.7852945 , 1.2170111 ,-1.055987 , 2.27575 , 1.0601226 ,-4.176826 , 0.08197393, 4.0421042 , 3.6263971 , 2.6941037 ,-2.644993 , 0.10439859,-4.512112 , 3.7939842 ,-4.8532767 , 0.391317 , 3.6432517 ,-3.9992728 , 0.29700363, 1.2722415 ,-2.3793647 ,-3.377246 , 2.0930648 , 2.574604 ,-1.2509564 , 0.4457573 ,-0.46469867, 2.6793416 , 0.02566718,-0.11948132,-3.1046712 ,-0.6204446 ,-4.615342 , 4.057695 , 1.1312845 ,-3.0446556 ,-1.9381613 ,-0.92255247,-3.5459394 ,-1.1972907 , 0.5879403 ,-1.2265042 ,-2.6279037 , 3.7533212 ,-0.2950134 ,-1.6104454 , 4.7811155 , 3.9216835 ,-2.2905827 ,-3.9489107 ,-4.078132 , 4.878544 ,-2.1483154 ,-3.1480436 ,-1.8742744 , 0.38310575,-4.0457416 ,-1.5423136 , 4.9426446 , 2.80434 ,-2.758338 , 1.6596367 ,-4.559686 ,-1.2686385 ,-1.2173673 , 0.49475643,-2.4956207 ,-1.5008336 ,-1.7967415 ,-1.1574938 , 2.2852411 , 1.7171949 ,-3.328038 ,-3.1454384 ,-0.41883984, 3.822312 , 1.1161699 ,-1.5137968 , 3.1651397 , 3.2411747 , 1.2685378 , 2.7408757 ,-3.078621 , 3.3460293 ,-0.34918678,-1.0433053 , 0.9397743 ,-3.9071774 , 0.68924445, 4.896269 ,-4.234603 ,-4.8659916 , 1.472339 , 4.5464644 , 0.35857418, 3.4065645 ,-1.514736 , 4.2301235
diff --git a/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/layer/uint8/4.txt b/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/layer/uint8/4.txt
new file mode 100644
index 000000000..163c037cf
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Mean_000_config/layer/uint8/4.txt
@@ -0,0 +1 @@
+-0.91463715,-2.9258113 , 4.4465976 ,-0.84762925,-3.3510911 ,-0.15094744, 2.2284694 , 3.9705405 ,-1.6315348 , 4.698665 , 2.8595035 ,-2.4719086 , 4.2091336 ,-3.7003224 , 0.06198901, 4.24617 ,-3.7041452 , 1.4280707 , 0.61925036, 3.873551 , 0.3554166 , 3.0535998 ,-1.403015 , 2.5769274 , 4.0060935 ,-2.134697 , 0.61366636,-2.2069314 , 3.5629356 ,-4.94381 , 3.3054771 ,-0.42945656, 4.4868546 , 4.124087 ,-4.039486 , 0.75716823,-4.530404 ,-0.8464823 , 2.7817092 ,-4.954212 , 4.790015 , 2.5307322 , 0.635834 ,-3.393037 ,-3.7000508 ,-1.1439751 ,-2.4422479 , 3.9414582 ,-4.0586324 ,-3.5872777 , 2.2529798 , 0.50453144,-2.9947112 ,-0.76174486, 0.8427806 ,-0.90798455,-0.5518859 ,-1.1810572 , 1.2787138 ,-1.7791113 ,-4.661412 ,-3.7413049 , 0.03910514, 3.970302 ,-3.0697417 ,-4.107844 ,-1.985001 ,-2.434408 ,-3.0120797 , 0.34467867, 0.09826441, 3.1933572 , 0.09855966, 1.7976784 ,-3.3814316 ,-2.8423817 ,-4.787137 , 0.21746217,-1.8560363 ,-0.7145455 , 3.911294 , 4.6970305 ,-4.0105987 , 3.3843613 , 2.3087065 , 1.8619018 , 1.6607213 ,-4.1276345 ,-0.15251912, 3.1198032 , 1.8143575 , 2.178214 ,-4.6250186 , 4.4006424 ,-3.378407 , 3.6481302 , 4.4439235 , 4.5322957 , 2.7754776 , 1.9026359 ,-2.9371052 , 0.32501587, 4.980984 ,-3.2300677 , 4.190388 , 4.441369 , 0.8116277 ,-4.7056756 , 1.1501676 ,-0.9759702 ,-0.1920487 ,-3.2009268 , 4.654679 , 4.043145 , 4.579935 , 4.917842 ,-3.2166183 , 2.381046 , 2.3470554 , 0.04456256,-2.6785278 ,-2.1683002 ,-0.2686819 , 0.6097173 , 1.5071467 , 3.9692068 ,-3.4313831 ,-0.87708473, 3.9917011 , 0.7843428 ,-4.6622047 , 0.774621 ,-4.6538844 , 3.6392822 , 4.962988 , 1.4132729 ,-0.40482154,-1.8656421 ,-1.6113061 ,-1.3454957 , 0.40846685,-4.5410986 , 2.7158992 ,-1.8403106 ,-3.803351 , 4.406537 ,-1.5868717 , 2.7034876 ,-3.3383765 , 4.6084027 ,-1.691095 ,-0.52188784, 2.9010768 , 0.08786624, 2.7466853 ,-1.7457972 , 0.59371734,-0.1716976 ,-2.6220891 , 4.9432936 , 2.3500183 , 1.6905144 ,-2.7329378 , 4.003541 ,-1.1137847 , 3.9017355 , 0.9116626 , 4.233729 ,-2.6706429 , 3.4342804 ,-0.42729262, 1.174779 ,-4.944099 , 1.2316282 , 4.9237943 ,-2.2999635 ,-4.9210916 ,-1.9033331 , 0.43241265, 3.2149148 , 4.1269703 , 0.8590868 , 2.734273 , 1.658618 ,-2.1702065 ,-2.0058317 , 4.0706363 , 4.003833 ,-0.35835287, 2.5514262 , 1.2571276 ,-4.655018 , 3.6468434 , 0.06320113,-4.662375 , 1.0745742 ,-1.117399 , 4.167245 , 4.59434 ,-1.686359 ,-0.17328739, 0.3083307 , 3.3926466 , 2.2254786 ,-0.45468137, 2.4956248 ,-3.492782 ,-2.9805465 ,-1.0610795 ,-0.2784433 , 0.7163735 ,-3.0048254 ,-1.8024784 ,-3.3139167 ,-1.8410577 , 4.5702477 ,-3.4454951 ,-1.4504164 ,-1.7432297 ,-4.998418 ,-2.5524495 , 3.028534 , 4.075326 ,-2.2187853 ,-0.6484594 , 3.00815 ,-2.8010397 ,-4.5529976 , 1.7830837 , 0.3373458 , 0.19151935,-1.0437245 ,-3.6349878 , 1.1947471 ,-1.9664146 , 0.27316815,-0.20781417, 2.419226 , 0.02246885, 4.5222287 , 3.1069999 , 3.940458 , 4.2710595 , 3.4216619 , 2.8447206 , 2.7136886 ,-0.60954016, 2.9277234 , 3.995615 ,-0.30593097, 1.7800944 , 1.0608315 , 3.8786283 ,-2.7564247 , 1.8526665 ,-3.8638606
diff --git a/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/channel/int16/0.txt b/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/channel/int16/0.txt
new file mode 100644
index 000000000..e580d6f85
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/channel/int16/0.txt
@@ -0,0 +1 @@
+-4.024665 , 3.0544488,-4.5645285,-3.2134292,-2.1543078, 4.039755 ,-4.613908 , 4.2014904, 3.8222141,-4.4992657,-4.02681 ,-3.2933445
diff --git a/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/channel/int16/1.txt b/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/channel/int16/1.txt
new file mode 100644
index 000000000..c593dfbb6
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/channel/int16/1.txt
@@ -0,0 +1 @@
+-2.669042 , 2.479217 , 4.691815 , 1.8187722 ,-3.7656548 ,-2.0555806 ,-2.4494352 ,-3.2394514 ,-0.38215363,-1.543695 ,-0.6927158 , 2.3534324
diff --git a/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/channel/int16/2.txt b/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/channel/int16/2.txt
new file mode 100644
index 000000000..14520a177
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/channel/int16/2.txt
@@ -0,0 +1 @@
+ 4.036224 ,-1.2903051 , 1.2116423 , 3.92255 ,-0.48049024,-1.0290806 ,-0.9644837 , 1.3379688 ,-1.0027533 ,-1.9611529 , 3.7190473 , 0.45794436
diff --git a/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/channel/int16/3.txt b/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/channel/int16/3.txt
new file mode 100644
index 000000000..2238d5e9e
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/channel/int16/3.txt
@@ -0,0 +1 @@
+ 4.560488 ,-1.2475324, 1.8892838,-2.0155866,-4.968927 , 0.3717404,-0.6095849, 3.2483344,-1.2499679, 1.4237018,-3.1225715, 3.0611598
diff --git a/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/channel/int16/4.txt b/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/channel/int16/4.txt
new file mode 100644
index 000000000..14a91ccc9
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/channel/int16/4.txt
@@ -0,0 +1 @@
+-1.7167594, 2.116633 ,-1.3816848,-1.7106141,-3.273076 ,-4.148302 ,-2.1654181, 0.4368236, 3.4279666, 1.2954224, 1.3004405,-4.3022
diff --git a/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/layer/uint8/0.txt b/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/layer/uint8/0.txt
new file mode 100644
index 000000000..3b2a3c258
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/layer/uint8/0.txt
@@ -0,0 +1 @@
+ 4.9167333 , 0.9170983 ,-2.4031715 , 0.4819133 , 0.21536288,-2.0262568 , 4.364642 , 1.7851653 , 2.0982797 , 0.5736603 , 2.5769486 , 3.68285
diff --git a/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/layer/uint8/1.txt b/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/layer/uint8/1.txt
new file mode 100644
index 000000000..dff8a3b09
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/layer/uint8/1.txt
@@ -0,0 +1 @@
+ 3.8708763 , 3.263454 ,-4.796817 , 0.6411522 ,-3.0385532 , 0.49334133,-0.20283684,-0.88814104, 4.826072 ,-4.8037696 , 4.757636 ,-3.036691
diff --git a/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/layer/uint8/2.txt b/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/layer/uint8/2.txt
new file mode 100644
index 000000000..93e747284
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/layer/uint8/2.txt
@@ -0,0 +1 @@
+-3.8694625 ,-3.5254061 ,-0.23680535, 4.1042504 , 3.2534697 ,-1.8511593 ,-1.9182487 , 2.6457057 , 0.12923336, 2.618141 , 1.2465005 ,-4.4625525
diff --git a/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/layer/uint8/3.txt b/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/layer/uint8/3.txt
new file mode 100644
index 000000000..c924e03d9
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/layer/uint8/3.txt
@@ -0,0 +1 @@
+-2.5559328 , 1.768443 ,-1.4850446 ,-1.2771453 ,-2.7216687 , 2.80077 , 0.21637216,-0.6145739 ,-0.37175298, 3.8750615 ,-1.9910356 ,-1.657059
diff --git a/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/layer/uint8/4.txt b/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/layer/uint8/4.txt
new file mode 100644
index 000000000..1153c85ed
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Mul_001_config/layer/uint8/4.txt
@@ -0,0 +1 @@
+-1.6168976 ,-3.816399 ,-0.55625045, 4.961818 , 0.19316113,-2.6601286 ,-1.6928803 , 4.1208386 ,-1.4012221 , 2.7742999 , 0.75798005,-2.5877
diff --git a/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/channel/int16/0.txt b/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/channel/int16/0.txt
new file mode 100644
index 000000000..1f2993269
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/channel/int16/0.txt
@@ -0,0 +1 @@
+-3.3436873 ,-0.79453826, 2.2211137 , 2.6420908 ,-1.3191302 , 1.2973647 ,-4.506594 , 4.867371 ,-4.318404 , 1.6957753 ,-4.3091793 ,-3.2230556 , 4.9175825 ,-3.1527104 ,-2.6669753 ,-2.1135337 ,-3.7701926 ,-3.358504 ,-4.419803 , 3.2045574 ,-0.5828494 ,-3.5796826 ,-4.0088696 ,-4.7178082 , 2.2726505 , 2.1860175 , 3.7198956 ,-0.5788681 ,-3.7766652 ,-0.65016747, 3.707159 ,-2.240267 , 4.5772953 ,-0.54754776, 4.7143884 ,-3.196982 ,-3.6356654 , 3.7157805 , 3.1312432 , 0.58816016, 2.1710336 ,-1.600533 ,-3.689763 , 4.322089 , 0.4816874 , 2.2769346 ,-3.9072733 ,-0.58615017
diff --git a/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/channel/int16/1.txt b/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/channel/int16/1.txt
new file mode 100644
index 000000000..a19ea6696
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/channel/int16/1.txt
@@ -0,0 +1 @@
+-1.275483 ,-3.6622071 ,-0.87433696, 0.60946655, 1.4415421 , 3.3705983 , 2.2635043 , 3.3926573 ,-0.2936643 ,-0.5169573 , 3.2535644 , 2.1269164 ,-3.4180303 , 1.0427854 ,-1.3514856 , 3.6084783 , 4.569944 ,-0.79272085, 2.9771423 ,-1.6668562 , 4.8700657 , 0.3355385 , 0.76509756, 3.5142152 ,-1.6743544 , 4.794434 ,-2.958765 ,-0.23857778, 2.4555902 , 2.459867 , 3.3922994 ,-4.350212 , 0.6286153 , 0.8139546 , 4.1676807 ,-3.3461437 , 0.69633776,-4.6548877 , 0.98267466,-4.508397 ,-1.4581255 ,-1.2289628 , 3.8701873 , 3.334336 ,-3.5611253 , 2.6133575 ,-1.0554558 ,-3.3291767
diff --git a/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/channel/int16/2.txt b/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/channel/int16/2.txt
new file mode 100644
index 000000000..7113eb52e
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/channel/int16/2.txt
@@ -0,0 +1 @@
+-0.6250365 ,-4.798417 ,-4.214081 ,-3.625409 , 2.4391694 , 4.1856265 , 3.2472587 ,-3.20996 ,-2.3537548 , 1.3749354 , 2.5947835 ,-1.8891864 ,-3.612735 , 2.246563 , 1.2701501 ,-2.8927476 ,-0.71078295,-3.6037376 ,-4.5916877 , 2.0044398 , 3.4437728 ,-1.0695096 , 4.3483944 ,-3.3387017 ,-0.9384242 , 1.4229002 ,-0.6568144 , 1.1164346 , 1.7145283 ,-2.596518 , 4.6728883 , 3.4737296 , 1.7935314 , 3.1263895 , 1.3614839 ,-3.824968 ,-3.0405738 , 3.1729462 ,-4.1985774 ,-2.9489865 ,-4.2080064 , 2.0368521 ,-2.858539 ,-0.03206728,-1.1123812 , 0.2994737 , 1.6906137 ,-0.8665008
diff --git a/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/channel/int16/3.txt b/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/channel/int16/3.txt
new file mode 100644
index 000000000..afeb2c0e6
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/channel/int16/3.txt
@@ -0,0 +1 @@
+-4.5279946 ,-3.4497826 ,-2.058617 ,-0.39549035,-0.26672208, 3.0173857 , 3.2430282 , 1.9996022 , 1.3895315 , 1.7620904 ,-4.9040093 ,-3.2858686 ,-2.2823575 ,-1.4176623 ,-0.537347 , 0.68219584,-3.193989 ,-3.1675165 , 0.47214374,-4.390378 ,-1.8730192 , 1.4416525 ,-3.0460286 ,-0.73547626, 1.8686327 ,-0.8146671 ,-2.0906649 , 0.01226121,-0.06992937, 0.9302521 ,-2.1858516 , 4.8370657 ,-4.1847024 , 4.4963436 ,-1.3834711 ,-1.1244944 , 0.4290957 ,-4.2681174 , 1.2978764 , 3.4149706 ,-2.7011304 ,-3.1285405 ,-3.8857136 ,-0.18625297,-0.13618916, 2.427405 ,-1.7979074 ,-1.4174187
diff --git a/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/channel/int16/4.txt b/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/channel/int16/4.txt
new file mode 100644
index 000000000..99c6284d6
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/channel/int16/4.txt
@@ -0,0 +1 @@
+-0.40635094,-2.485209 ,-2.9641154 , 4.09174 ,-1.9137962 ,-2.0860991 , 1.6594787 , 0.53744185, 1.7737653 ,-1.7054961 , 2.5611186 ,-1.1456238 , 2.741241 ,-2.283051 ,-4.2111306 ,-0.8722772 , 1.6465468 ,-0.61518955, 0.08495517, 3.6847656 , 3.7826371 , 2.0023444 ,-3.5326133 , 2.3723035 , 3.7383325 ,-3.3514297 , 2.031452 ,-0.7364658 ,-4.3347225 ,-2.8146286 ,-1.37377 ,-3.518721 ,-0.19657679,-1.6831368 , 1.2457223 , 0.25099897,-4.4722757 ,-4.135197 ,-0.6378818 , 3.8833187 , 1.9291897 , 2.5969315 , 2.146067 ,-2.846719 ,-2.2562532 ,-2.6856182 , 2.824374 , 2.3662992
diff --git a/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/layer/uint8/0.txt b/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/layer/uint8/0.txt
new file mode 100644
index 000000000..081a1e6ee
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/layer/uint8/0.txt
@@ -0,0 +1 @@
+-1.9927613e+00,-1.7386111e+00, 4.0895696e+00, 3.7818990e+00, 1.9420158e+00, 2.8482721e+00, 1.9165717e+00, 3.0059583e+00, 1.8346788e+00,-1.9055414e-03, 4.9277787e+00,-2.2794118e+00, 4.4005270e+00, 4.9703922e+00,-4.5275192e+00,-4.0446317e-01,-4.9363256e+00, 4.9506269e+00, 5.5874938e-01, 3.9949589e+00,-3.8152415e-01,-4.1024357e-01,-3.8472393e+00, 4.2956004e+00, 4.8097472e+00, 1.7960385e+00, 1.6767026e+00,-2.2773645e+00, 2.6808765e+00,-3.7214172e+00, 4.0978761e+00, 3.6202488e+00,-3.3211513e+00, 3.6200387e+00,-3.6106458e+00,-3.9778764e+00, 3.8779631e+00,-4.8502750e+00,-2.1901150e+00, 3.1800017e+00, 4.6261444e+00, 3.5151103e+00, 2.8659137e-02, 4.5340648e+00, 1.9836371e+00,-2.1751235e+00,-4.6762753e+00,-3.6951694e+00
diff --git a/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/layer/uint8/1.txt b/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/layer/uint8/1.txt
new file mode 100644
index 000000000..f6b31db38
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/layer/uint8/1.txt
@@ -0,0 +1 @@
+-4.7488093 , 4.805902 ,-0.29828382, 0.57486725,-4.864297 , 1.1832287 ,-1.7611881 ,-2.7058024 , 2.707353 ,-3.9832466 , 3.1243927 ,-4.795229 , 1.9835415 , 3.2291937 , 2.4303932 ,-3.556881 , 4.316894 ,-0.6444627 ,-3.8289468 , 4.012964 , 0.7878584 ,-1.8921386 , 2.779619 ,-3.762597 , 3.4239094 ,-0.9103423 ,-3.9791772 ,-2.5613685 ,-4.4910364 , 0.19411987, 4.6296096 ,-0.6827259 , 3.7645729 , 1.5309091 , 3.5163064 , 3.4726381 , 3.5372822 , 1.7671971 , 1.4374614 , 3.5783768 ,-2.4927518 , 3.9427729 , 2.431568 , 2.6959393 , 3.8100271 ,-2.099064 , 3.3663592 ,-2.0818436
diff --git a/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/layer/uint8/2.txt b/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/layer/uint8/2.txt
new file mode 100644
index 000000000..acc01cb55
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/layer/uint8/2.txt
@@ -0,0 +1 @@
+ 4.279912 ,-2.2746763 , 4.0609813 , 4.5353827 , 3.624241 ,-3.9593613 , 4.189409 ,-3.9370356 ,-2.7063863 ,-1.9987059 , 4.172294 ,-4.5454354 , 4.362368 , 2.2204642 ,-4.9866576 , 3.31571 , 0.12623785, 4.7834573 ,-1.3521448 ,-1.5408021 ,-4.6578984 ,-2.93307 ,-1.5684534 ,-1.6875995 ,-0.4278419 , 1.1314197 ,-2.9655704 ,-0.48032767,-1.9200082 , 1.3321692 , 0.87586147,-0.1761448 , 3.939337 ,-1.0270193 ,-4.807054 , 2.8373904 ,-1.1184337 ,-0.8979197 , 2.1442132 ,-2.8509672 ,-3.3741531 , 3.6592414 , 0.7632272 ,-4.11465 , 4.892313 , 4.715815 ,-4.6481915 , 0.24676175
diff --git a/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/layer/uint8/3.txt b/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/layer/uint8/3.txt
new file mode 100644
index 000000000..0f0b7a939
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/layer/uint8/3.txt
@@ -0,0 +1 @@
+-2.0949495 ,-1.1370499 , 4.6457314 ,-2.243915 ,-1.7996464 , 1.2268789 ,-4.938172 ,-3.2802615 , 1.8788282 , 4.4162655 ,-4.8805113 , 3.1269526 , 3.2644348 , 0.89842725,-1.4484432 ,-0.28381723, 3.046261 ,-1.0718596 ,-3.996107 ,-4.9575796 ,-2.2279077 , 1.5326967 , 4.4588428 ,-2.042381 , 4.6604958 , 4.6422915 ,-1.097833 , 3.666126 , 0.4735639 ,-4.480704 ,-4.831033 ,-0.27288163, 4.588138 , 4.5297036 , 4.3675694 ,-1.6098841 ,-3.4147859 , 2.1168516 ,-1.9529305 ,-0.12548867, 3.4388335 ,-1.4071734 , 0.9507897 , 4.8206787 , 1.676873 ,-1.7102181 , 1.7746873 , 0.02711739
diff --git a/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/layer/uint8/4.txt b/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/layer/uint8/4.txt
new file mode 100644
index 000000000..d23450db6
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/PRelu_001_config/layer/uint8/4.txt
@@ -0,0 +1 @@
+-4.707647 ,-4.0921726 , 3.5813692 ,-4.71081 , 3.157816 ,-3.0034213 ,-0.21858999,-1.1736552 ,-1.6042249 ,-3.93102 ,-4.0407577 , 3.7350774 ,-4.9545655 ,-1.5413756 , 0.34996858, 2.0339615 , 0.99290746,-3.9916334 ,-4.149016 ,-3.2332835 , 3.6728513 , 2.4537466 ,-3.103485 ,-0.4829316 , 4.8046784 ,-1.753812 , 4.878712 ,-1.4039769 , 1.6640003 ,-1.2041731 , 0.8046477 , 0.9196048 ,-0.6475092 , 1.1409346 , 2.0324717 ,-0.04227797,-0.5379897 , 3.205104 , 3.3556423 , 4.8447986 ,-1.9695646 ,-2.6304977 ,-3.7261262 ,-4.725599 , 2.1162436 ,-0.5631174 ,-0.5820323 , 0.8398242
diff --git a/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/channel/int16/0.txt b/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/channel/int16/0.txt
new file mode 100644
index 000000000..eb058a1c3
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/channel/int16/0.txt
@@ -0,0 +1 @@
+-0.55411166,-4.1992335 , 1.4317423 ,-3.7261302 , 1.151971 ,-2.117022 ,-0.7386241 , 4.654951 , 1.4869142 ,-4.6252975 ,-3.305923 , 3.632628 ,-2.6403873 ,-4.862389 , 3.477561 ,-4.9842925 ,-3.6267536 , 4.9950438
diff --git a/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/channel/int16/1.txt b/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/channel/int16/1.txt
new file mode 100644
index 000000000..ff15f032d
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/channel/int16/1.txt
@@ -0,0 +1 @@
+ 0.18094282,-0.58095986, 1.2765085 ,-0.534363 , 4.5564513 ,-0.28305855, 0.80606604,-3.3217795 ,-0.08041744,-3.7558215 ,-0.5370528 , 1.8984528 ,-0.09462419,-0.28595117, 4.6817894 ,-4.6653147 ,-4.127137 ,-2.3407753
diff --git a/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/channel/int16/2.txt b/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/channel/int16/2.txt
new file mode 100644
index 000000000..e564168bf
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/channel/int16/2.txt
@@ -0,0 +1 @@
+-0.62747055, 1.4133646 ,-0.9954612 ,-4.687624 ,-2.5390003 ,-4.534569 ,-1.1943612 ,-4.830596 , 4.3214984 ,-2.4795794 , 4.166298 ,-1.4772589 ,-4.074577 , 3.2332711 ,-1.5221404 ,-1.7308865 , 0.06814837, 2.944668
diff --git a/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/channel/int16/3.txt b/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/channel/int16/3.txt
new file mode 100644
index 000000000..c763b6311
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/channel/int16/3.txt
@@ -0,0 +1 @@
+-3.2136867 , 0.6229863 , 0.02772082,-0.00820862,-2.4893622 ,-0.6757174 ,-2.2024722 ,-2.0893583 , 0.33953062,-3.5438979 , 0.7000838 , 1.3219849 ,-0.02302017, 2.3125873 ,-1.5376673 ,-4.0330076 , 4.755884 , 2.729685
diff --git a/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/channel/int16/4.txt b/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/channel/int16/4.txt
new file mode 100644
index 000000000..12e13272d
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/channel/int16/4.txt
@@ -0,0 +1 @@
+ 0.82922786, 4.762074 ,-3.5043278 , 2.4521468 , 2.6450796 ,-2.8606322 , 0.8321993 ,-1.4020495 ,-0.25749585, 1.0287803 ,-3.911455 ,-1.8311876 , 2.763438 , 3.8604703 ,-3.5478592 ,-4.2335987 ,-3.6402035 ,-1.8485361
diff --git a/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/layer/uint8/0.txt b/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/layer/uint8/0.txt
new file mode 100644
index 000000000..42ce6be36
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/layer/uint8/0.txt
@@ -0,0 +1 @@
+ 1.1826919 , 0.07451724, 3.48515 , 3.4905832 , 1.8009655 , 4.155749 , 3.3155255 , 2.6834202 ,-1.7111781 ,-2.2254407 ,-4.578932 ,-2.1239302 ,-0.1269101 ,-2.6022012 ,-4.8320093 , 0.2983099 ,-0.43314072,-0.66332716
diff --git a/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/layer/uint8/1.txt b/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/layer/uint8/1.txt
new file mode 100644
index 000000000..f677cc836
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/layer/uint8/1.txt
@@ -0,0 +1 @@
+-1.2971772 ,-3.6082 ,-2.2253058 ,-4.4367466 ,-1.7221912 , 0.02547262,-3.641017 , 0.2953748 , 0.7217547 , 4.663728 , 4.262444 ,-3.196005 ,-1.6792587 ,-1.7463406 , 2.030074 , 0.67998594,-0.92862725,-1.7960806
diff --git a/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/layer/uint8/2.txt b/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/layer/uint8/2.txt
new file mode 100644
index 000000000..841ea9f8f
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/layer/uint8/2.txt
@@ -0,0 +1 @@
+ 2.2390285 ,-1.9557759 ,-1.2331479 ,-2.4810686 ,-0.5112022 , 1.741153 , 0.13645513,-2.3543327 ,-3.2610211 , 2.5739572 ,-0.50510126, 2.3544457 , 1.884411 ,-3.7153857 ,-1.7037194 ,-0.36849263,-4.819704 , 3.047652
diff --git a/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/layer/uint8/3.txt b/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/layer/uint8/3.txt
new file mode 100644
index 000000000..08ec9fe8f
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/layer/uint8/3.txt
@@ -0,0 +1 @@
+-0.9080747 ,-1.5609599 ,-0.40923035,-2.0569193 , 4.5904484 ,-0.02348744, 0.35939455, 2.2017193 , 2.2766497 ,-2.2080436 ,-2.6453862 ,-3.6456985 , 4.160244 , 1.7283534 , 4.5547447 ,-1.8674839 , 3.019465 , 1.1584582
diff --git a/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/layer/uint8/4.txt b/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/layer/uint8/4.txt
new file mode 100644
index 000000000..a4f2d97d1
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/ReLU_000_config/layer/uint8/4.txt
@@ -0,0 +1 @@
+ 4.5920744 , 3.827386 ,-2.1228654 , 3.7227573 ,-3.4464717 , 0.31313375, 0.5531476 ,-0.30391756,-0.21601346, 3.8968146 , 0.23224053,-0.6208954 ,-0.76323295,-1.1700501 ,-1.6203161 , 2.1780837 , 2.3581395 , 2.6519518
diff --git a/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/int16/0.txt b/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/int16/0.txt
new file mode 100644
index 000000000..0e8d687b1
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/int16/0.txt
@@ -0,0 +1 @@
+-2.327701 , 1.9312059 ,-2.0069487 ,-1.2584914 ,-0.08435626, 0.47685367,-2.7456024 , 2.1275337 ,-4.9685698 , 1.8143541 , 0.52829266,-2.770121
diff --git a/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/int16/1.txt b/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/int16/1.txt
new file mode 100644
index 000000000..67732e8f5
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/int16/1.txt
@@ -0,0 +1 @@
+ 0.01133719,-3.3741624 , 3.556686 ,-4.21059 , 0.49977505, 1.768375 , 3.867543 , 2.270572 ,-3.9507272 ,-4.595618 ,-4.7460327 , 0.5856542
diff --git a/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/int16/2.txt b/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/int16/2.txt
new file mode 100644
index 000000000..7bc7124d6
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/int16/2.txt
@@ -0,0 +1 @@
+-2.7181 , 4.6819983 , 2.9022477 ,-0.10716935, 3.6687856 ,-2.5403244 ,-4.477037 , 2.5499978 ,-3.9294813 , 0.08725335,-2.243345 ,-1.4018577
diff --git a/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/int16/3.txt b/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/int16/3.txt
new file mode 100644
index 000000000..0fac9fb70
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/int16/3.txt
@@ -0,0 +1 @@
+-3.920553 , 0.87464577,-1.0319884 , 2.1885726 , 2.755115 ,-1.6436632 ,-4.4507327 , 4.915525 , 2.9331517 , 4.7712016 , 4.676084 ,-1.7715888
diff --git a/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/int16/4.txt b/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/int16/4.txt
new file mode 100644
index 000000000..df79104c2
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/int16/4.txt
@@ -0,0 +1 @@
+-2.181168 ,-1.6011912 ,-4.359466 ,-1.3662407 ,-0.06876431,-2.9213328 ,-0.5463467 ,-3.7916536 ,-3.751455 ,-2.822578 , 0.8914152 ,-3.0267959
diff --git a/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/uint8/0.txt b/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/uint8/0.txt
new file mode 100644
index 000000000..4b999a028
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/uint8/0.txt
@@ -0,0 +1 @@
+ 3.241328 , 2.7033713 ,-2.5329788 ,-4.078369 ,-3.6711028 , 2.8912613 , 0.6188993 , 3.3729403 , 2.9906578 , 0.69040877, 0.6443222 , 1.1676162
diff --git a/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/uint8/1.txt b/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/uint8/1.txt
new file mode 100644
index 000000000..7061063b9
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/uint8/1.txt
@@ -0,0 +1 @@
+ 1.572614 , 3.6147017 , 1.4378501 ,-0.81497866, 1.5987366 , 3.7698908 ,-3.8637109 , 4.5728784 ,-0.8706349 , 0.7389268 , 4.64117 ,-0.96047217
diff --git a/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/uint8/2.txt b/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/uint8/2.txt
new file mode 100644
index 000000000..c048a8a9f
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/uint8/2.txt
@@ -0,0 +1 @@
+ 0.00864919,-3.1653113 ,-2.125551 , 2.9225516 ,-1.1439148 , 4.6509814 ,-2.097259 , 2.5843353 ,-2.067207 ,-2.5034845 ,-4.9441104 ,-3.9062042
diff --git a/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/uint8/3.txt b/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/uint8/3.txt
new file mode 100644
index 000000000..55be3b464
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/uint8/3.txt
@@ -0,0 +1 @@
+ 1.0920542 , 0.5510192 , 1.3465579 ,-2.3510268 , 4.016736 , 4.7848744 ,-0.42403316, 0.00571597, 1.6412207 , 1.7787368 , 2.4728034 ,-3.5900247
diff --git a/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/uint8/4.txt b/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/uint8/4.txt
new file mode 100644
index 000000000..04c7a1a8a
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/Split_000_config/channel/uint8/4.txt
@@ -0,0 +1 @@
+-2.9799085,-3.9477375, 0.6402844, 3.304766 , 3.8880465,-3.5069442,-2.3702915, 4.126247 ,-3.1614416, 2.9909244,-2.8755414, 0.2627986
diff --git a/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/channel/int16/0.txt b/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/channel/int16/0.txt
new file mode 100644
index 000000000..e9db48f9e
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/channel/int16/0.txt
@@ -0,0 +1 @@
+-1.4124781 , 0.42694193, 1.1734594 ,-3.5111153 ,-2.9756174 , 1.3682148 ,-2.318465 , 2.198896 ,-4.5043235 , 3.1775594 ,-0.42802384,-1.4872279 , 1.3821319 ,-4.771963 ,-0.12837897, 4.132799 , 3.697655 , 2.0807178 ,-3.621293 , 2.121878 ,-0.25654107, 0.42100102,-1.4009671 ,-2.9733627 ,-0.7058871 ,-2.831215 , 3.5669627 , 2.1420689 ,-1.8789555 , 0.8104939 ,-2.0503597 , 1.7788508
diff --git a/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/channel/int16/1.txt b/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/channel/int16/1.txt
new file mode 100644
index 000000000..479d062f1
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/channel/int16/1.txt
@@ -0,0 +1 @@
+ 3.4726453 , 3.0497985 ,-4.234619 ,-1.0526706 , 1.7278554 ,-3.341614 , 4.54768 , 3.0954597 ,-3.735109 , 2.8810751 ,-2.5381427 ,-3.2360535 ,-1.5378917 , 2.3052745 ,-3.170938 ,-3.327242 , 2.0654576 ,-2.2294598 ,-1.881382 , 0.13216451,-4.2825613 , 0.26616526, 4.6196365 ,-0.88623226, 1.7103885 ,-1.5865034 ,-3.9114466 ,-3.2227128 , 4.909618 , 2.3318915 , 0.84300846, 0.760918
diff --git a/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/channel/int16/2.txt b/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/channel/int16/2.txt
new file mode 100644
index 000000000..ae28234bd
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/channel/int16/2.txt
@@ -0,0 +1 @@
+-4.6097918,-4.21991 ,-3.9955974, 3.6492047, 2.9191775, 2.8082933, 1.6189331, 0.2730309,-1.5029653,-1.9471445, 4.8758197, 3.3177438, 3.1338058,-2.1281245,-1.7526287,-2.5518703,-1.7746793, 4.0455256,-0.5839861,-4.408046 ,-4.0034447, 1.5858272,-4.5896654, 4.7211285,-4.677515 ,-2.6027086,-4.7896166,-3.5512326,-1.9068764,-2.9705904,-4.854087 ,-4.892111
diff --git a/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/channel/int16/3.txt b/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/channel/int16/3.txt
new file mode 100644
index 000000000..fd40f84f4
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/channel/int16/3.txt
@@ -0,0 +1 @@
+ 2.1514777e-02, 2.6526773e+00,-3.0477784e+00, 1.3287724e+00,-4.1414630e-01,-1.7295350e-01, 7.6649576e-01,-1.8028022e+00,-7.0781744e-01,-2.5262204e-01,-3.0970418e+00,-1.3165286e+00,-4.6649928e+00, 2.0809033e+00,-1.5739973e+00,-4.0531826e-01,-2.1718202e+00, 2.0146034e+00, 2.5044403e+00,-1.1256610e+00, 1.3536702e+00, 1.0283234e-03,-1.8823910e+00, 4.7122188e+00, 9.4781297e-01, 3.2012525e+00,-5.5164534e-01,-2.6158772e+00,-1.8771547e+00,-3.1689723e+00, 4.9054880e+00,-3.4560370e+00
diff --git a/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/channel/int16/4.txt b/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/channel/int16/4.txt
new file mode 100644
index 000000000..e81c3b8e5
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/channel/int16/4.txt
@@ -0,0 +1 @@
+-2.0927553 ,-2.107511 ,-1.6963564 , 1.7006218 , 1.4575784 , 0.06095728, 1.2659966 , 4.1905265 , 1.3035946 , 4.9793477 ,-4.3388166 ,-0.23496658, 1.9831208 , 2.6154642 ,-0.2790228 ,-3.1774354 ,-3.178935 ,-1.1564373 ,-0.8199472 ,-2.245698 ,-4.8605046 ,-3.569018 ,-1.4226891 ,-4.1067843 , 2.6078918 ,-3.5830674 , 1.9065963 , 2.435578 ,-3.3216476 , 4.5930347 , 2.9191844 , 1.7885648
diff --git a/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/layer/uint8/0.txt b/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/layer/uint8/0.txt
new file mode 100644
index 000000000..a8874bc5f
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/layer/uint8/0.txt
@@ -0,0 +1 @@
+ 3.9384239 ,-3.7377489 , 0.97284186, 3.8309984 , 2.4125865 , 1.7141674 , 3.9459977 ,-0.304659 ,-3.4623327 , 4.4569106 , 4.209985 ,-0.6677348 , 3.4578135 , 1.6779743 , 2.502791 ,-1.324285 , 1.3139176 , 3.4334664 ,-2.2695086 ,-4.001059 ,-0.91164917, 4.4447775 ,-3.0275404 ,-2.0852396 , 3.6677403 ,-2.9595146 , 2.0921555 , 1.7570637 , 3.717391 ,-0.3216191 ,-0.8410847 , 2.662336
diff --git a/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/layer/uint8/1.txt b/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/layer/uint8/1.txt
new file mode 100644
index 000000000..715e680be
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/layer/uint8/1.txt
@@ -0,0 +1 @@
+ 0.6663157 ,-0.04146723,-0.8193995 , 4.804576 ,-2.1357434 , 4.0829 ,-1.6380692 , 1.8043218 , 2.3431025 , 0.30111 , 1.2928191 ,-1.8559257 ,-0.68305963,-1.1502715 , 1.9492546 ,-2.7240746 , 2.9279857 ,-3.3329778 ,-4.8343406 ,-0.02708206, 1.1840513 , 3.6476028 , 4.75276 ,-4.9085226 ,-1.1922491 , 0.54225117, 3.17247 ,-2.7856457 ,-3.0866194 ,-2.2077718 , 1.6263398 , 3.7066603
diff --git a/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/layer/uint8/2.txt b/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/layer/uint8/2.txt
new file mode 100644
index 000000000..3ca893e61
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/layer/uint8/2.txt
@@ -0,0 +1 @@
+-4.8507566 ,-1.267258 , 0.5099198 , 1.650726 , 3.4329638 ,-2.2652836 , 1.2157568 , 0.18305123, 3.6754217 ,-4.6185255 ,-1.0646905 ,-0.46092424, 2.046326 ,-2.8830478 , 4.156068 ,-2.0503244 , 0.0755459 ,-4.6472006 ,-0.50128895, 3.1129324 ,-4.4048553 , 0.47983927, 1.4510479 , 3.9226127 ,-4.767221 ,-2.795826 ,-4.816457 ,-3.6127663 ,-2.2712553 , 4.586938 , 1.1028811 , 1.5028698
diff --git a/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/layer/uint8/3.txt b/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/layer/uint8/3.txt
new file mode 100644
index 000000000..3fba8ecec
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/layer/uint8/3.txt
@@ -0,0 +1 @@
+ 4.9431224 ,-3.4878132 ,-2.4831018 , 2.2395666 ,-2.3317611 ,-1.6786547 ,-2.4702384 , 3.2167027 , 1.7300137 , 2.8848834 ,-4.6395254 , 0.5527259 ,-2.915835 ,-1.0066313 ,-0.278253 , 4.6136203 ,-3.4183645 ,-1.5189631 ,-4.599058 , 3.3198457 ,-3.9464161 ,-0.6357558 , 0.32550323, 3.2147424 , 4.921844 ,-0.30067012, 3.9456701 , 0.5943688 ,-4.7229166 ,-3.6803844 ,-3.3813965 , 3.283583
diff --git a/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/layer/uint8/4.txt b/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/layer/uint8/4.txt
new file mode 100644
index 000000000..16cc23b79
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/TransposeConv_001_config/layer/uint8/4.txt
@@ -0,0 +1 @@
+ 2.232644 , 4.465217 , 1.926956 ,-4.007337 ,-2.7392106 ,-2.4579394 , 2.913538 ,-1.7261469 , 3.8706868 , 0.06259949,-2.018361 , 1.2728635 ,-3.133289 ,-4.943454 ,-1.5415367 ,-4.8183494 , 4.348317 ,-2.4929109 ,-0.9018388 ,-4.776565 , 4.634248 , 3.0753953 , 2.3412373 ,-2.7086196 , 3.4485948 , 0.3561932 , 0.03650501,-2.8704169 , 1.0514414 , 3.3964615 , 1.2783849 , 4.974951
diff --git a/compiler/pota-quantization-value-test/test_quantization_with_config.sh b/compiler/pota-quantization-value-test/test_quantization_with_config.sh
new file mode 100755
index 000000000..1364dfb90
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_quantization_with_config.sh
@@ -0,0 +1,109 @@
+#!/bin/bash
+
+# This script tests quantize_with_minmax option of circle-quantizer with config file
+#
+# HOW TO USE
+#
+# ./test_quantization_with_config.sh <path/to/test.config> <path/to/work_dir> <TEST 1> <TEST 2> ...
+# test.config : set ${RECORD_MINMAX_PATH} and ${CIRCLE_QUANTIZER_PATH}
+# work_dir : build directory of quantization-value-test (ex: build/compiler/quantization-value-test)
+
+SOURCE_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+COMPARE_SCRIPT_PATH="${SOURCE_PATH}/compare_tensors.py"
+CONFIG_PATH="$1"; shift
+BIN_PATH=$(dirname "${CONFIG_PATH}")
+TEST_INPUT_PATH="${SOURCE_PATH}/test_inputs"
+GEN_SCRIPT_PATH="${BIN_PATH}/gen_h5_explicit_inputs.py"
+WORKDIR="$1"; shift
+
+source "${CONFIG_PATH}"
+
+echo "-- Found CIRCLE_QUANTIZER: ${CIRCLE_QUANTIZER_PATH}"
+echo "-- Found CIRCLE_TENSORDUMP: ${CIRCLE_TENSORDUMP_PATH}"
+echo "-- Found workdir: ${WORKDIR}"
+
+TESTED=()
+PASSED=()
+FAILED=()
+
+pushd "${WORKDIR}"
+while [ "$1" != "" ]; do
+ MODELNAME=$1; shift
+ GRANULARITY=$1; shift
+ DTYPE=$1; shift
+ TESTCASE="${MODELNAME}.${GRANULARITY}.${DTYPE}"
+
+ TESTED+=("${TESTCASE}")
+
+ TESTCASE_FILE="${WORKDIR}/${TESTCASE}"
+ TEST_RESULT_FILE="${BIN_PATH}/${TESTCASE}"
+
+ PASSED_TAG="${TEST_RESULT_FILE}.quantization.mixed.passed"
+ rm -f "${PASSED_TAG}"
+
+ cat > "${TEST_RESULT_FILE}_quantization_with_config.log" <(
+ exec 2>&1
+ set -ex
+
+ # Generate h5 input data
+ source "${VIRTUALENV}/bin/activate"
+ "${VIRTUALENV}/bin/python" "${GEN_SCRIPT_PATH}" \
+ --model "${WORKDIR}/${MODELNAME}.circle" \
+ --input "${TEST_INPUT_PATH}/${MODELNAME}_config/${GRANULARITY}/${DTYPE}" \
+ --output "${TESTCASE_FILE}.mixed.input.h5"
+
+ if [[ $? -ne 0 ]]; then
+ echo "FAILED TO GENERATE INPUT"
+ continue
+ fi
+
+ # Run record-minmax
+ # NOTE There is no '_with_config' test for record-minmax, because it does not
+ # use quantization config file.
+ "${RECORD_MINMAX_PATH}" \
+ --input_model "${TEST_RESULT_FILE}.fake_quantized.mixed.circle" \
+ --input_data "${TESTCASE_FILE}.mixed.input.h5" \
+ --output_model "${TEST_RESULT_FILE}.minmax_recorded.mixed.circle"
+
+ # Run circle-quantizer with --quantize_with_minmax
+ "${CIRCLE_QUANTIZER_PATH}" \
+ --quantize_with_minmax float32 "${DTYPE}" "${GRANULARITY}" \
+ --config "${SOURCE_PATH}/config_files/${MODELNAME}/${GRANULARITY}/${DTYPE}/qconf.json" \
+ "${TEST_RESULT_FILE}.minmax_recorded.mixed.circle" \
+ "${TEST_RESULT_FILE}.quantized.mixed.circle"
+
+ # Dump scale, zp, weights values (circle-tensordump)
+ "${CIRCLE_TENSORDUMP_PATH}" \
+ "${TEST_RESULT_FILE}.quantized.mixed.circle" \
+ --tensors_to_hdf5 "${TEST_RESULT_FILE}.quantized.mixed.circle.h5"
+
+ # Compare result
+ "${VIRTUALENV}/bin/python" "${COMPARE_SCRIPT_PATH}" \
+ --input_h5 "${TEST_RESULT_FILE}.quantized.mixed.circle.h5" \
+ --expect_dir "${SOURCE_PATH}/expected_outputs/${MODELNAME}_config/${GRANULARITY}/${DTYPE}/quantization" \
+ --mode quantization
+
+ if [[ $? -eq 0 ]]; then
+ touch "${PASSED_TAG}"
+ fi
+ )
+
+ if [[ -f "${PASSED_TAG}" ]]; then
+ PASSED+=("$TESTCASE")
+ else
+ FAILED+=("$TESTCASE")
+ fi
+done
+popd
+
+if [[ ${#TESTED[@]} -ne ${#PASSED[@]} ]]; then
+ echo "FAILED"
+ for TEST in "${FAILED[@]}"
+ do
+ echo "- ${TEST}"
+ done
+ exit 255
+fi
+
+echo "PASSED"
+exit 0
diff --git a/compiler/pp/CMakeLists.txt b/compiler/pp/CMakeLists.txt
index 6d58458ca..1db09cb88 100644
--- a/compiler/pp/CMakeLists.txt
+++ b/compiler/pp/CMakeLists.txt
@@ -3,7 +3,9 @@ file(GLOB_RECURSE TESTS "src/*.test.cpp")
list(REMOVE_ITEM SOURCES ${TESTS})
add_library(pp STATIC ${SOURCES})
-set_target_properties(pp PROPERTIES POSITION_INDEPENDENT_CODE ON)
+if (NOT NNCC_LIBRARY_NO_PIC)
+ set_target_properties(pp PROPERTIES POSITION_INDEPENDENT_CODE ON)
+endif(NOT NNCC_LIBRARY_NO_PIC)
target_include_directories(pp PUBLIC include)
target_link_libraries(pp PRIVATE nncc_common)
target_link_libraries(pp PUBLIC nncc_coverage)
diff --git a/compiler/record-minmax-conversion-test/CMakeLists.txt b/compiler/record-minmax-conversion-test/CMakeLists.txt
index 2221e1702..31b906142 100644
--- a/compiler/record-minmax-conversion-test/CMakeLists.txt
+++ b/compiler/record-minmax-conversion-test/CMakeLists.txt
@@ -37,6 +37,6 @@ add_test(
COMMAND "${CMAKE_CURRENT_SOURCE_DIR}/testall.sh"
"${TEST_CONFIG}"
"${ARTIFACTS_BIN_PATH}"
- "${NNCC_OVERLAY_DIR}/venv_1_13_2"
+ "${NNCC_OVERLAY_DIR}/venv_2_8_0"
${RECORD_MINMAX_CONVERSION_TEST}
)
diff --git a/compiler/record-minmax/CMakeLists.txt b/compiler/record-minmax/CMakeLists.txt
index da63bbf5f..b9c08f472 100644
--- a/compiler/record-minmax/CMakeLists.txt
+++ b/compiler/record-minmax/CMakeLists.txt
@@ -1,25 +1,17 @@
-nnas_find_package(HDF5 COMPONENTS STATIC QUIET)
-
-if(NOT HDF5_FOUND)
- message(STATUS "Build record-minmax: FAILED (missing HDF5)")
- return()
-endif(NOT HDF5_FOUND)
-
set(DRIVER "driver/Driver.cpp")
file(GLOB_RECURSE SOURCES "src/*.cpp")
add_executable(record-minmax ${DRIVER} ${SOURCES})
target_include_directories(record-minmax PRIVATE include)
-target_include_directories(record-minmax PRIVATE ${HDF5_INCLUDE_DIRS})
-target_link_libraries(record-minmax ${HDF5_CXX_LIBRARIES})
target_link_libraries(record-minmax arser)
target_link_libraries(record-minmax safemain)
target_link_libraries(record-minmax luci_import)
target_link_libraries(record-minmax luci_env)
target_link_libraries(record-minmax luci_export)
target_link_libraries(record-minmax luci_interpreter)
+target_link_libraries(record-minmax dio_hdf5)
target_link_libraries(record-minmax vconone)
target_link_libraries(record-minmax nncc_coverage)
diff --git a/compiler/record-minmax/requires.cmake b/compiler/record-minmax/requires.cmake
index 9cf12591e..69373e76f 100644
--- a/compiler/record-minmax/requires.cmake
+++ b/compiler/record-minmax/requires.cmake
@@ -2,4 +2,5 @@ require("luci")
require("luci-interpreter")
require("safemain")
require("arser")
+require("dio-hdf5")
require("vconone")
diff --git a/compiler/record-minmax/src/HDF5Importer.h b/compiler/record-minmax/src/HDF5Importer.h
deleted file mode 100644
index 9e98c7752..000000000
--- a/compiler/record-minmax/src/HDF5Importer.h
+++ /dev/null
@@ -1,87 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __RECORD_MINMAX_HDF5IMPORTER_H__
-#define __RECORD_MINMAX_HDF5IMPORTER_H__
-
-#include <luci_interpreter/core/Tensor.h>
-
-#include <H5Cpp.h>
-
-#include <stdexcept>
-
-using Shape = luci_interpreter::Shape;
-using DataType = luci_interpreter::DataType;
-
-namespace record_minmax
-{
-
-// HDF5Importer reads an input data saved in the hdf5 file in the given path
-// The hierarchy of the hdf5 file is as follows.
-// Group "/"
-// > Group "value"
-// > Group <record_idx>
-// > Dataset <input_idx>
-// record_idx : index of the record (dataset file can contain multiple records)
-// input_idx : index of the input (DNN model can have multiple inputs)
-// Ex: the j'th input of the i'th record can be accessed by "/value/i/j"
-class HDF5Importer
-{
-public:
- explicit HDF5Importer(const std::string &path)
- {
- if (_file.isHdf5(path) == false)
- throw std::runtime_error("Given data file is not HDF5");
-
- _file = H5::H5File(path, H5F_ACC_RDONLY);
- }
-
-public:
- /**
- * @brief importGroup has to be called before readTensor is called
- * Otherwise, readTensor will throw an exception
- */
- void importGroup() { _value_grp = _file.openGroup("value"); }
-
- /**
- * @brief Read tensor data from file and store it into buffer
- * @details A tensor in the file can be retrieved with (record_idx, input_idx)
- * @param record_idx : index of the record
- * @param input_idx : index of the input
- * @param dtype : pointer to write the tensor's data type
- * @param shape : pointer to write the tensor's shape
- * @param buffer : pointer to write the tensor's data
- */
- void readTensor(int32_t record_idx, int32_t input_idx, DataType *dtype, Shape *shape,
- void *buffer);
-
- // Read a raw tensor (no type/shape is specified)
- void readTensor(int32_t record_idx, int32_t input_idx, void *buffer);
-
- bool isRawData() { return _value_grp.attrExists("rawData"); }
-
- int32_t numRecords() { return _value_grp.getNumObjs(); }
-
- int32_t numInputs(int32_t record_idx);
-
-private:
- H5::H5File _file;
- H5::Group _value_grp;
-};
-
-} // namespace record_minmax
-
-#endif // __RECORD_MINMAX_HDF5IMPORTER_H__
diff --git a/compiler/record-minmax/src/MinMaxObserver.cpp b/compiler/record-minmax/src/MinMaxObserver.cpp
index 28ae2b33b..8288d3e5e 100644
--- a/compiler/record-minmax/src/MinMaxObserver.cpp
+++ b/compiler/record-minmax/src/MinMaxObserver.cpp
@@ -51,6 +51,16 @@ void MinMaxObserver::postTensorWrite(const luci::CircleNode *node,
// Bool type tensor is not quantized
return;
}
+ if (node->dtype() == DataType::S32)
+ {
+ // Integer type tensor is not quantized
+ return;
+ }
+ if (node->dtype() == DataType::S64)
+ {
+ // Integer type tensor is not quantized
+ return;
+ }
// Only support recording of float32 values
if (tensor->element_type() != DataType::FLOAT32)
@@ -58,9 +68,6 @@ void MinMaxObserver::postTensorWrite(const luci::CircleNode *node,
// Exceptions that should be processed in backends
switch (node->opcode())
{
- case luci::CircleOpcode::ARG_MAX:
- // Output of arg_max is the index of the largest value across axes of a tensor.
- // It always has integer type.
case luci::CircleOpcode::CAST:
// Cast is quantized only if it converts <type> -> float.
// Other cases should be processed in backends.
diff --git a/compiler/record-minmax/src/RecordMinMax.cpp b/compiler/record-minmax/src/RecordMinMax.cpp
index c249960f8..10a14516f 100644
--- a/compiler/record-minmax/src/RecordMinMax.cpp
+++ b/compiler/record-minmax/src/RecordMinMax.cpp
@@ -17,12 +17,12 @@
#include "RecordMinMax.h"
#include "RecordFunction.h"
#include "MinMaxObserver.h"
-#include "HDF5Importer.h"
#include <luci/Importer.h>
#include <luci/CircleExporter.h>
#include <luci/CircleFileExpContract.h>
#include <luci/IR/CircleQuantParam.h>
+#include <dio_hdf5/HDF5Importer.h>
#include <dirent.h>
#include <algorithm>
@@ -33,12 +33,34 @@
#include <iostream>
#include <random>
-using Shape = luci_interpreter::Shape;
-using DataType = luci_interpreter::DataType;
+using Shape = std::vector<loco::Dimension>;
+using DataType = loco::DataType;
namespace
{
+uint32_t numElements(const luci::CircleNode *node)
+{
+ uint32_t num_elements = 1;
+ for (uint32_t i = 0; i < node->rank(); i++)
+ num_elements *= node->dim(i).value();
+
+ return num_elements;
+}
+
+// Throw exception if input has one of the following conditions.
+// 1. Have unknown dimension
+// 2. Number of elements is 0
+void checkInputDimension(const luci::CircleInput *input)
+{
+ for (uint32_t i = 0; i < input->rank(); i++)
+ if (!input->dim(i).known())
+ throw std::runtime_error(input->name() + " has unknown dimension");
+
+ if (numElements(input) == 0)
+ throw std::runtime_error(input->name() + " is a zero-sized input");
+}
+
void readDataFromFile(const std::string &filename, std::vector<char> &data, size_t data_size)
{
assert(data.size() == data_size); // FIX_CALLER_UNLESS
@@ -62,6 +84,21 @@ std::vector<uint8_t> genRandomBoolData(std::mt19937 &gen, uint32_t num_elements)
return input_data;
}
+template <typename T>
+std::vector<T> genRandomIntData(std::mt19937 &gen, uint32_t num_elements, T min, T max)
+{
+ std::uniform_int_distribution<T> dist(min, max);
+ std::vector<T> input_data(num_elements);
+
+ // Write random data
+ {
+ auto const generator = [&gen, &dist]() { return dist(gen); };
+ std::generate(begin(input_data), end(input_data), generator);
+ }
+
+ return input_data;
+}
+
/**
* @brief getTensorSize will return size in bytes
*/
@@ -83,12 +120,12 @@ void verifyTypeShape(const luci::CircleInput *input_node, const DataType &dtype,
if (dtype != input_node->dtype())
throw std::runtime_error("Wrong input type.");
- if (shape.num_dims() != input_node->rank())
+ if (shape.size() != input_node->rank())
throw std::runtime_error("Input rank mismatch.");
- for (uint32_t i = 0; i < shape.num_dims(); i++)
+ for (uint32_t i = 0; i < shape.size(); i++)
{
- if (shape.dim(i) != input_node->dim(i).value())
+ if (not(shape.at(i) == input_node->dim(i)))
throw std::runtime_error("Input shape mismatch.");
}
}
@@ -188,6 +225,7 @@ void RecordMinMax::profileRawDataDirectory(const std::string &mode,
for (auto input : input_nodes)
{
const auto *input_node = loco::must_cast<const luci::CircleInput *>(input);
+ checkInputDimension(input_node);
total_input_size += getTensorSize(input_node);
}
@@ -254,6 +292,7 @@ void RecordMinMax::profileRawData(const std::string &mode, const std::string &in
for (auto input : input_nodes)
{
const auto *input_node = loco::must_cast<const luci::CircleInput *>(input);
+ checkInputDimension(input_node);
total_input_size += getTensorSize(input_node);
}
@@ -296,12 +335,12 @@ void RecordMinMax::profileData(const std::string &mode, const std::string &input
{
try
{
- HDF5Importer importer(input_data_path);
- importer.importGroup();
+ dio::hdf5::HDF5Importer importer(input_data_path);
+ importer.importGroup("value");
bool is_raw_data = importer.isRawData();
- const auto num_records = importer.numRecords();
+ const auto num_records = importer.numData();
if (num_records == 0)
throw std::runtime_error("The input data file does not contain any record.");
@@ -319,12 +358,13 @@ void RecordMinMax::profileData(const std::string &mode, const std::string &input
{
const auto *input_node = loco::must_cast<const luci::CircleInput *>(input_nodes[input_idx]);
assert(input_node->index() == input_idx);
+ checkInputDimension(input_node);
std::vector<char> input_data(getTensorSize(input_node));
if (!is_raw_data)
{
DataType dtype;
- Shape shape(input_node->rank());
+ Shape shape;
importer.readTensor(record_idx, input_idx, &dtype, &shape, input_data.data());
// Check the type and the shape of the input data is valid
@@ -376,43 +416,47 @@ void RecordMinMax::profileDataWithRandomInputs(const std::string &mode, float mi
{
const auto *input_node = loco::must_cast<const luci::CircleInput *>(input_nodes[input_idx]);
assert(input_node->index() == input_idx);
- uint32_t num_elements = 1;
- for (uint32_t i = 0; i < input_node->rank(); i++)
- {
- if (!input_node->dim(i).known())
- throw std::runtime_error("Input dimension must be known");
+ checkInputDimension(input_node);
- num_elements *= input_node->dim(i).value();
- }
-
- if (num_elements == 0)
- throw std::runtime_error("Only support non-zero sized inputs");
+ const auto num_elements = numElements(input_node);
// TODO Support more input data types
assert(input_node->dtype() == loco::DataType::FLOAT32 ||
- input_node->dtype() == loco::DataType::BOOL);
+ input_node->dtype() == loco::DataType::BOOL ||
+ input_node->dtype() == loco::DataType::S32 ||
+ input_node->dtype() == loco::DataType::S64);
if (input_node->dtype() == DataType::FLOAT32)
- // clang-format off
{
- std::vector<float> input_data(num_elements);
+ std::vector<float> input_data(num_elements);
- // Write random data
- for (auto &iter : input_data)
- iter = static_cast<float>(dist(gen));
+ // Write random data
+ for (auto &iter : input_data)
+ iter = static_cast<float>(dist(gen));
- // TODO: Input data is copied twice (file -> buffer (input_data) -> interpreter inputs)
- // We can redcue the copy by directly writing data from file to interpreter inputs
- _interpreter->writeInputTensor(input_node, input_data.data(),
- input_data.size() * sizeof(float));
+ // TODO: Input data is copied twice (file -> buffer (input_data) -> interpreter inputs)
+ // We can redcue the copy by directly writing data from file to interpreter inputs
+ _interpreter->writeInputTensor(input_node, input_data.data(),
+ input_data.size() * sizeof(float));
}
- // clang-format on
else if (input_node->dtype() == DataType::BOOL)
{
auto input_data = genRandomBoolData(gen, num_elements);
_interpreter->writeInputTensor(input_node, input_data.data(),
input_data.size() * sizeof(uint8_t));
}
+ else if (input_node->dtype() == DataType::S32)
+ {
+ auto input_data = genRandomIntData<int32_t>(gen, num_elements, 0, 100);
+ _interpreter->writeInputTensor(input_node, input_data.data(),
+ input_data.size() * sizeof(int32_t));
+ }
+ else if (input_node->dtype() == DataType::S64)
+ {
+ auto input_data = genRandomIntData<int64_t>(gen, num_elements, 0, 100);
+ _interpreter->writeInputTensor(input_node, input_data.data(),
+ input_data.size() * sizeof(int64_t));
+ }
}
_interpreter->interpret();
diff --git a/compiler/souschef/CMakeLists.txt b/compiler/souschef/CMakeLists.txt
index ca7eddc6f..f57102f1f 100644
--- a/compiler/souschef/CMakeLists.txt
+++ b/compiler/souschef/CMakeLists.txt
@@ -1,7 +1,7 @@
nnas_find_package(Protobuf QUIET)
if(NOT Protobuf_FOUND)
- message(STATUS "Build souschef: FAILED (missing Protobuf")
+ message(STATUS "Build souschef: FAILED (missing Protobuf)")
return()
endif(NOT Protobuf_FOUND)
diff --git a/compiler/tf2tfliteV2-conversion-test/CMakeLists.txt b/compiler/tf2tfliteV2-conversion-test/CMakeLists.txt
index 3e7e57747..0b4739374 100644
--- a/compiler/tf2tfliteV2-conversion-test/CMakeLists.txt
+++ b/compiler/tf2tfliteV2-conversion-test/CMakeLists.txt
@@ -72,7 +72,7 @@ list(APPEND TEST_DEPS "${TEST_RUNNER}")
get_target_property(ARTIFACTS_BIN_PATH testDataGenerator BINARY_DIR)
-set(VIRTUALENV "${NNCC_OVERLAY_DIR}/venv_1_13_2")
+set(VIRTUALENV "${NNCC_OVERLAY_DIR}/venv_2_8_0")
###
### Generate test.config
diff --git a/compiler/tfl-inspect/CMakeLists.txt b/compiler/tfl-inspect/CMakeLists.txt
index 6ba55c357..9e1cb720f 100644
--- a/compiler/tfl-inspect/CMakeLists.txt
+++ b/compiler/tfl-inspect/CMakeLists.txt
@@ -10,5 +10,6 @@ add_executable(tfl-inspect ${DRIVER} ${SOURCES})
target_include_directories(tfl-inspect PRIVATE src)
target_link_libraries(tfl-inspect arser)
target_link_libraries(tfl-inspect foder)
-target_link_libraries(tfl-inspect mio_tflite260)
+target_link_libraries(tfl-inspect mio_tflite280)
+target_link_libraries(tfl-inspect mio_tflite280_helper)
target_link_libraries(tfl-inspect safemain)
diff --git a/compiler/tfl-inspect/requires.cmake b/compiler/tfl-inspect/requires.cmake
index 9a7477b81..a11f6b200 100644
--- a/compiler/tfl-inspect/requires.cmake
+++ b/compiler/tfl-inspect/requires.cmake
@@ -1,4 +1,4 @@
require("arser")
require("foder")
-require("mio-tflite260")
+require("mio-tflite280")
require("safemain")
diff --git a/compiler/tfl-inspect/src/Reader.cpp b/compiler/tfl-inspect/src/Reader.cpp
index 41a8396bb..6c4529516 100644
--- a/compiler/tfl-inspect/src/Reader.cpp
+++ b/compiler/tfl-inspect/src/Reader.cpp
@@ -16,6 +16,8 @@
#include "Reader.h"
+#include <mio_tflite280/Helper.h>
+
#include <cassert>
#include <sstream>
#include <string>
@@ -23,72 +25,6 @@
namespace tflinspect
{
-// This will provide v3/v3a format neutral BuiltinOperator
-tflite::BuiltinOperator builtin_code_neutral(const tflite::OperatorCode *opcode)
-{
- assert(opcode != nullptr);
- int8_t dp_code = opcode->deprecated_builtin_code();
- // 127 is max of int8_t which is upper bound of v3 builtin_code
- // NOTE TensorFlow uses 'BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES' for 127
- if (dp_code < 127 && dp_code >= 0)
- return tflite::BuiltinOperator(dp_code);
- return opcode->builtin_code();
-}
-
-bool is_valid(const tflite::OperatorCode *opcode)
-{
- tflite::BuiltinOperator code = builtin_code_neutral(opcode);
- return (tflite::BuiltinOperator_MIN <= code && code <= tflite::BuiltinOperator_MAX);
-}
-
-bool is_custom(const tflite::OperatorCode *opcode)
-{
- tflite::BuiltinOperator code = builtin_code_neutral(opcode);
- return (code == tflite::BuiltinOperator_CUSTOM);
-}
-
-std::string opcode_name(const tflite::OperatorCode *opcode)
-{
- assert(opcode);
-
- if (!is_valid(opcode))
- {
- std::ostringstream oss;
- oss << "(invalid)";
- return oss.str();
- }
-
- if (is_custom(opcode))
- {
- if (!opcode->custom_code())
- return "(invalid custom)";
-
- std::string custom_op = "CUSTOM(";
- custom_op += opcode->custom_code()->c_str();
- custom_op += ")";
- return custom_op;
- }
-
- tflite::BuiltinOperator code = builtin_code_neutral(opcode);
- return tflite::EnumNameBuiltinOperator(code);
-}
-
-const char *tensor_type(const tflite::Tensor *tensor)
-{
- return tflite::EnumNameTensorType(tensor->type());
-}
-
-const char *tensor_name(const tflite::Tensor *tensor)
-{
- static const char *kEmptyTensorName = "(noname)";
-
- auto name = tensor->name();
- if (name)
- return name->c_str();
-
- return kEmptyTensorName;
-}
-
Reader::Reader(const tflite::Model *model)
{
_subgraphs = model->subgraphs();
@@ -135,7 +71,7 @@ tflite::BuiltinOperator Reader::builtin_code(const tflite::Operator *op) const
assert(index < _op_codes.size());
const tflite::OperatorCode *opcode = _op_codes.at(index);
- return tflinspect::builtin_code_neutral(opcode);
+ return mio::tflite::builtin_code_neutral(opcode);
}
std::string Reader::opcode_name(const tflite::Operator *op) const
@@ -144,14 +80,14 @@ std::string Reader::opcode_name(const tflite::Operator *op) const
assert(index < _op_codes.size());
const tflite::OperatorCode *opcode = _op_codes.at(index);
- if (!is_valid(opcode))
+ if (!mio::tflite::is_valid(opcode))
{
std::ostringstream oss;
oss << "(invalid: " << index << ")";
return oss.str();
}
- return tflinspect::opcode_name(opcode);
+ return mio::tflite::opcode_name(opcode);
}
bool Reader::select_subgraph(uint32_t sgindex)
diff --git a/compiler/tfl-inspect/src/Reader.h b/compiler/tfl-inspect/src/Reader.h
index 91b7bb940..98554cf85 100644
--- a/compiler/tfl-inspect/src/Reader.h
+++ b/compiler/tfl-inspect/src/Reader.h
@@ -36,13 +36,6 @@ template <typename T> std::vector<T> as_index_vector(const flatbuffers::Vector<T
return ret;
}
-tflite::BuiltinOperator builtin_code_neutral(const tflite::OperatorCode *opcode);
-bool is_valid(const tflite::OperatorCode *opcode);
-bool is_custom(const tflite::OperatorCode *opcode);
-std::string opcode_name(const tflite::OperatorCode *opcode);
-const char *tensor_type(const tflite::Tensor *tensor);
-const char *tensor_name(const tflite::Tensor *tensor);
-
/**
* @brief Loads TF lite file and provides helpers to access attributes
*/
diff --git a/compiler/tfl-verify/CMakeLists.txt b/compiler/tfl-verify/CMakeLists.txt
index a87d30c5e..2fba335ea 100644
--- a/compiler/tfl-verify/CMakeLists.txt
+++ b/compiler/tfl-verify/CMakeLists.txt
@@ -8,6 +8,6 @@ add_executable(tfl-verify ${SOURCES})
target_include_directories(tfl-verify PRIVATE src)
target_link_libraries(tfl-verify arser)
target_link_libraries(tfl-verify foder)
-target_link_libraries(tfl-verify mio_tflite260)
+target_link_libraries(tfl-verify mio_tflite280)
target_link_libraries(tfl-verify safemain)
target_link_libraries(tfl-verify cwrap)
diff --git a/compiler/tfl-verify/requires.cmake b/compiler/tfl-verify/requires.cmake
index 72803d890..b107bdfe7 100644
--- a/compiler/tfl-verify/requires.cmake
+++ b/compiler/tfl-verify/requires.cmake
@@ -1,5 +1,5 @@
require("arser")
require("foder")
-require("mio-tflite260")
+require("mio-tflite280")
require("safemain")
require("cwrap")
diff --git a/compiler/tflchef/CMakeLists.txt b/compiler/tflchef/CMakeLists.txt
index ac7fe4b7c..948b1cecd 100644
--- a/compiler/tflchef/CMakeLists.txt
+++ b/compiler/tflchef/CMakeLists.txt
@@ -5,10 +5,10 @@ if(NOT Protobuf_FOUND)
return()
endif(NOT Protobuf_FOUND)
-if(NOT TARGET mio_tflite260)
- message(STATUS "Build tflchef: FAILED (missing mio_tflite260)")
+if(NOT TARGET mio_tflite280)
+ message(STATUS "Build tflchef: FAILED (missing mio_tflite280)")
return()
-endif(NOT TARGET mio_tflite260)
+endif(NOT TARGET mio_tflite280)
# Recipe Parser
add_subdirectory(proto)
diff --git a/compiler/tflchef/core/CMakeLists.txt b/compiler/tflchef/core/CMakeLists.txt
index 413b78b15..6b6fed57b 100644
--- a/compiler/tflchef/core/CMakeLists.txt
+++ b/compiler/tflchef/core/CMakeLists.txt
@@ -5,5 +5,5 @@ target_include_directories(tflchef_core PUBLIC include)
target_include_directories(tflchef_core PRIVATE src)
target_link_libraries(tflchef_core tflchef_proto)
target_link_libraries(tflchef_core tflchef_log)
-target_link_libraries(tflchef_core mio_tflite260)
+target_link_libraries(tflchef_core mio_tflite280)
target_link_libraries(tflchef_core souschef)
diff --git a/compiler/tflchef/core/src/ModelChef.cpp b/compiler/tflchef/core/src/ModelChef.cpp
index ada5ff5d1..93b9334a6 100644
--- a/compiler/tflchef/core/src/ModelChef.cpp
+++ b/compiler/tflchef/core/src/ModelChef.cpp
@@ -722,15 +722,13 @@ GeneratedModel cook(const ::tflchef::ModelRecipe &model_recipe)
auto inputs = flatbuffer_builder->CreateVector(tensormap_inputs);
auto outputs = flatbuffer_builder->CreateVector(tensormap_outputs);
- auto method_name = flatbuffer_builder->CreateString(rec_signature_def.method_name());
- auto key = flatbuffer_builder->CreateString(rec_signature_def.key());
- // TODO add validation for method_name and key
+ auto signature_key = flatbuffer_builder->CreateString(rec_signature_def.signature_key());
+ // TODO add validation for signature_key
::tflite::SignatureDefBuilder signature_def_builder{*flatbuffer_builder};
signature_def_builder.add_inputs(inputs);
signature_def_builder.add_outputs(outputs);
- signature_def_builder.add_method_name(method_name);
- signature_def_builder.add_key(key);
+ signature_def_builder.add_signature_key(signature_key);
signature_def_builder.add_subgraph_index(rec_signature_def.subgraph_index());
signdef_vec.emplace_back(signature_def_builder.Finish());
diff --git a/compiler/tflchef/core/src/Op/FullyConnected.cpp b/compiler/tflchef/core/src/Op/FullyConnected.cpp
index 45269916c..7173a67ba 100644
--- a/compiler/tflchef/core/src/Op/FullyConnected.cpp
+++ b/compiler/tflchef/core/src/Op/FullyConnected.cpp
@@ -29,6 +29,7 @@ flatbuffers::Offset<void> FullyConnectedChef::value(flatbuffers::FlatBufferBuild
tflite::FullyConnectedOptionsBuilder fc_options_builder{fbb};
fc_options_builder.add_fused_activation_function(tflite_activation);
+ fc_options_builder.add_keep_num_dims(operation.fullyconnected_options().keep_num_dims());
return fc_options_builder.Finish().Union();
}
diff --git a/compiler/tflchef/core/src/Op/SVDF.cpp b/compiler/tflchef/core/src/Op/SVDF.cpp
new file mode 100644
index 000000000..690896cf1
--- /dev/null
+++ b/compiler/tflchef/core/src/Op/SVDF.cpp
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "SVDF.h"
+#include "Convert.h"
+
+#include <cassert>
+
+flatbuffers::Offset<void> SVDFChef::value(flatbuffers::FlatBufferBuilder &fbb) const
+{
+ assert(_operation->has_svdf_options());
+
+ const auto &svdf_options = _operation->svdf_options();
+
+ const auto tflite_activation = as_tflite_activation(svdf_options.activation());
+
+ tflite::SVDFOptionsBuilder svdf_options_builder{fbb};
+ svdf_options_builder.add_fused_activation_function(tflite_activation);
+ svdf_options_builder.add_asymmetric_quantize_inputs(svdf_options.asymmetric_quantize_inputs());
+ svdf_options_builder.add_rank(svdf_options.rank());
+
+ return svdf_options_builder.Finish().Union();
+}
+
+std::unique_ptr<OpChef> SVDFChefFactory::create(const tflchef::Operation *operation) const
+{
+ return std::unique_ptr<OpChef>{new SVDFChef{operation}};
+}
diff --git a/compiler/tflchef/core/src/Op/SVDF.h b/compiler/tflchef/core/src/Op/SVDF.h
new file mode 100644
index 000000000..9bf0b6efb
--- /dev/null
+++ b/compiler/tflchef/core/src/Op/SVDF.h
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __OP_SVDF_H__
+#define __OP_SVDF_H__
+
+#include "OpChef.h"
+
+class SVDFChef final : public OpChef
+{
+public:
+ explicit SVDFChef(const tflchef::Operation *operation) : _operation{operation}
+ {
+ // DO NOTHING
+ }
+
+public:
+ tflite::BuiltinOperator code(void) const override { return tflite::BuiltinOperator_SVDF; }
+
+ tflite::BuiltinOptions type(void) const override { return tflite::BuiltinOptions_SVDFOptions; }
+
+ flatbuffers::Offset<void> value(flatbuffers::FlatBufferBuilder &fbb) const override;
+
+private:
+ const tflchef::Operation *_operation;
+};
+
+struct SVDFChefFactory final : public OpChefFactory
+{
+ std::unique_ptr<OpChef> create(const tflchef::Operation *operation) const override;
+};
+
+#endif // __OP_SVDF_H__
diff --git a/compiler/tflchef/core/src/OpChef.def b/compiler/tflchef/core/src/OpChef.def
index b1e8a3829..beebd359f 100644
--- a/compiler/tflchef/core/src/OpChef.def
+++ b/compiler/tflchef/core/src/OpChef.def
@@ -104,6 +104,7 @@ OP_CHEF(Squeeze, SqueezeChefFactory)
OP_CHEF(StridedSlice, StridedSliceChefFactory)
OP_CHEF(Sub, SubChefFactory)
OP_CHEF(Sum, SumChefFactory)
+OP_CHEF(SVDF, SVDFChefFactory)
OP_CHEF(Tanh, TanhChefFactory)
OP_CHEF(Tile, TileChefFactory)
OP_CHEF(TopKV2, TopKV2ChefFactory)
diff --git a/compiler/tflchef/core/src/OpChefs.h b/compiler/tflchef/core/src/OpChefs.h
index 35688ba95..159019abf 100644
--- a/compiler/tflchef/core/src/OpChefs.h
+++ b/compiler/tflchef/core/src/OpChefs.h
@@ -117,6 +117,7 @@
#include "Op/StridedSlice.h"
#include "Op/Sub.h"
#include "Op/Sum.h"
+#include "Op/SVDF.h"
#include "Op/Tanh.h"
#include "Op/Tile.h"
#include "Op/TopKV2.h"
diff --git a/compiler/tflchef/proto/tflchef.proto b/compiler/tflchef/proto/tflchef.proto
index 4162cb123..1abefafe1 100644
--- a/compiler/tflchef/proto/tflchef.proto
+++ b/compiler/tflchef/proto/tflchef.proto
@@ -182,6 +182,7 @@ message FloorModOptions {
message FullyConnectedOptions {
optional Activation activation = 1 [default = NONE];
+ optional bool keep_num_dims = 2 [ default = false ];
}
message AddOptions {
@@ -366,6 +367,12 @@ message SquaredDifferenceOptions {
// None
}
+message SVDFOptions {
+ optional int32 rank = 1 [default = 0];
+ optional Activation activation = 2 [default = NONE];
+ optional bool asymmetric_quantize_inputs = 3 [default = false];
+}
+
message FillOptions {
// None
}
@@ -589,7 +596,7 @@ message Operation {
optional ZerosLikeOptions zeros_like_options = 153;
// ConcatEmbeddingsOptions 154
// LSHProjectionOptions 155
- // SVDFOptions 156
+ optional SVDFOptions svdf_options = 156;
// RNNOptions 157
optional L2NormOptions l2norm_options = 158;
optional LocalResponseNormalizationOptions local_response_normalization_options = 159;
@@ -658,8 +665,8 @@ message TensorMap {
message SignatureDef {
repeated TensorMap inputs = 4;
repeated TensorMap outputs = 5;
- optional string method_name = 6;
- optional string key = 10;
+ optional string signature_key = 6;
+ // optional string key = 10; obsolete in TF2.8.0
optional uint32 subgraph_index = 12;
}
diff --git a/compiler/tflchef/requires.cmake b/compiler/tflchef/requires.cmake
index 78bfa2d07..a01da4258 100644
--- a/compiler/tflchef/requires.cmake
+++ b/compiler/tflchef/requires.cmake
@@ -1,7 +1,7 @@
require("arser")
require("nnkit")
require("cwrap")
-require("mio-tflite260")
+require("mio-tflite280")
require("safemain")
require("hermes")
require("hermes-std")
diff --git a/compiler/tflchef/tests/CMakeLists.txt b/compiler/tflchef/tests/CMakeLists.txt
index 5c4dff012..26cf67f4f 100644
--- a/compiler/tflchef/tests/CMakeLists.txt
+++ b/compiler/tflchef/tests/CMakeLists.txt
@@ -1,10 +1,11 @@
-if(NOT TARGET nnkit-run)
- return()
-endif(NOT TARGET nnkit-run)
-
-if(NOT TARGET nnkit_tflite_backend)
- return()
-endif(NOT TARGET nnkit_tflite_backend)
+set(TFLCHEF_FILE_PATH $<TARGET_FILE:tflchef-file>)
+set(TFLCHEF_REVERSE_PATH $<TARGET_FILE:tflchef-reverse>)
+if(DEFINED ENV{BUILD_HOST_EXEC})
+ # TODO use better way to represent path for host executable
+ set(TFLCHEF_FILE_PATH $ENV{BUILD_HOST_EXEC}/compiler/tflchef/tools/file/tflchef-file)
+ set(TFLCHEF_REVERSE_PATH $ENV{BUILD_HOST_EXEC}/compiler/tflchef/tools/reverse/tflchef-reverse)
+ message(STATUS "TFLCHEF_FILE_PATH = ${TFLCHEF_FILE_PATH}")
+endif(DEFINED ENV{BUILD_HOST_EXEC})
nncc_find_resource(TensorFlowLiteRecipes)
set(TENSORFLOWLITERECIPES_DIR "${TensorFlowLiteRecipes_DIR}")
@@ -26,8 +27,8 @@ foreach(RECIPE IN ITEMS ${RECIPES})
# Generate .tflite
add_custom_command(OUTPUT ${RECIPE_OUTPUT_FILE}
- COMMAND tflchef-file ${RECIPE_SOURCE_FILE} ${RECIPE_OUTPUT_FILE}
- DEPENDS tflchef-file ${RECIPE_SOURCE_FILE}
+ COMMAND ${TFLCHEF_FILE_PATH} ${RECIPE_SOURCE_FILE} ${RECIPE_OUTPUT_FILE}
+ DEPENDS ${TFLCHEF_FILE_PATH} ${RECIPE_SOURCE_FILE}
COMMENT "Generating ${RECIPE_OUTPUT_FILE}")
list(APPEND TESTS ${RECIPE_PREFIX})
@@ -52,8 +53,8 @@ foreach(RECIPE IN ITEMS ${RECIPES})
# Generate .tflite
add_custom_command(OUTPUT ${RECIPE_OUTPUT_FILE}
- COMMAND tflchef-file ${RECIPE_SOURCE_FILE} ${RECIPE_OUTPUT_FILE}
- DEPENDS tflchef-file ${RECIPE_SOURCE_FILE}
+ COMMAND ${TFLCHEF_FILE_PATH} ${RECIPE_SOURCE_FILE} ${RECIPE_OUTPUT_FILE}
+ DEPENDS ${TFLCHEF_FILE_PATH} ${RECIPE_SOURCE_FILE}
COMMENT "Generating ${RECIPE_OUTPUT_FILE}")
list(APPEND TESTS ${RECIPE_PREFIX})
@@ -76,16 +77,16 @@ foreach(TFLITEFILE IN ITEMS ${GEN_TFLITEFILES})
# Generate .gen.recipe from generated .tflite
add_custom_command(OUTPUT ${RECIPE_GEN_OUTPUT_FILE}
- COMMAND tflchef-reverse ${RECIPE_OUTPUT_FILE} ${RECIPE_GEN_OUTPUT_FILE}
- DEPENDS tflchef-reverse ${RECIPE_OUTPUT_FILE}
+ COMMAND ${TFLCHEF_REVERSE_PATH} ${RECIPE_OUTPUT_FILE} ${RECIPE_GEN_OUTPUT_FILE}
+ DEPENDS ${TFLCHEF_REVERSE_PATH} ${RECIPE_OUTPUT_FILE}
COMMENT "Generating ${RECIPE_GEN_OUTPUT_FILE}")
# now we are going to generate .gen.tflite from .gen.recipe
# to check generated .gen.recipe file is correct by using it.
# as weight values may be different, binary comparision is not acceptable.
add_custom_command(OUTPUT ${RECIPE_GEN_OUTPUT_FILE2}
- COMMAND tflchef-file ${RECIPE_GEN_OUTPUT_FILE} ${RECIPE_GEN_OUTPUT_FILE2}
- DEPENDS tflchef-file ${RECIPE_GEN_OUTPUT_FILE}
+ COMMAND ${TFLCHEF_FILE_PATH} ${RECIPE_GEN_OUTPUT_FILE} ${RECIPE_GEN_OUTPUT_FILE2}
+ DEPENDS ${TFLCHEF_FILE_PATH} ${RECIPE_GEN_OUTPUT_FILE}
COMMENT "Generating ${RECIPE_GEN_OUTPUT_FILE2}")
list(APPEND TESTS ${TFLITE_PREFIX}.gen)
@@ -104,13 +105,13 @@ foreach(TFLITEFILE IN ITEMS ${GEN_TFLITEFILES})
# Generate .gen.recipe from generated .tflite
add_custom_command(OUTPUT ${RECIPE_GEN_OUTPUT_FILE}
- COMMAND tflchef-reverse ${RECIPE_OUTPUT_FILE} ${RECIPE_GEN_OUTPUT_FILE}
- DEPENDS tflchef-reverse ${RECIPE_OUTPUT_FILE}
+ COMMAND ${TFLCHEF_REVERSE_PATH} ${RECIPE_OUTPUT_FILE} ${RECIPE_GEN_OUTPUT_FILE}
+ DEPENDS ${TFLCHEF_REVERSE_PATH} ${RECIPE_OUTPUT_FILE}
COMMENT "Generating ${RECIPE_GEN_OUTPUT_FILE}")
add_custom_command(OUTPUT ${RECIPE_GEN_OUTPUT_FILE2}
- COMMAND tflchef-file ${RECIPE_GEN_OUTPUT_FILE} ${RECIPE_GEN_OUTPUT_FILE2}
- DEPENDS tflchef-file ${RECIPE_GEN_OUTPUT_FILE}
+ COMMAND ${TFLCHEF_FILE_PATH} ${RECIPE_GEN_OUTPUT_FILE} ${RECIPE_GEN_OUTPUT_FILE2}
+ DEPENDS ${TFLCHEF_FILE_PATH} ${RECIPE_GEN_OUTPUT_FILE}
COMMENT "Generating ${RECIPE_GEN_OUTPUT_FILE2}")
list(APPEND TESTS ${TFLITE_PREFIX}.gen)
@@ -123,7 +124,9 @@ add_custom_target(tflchef_testfiles ALL DEPENDS ${TESTFILES})
# Using mio_tflite_validate for temporary as it only calls flatbuffer validate
# TODO do testing with running the model with runtime/interpreter
+# NOTE for ARM32 cross build, $<TARGET_FILE:mio_tflite280_validate> is used as-is
+# as test should run in ARM32 device
add_test(NAME tflchef_test
COMMAND "${CMAKE_CURRENT_SOURCE_DIR}/runvalidate.sh"
- $<TARGET_FILE:mio_tflite_validate>
+ $<TARGET_FILE:mio_tflite280_validate>
${TESTS})
diff --git a/compiler/tflchef/tests/signature_def_index/test.recipe b/compiler/tflchef/tests/signature_def_index/test.recipe
index 4481752ef..9e95edf00 100644
--- a/compiler/tflchef/tests/signature_def_index/test.recipe
+++ b/compiler/tflchef/tests/signature_def_index/test.recipe
@@ -50,8 +50,7 @@ signature_def {
name: "ofm1"
tensor_index: 1
}
- method_name: "serving_default"
- key: "serv"
+ signature_key: "serving_default"
subgraph_index: 0
}
input: "ifm"
diff --git a/compiler/tflchef/tests/signature_def_name/test.recipe b/compiler/tflchef/tests/signature_def_name/test.recipe
index 79be25138..4847f7dd8 100644
--- a/compiler/tflchef/tests/signature_def_name/test.recipe
+++ b/compiler/tflchef/tests/signature_def_name/test.recipe
@@ -50,8 +50,7 @@ signature_def {
name: "out1"
tensor: "ofm1"
}
- method_name: "serving_default"
- key: "serv"
+ signature_key: "serving_default"
subgraph_index: 0
}
input: "ifm"
diff --git a/compiler/tflchef/tflite/CMakeLists.txt b/compiler/tflchef/tflite/CMakeLists.txt
index 3c4c3fff6..3c3352b0a 100644
--- a/compiler/tflchef/tflite/CMakeLists.txt
+++ b/compiler/tflchef/tflite/CMakeLists.txt
@@ -4,6 +4,7 @@ add_library(tflchef_tflite STATIC ${SOURCES})
target_include_directories(tflchef_tflite PUBLIC include)
target_include_directories(tflchef_tflite PRIVATE src)
target_link_libraries(tflchef_tflite tflchef_proto)
-target_link_libraries(tflchef_tflite mio_tflite260)
+target_link_libraries(tflchef_tflite mio_tflite280)
+target_link_libraries(tflchef_tflite mio_tflite280_helper)
target_link_libraries(tflchef_tflite cwrap)
target_link_libraries(tflchef_tflite souschef)
diff --git a/compiler/tflchef/tflite/src/Op/FullyConnected.cpp b/compiler/tflchef/tflite/src/Op/FullyConnected.cpp
index 1f6e73aa6..bbc749fe4 100644
--- a/compiler/tflchef/tflite/src/Op/FullyConnected.cpp
+++ b/compiler/tflchef/tflite/src/Op/FullyConnected.cpp
@@ -48,6 +48,7 @@ tflchef::Operation *TFliteOpFullyConnected::build(const tflite::Operator *op, TF
auto op_options = operation->mutable_fullyconnected_options();
op_options->set_activation(as_tflchef_activation(op_params->fused_activation_function()));
+ op_options->set_keep_num_dims(op_params->keep_num_dims());
return operation;
}
diff --git a/compiler/tflchef/tflite/src/Op/SVDF.cpp b/compiler/tflchef/tflite/src/Op/SVDF.cpp
new file mode 100644
index 000000000..015f968a8
--- /dev/null
+++ b/compiler/tflchef/tflite/src/Op/SVDF.cpp
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "SVDF.h"
+
+#include "Convert.h"
+
+namespace tflchef
+{
+
+void TFliteOpSVDF::filler(const tflite::Operator *op, TFliteImport *import,
+ tflchef::ModelRecipe *model_recipe) const
+{
+ const std::vector<int32_t> &inputs = as_index_vector(op->inputs());
+ assert(inputs.size() == 5);
+
+ // optional input tensor idx has minus value.
+ const bool hasBias = (inputs.at(3) >= 0);
+
+ // Note: last input is variable tensor without data
+ import->set_tensor_filler(inputs.at(1));
+ import->set_tensor_filler(inputs.at(2));
+ if (hasBias)
+ import->set_tensor_filler(inputs.at(3));
+}
+
+tflchef::Operation *TFliteOpSVDF::build(const tflite::Operator *op, TFliteImport *import,
+ tflchef::ModelRecipe *model_recipe) const
+{
+ const auto op_params = op->builtin_options_as_SVDFOptions();
+ assert(op_params != nullptr);
+
+ auto operation = model_recipe->add_operation();
+
+ operation->set_type("SVDF");
+
+ auto op_options = operation->mutable_svdf_options();
+
+ op_options->set_activation(as_tflchef_activation(op_params->fused_activation_function()));
+ op_options->set_asymmetric_quantize_inputs(op_params->asymmetric_quantize_inputs());
+ op_options->set_rank(op_params->rank());
+
+ return operation;
+}
+
+} // namespace tflchef
diff --git a/compiler/tflchef/tflite/src/Op/SVDF.h b/compiler/tflchef/tflite/src/Op/SVDF.h
new file mode 100644
index 000000000..a59ca54a2
--- /dev/null
+++ b/compiler/tflchef/tflite/src/Op/SVDF.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __TFLITE_OP_SVDF_H__
+#define __TFLITE_OP_SVDF_H__
+
+#include "TFliteOpChef.h"
+
+namespace tflchef
+{
+
+/**
+ * @brief tflchef operator builder for SVDF
+ */
+class TFliteOpSVDF : public TFliteOpChef
+{
+public:
+ void filler(const tflite::Operator *op, TFliteImport *import,
+ tflchef::ModelRecipe *model_recipe) const override;
+ tflchef::Operation *build(const tflite::Operator *op, TFliteImport *import,
+ tflchef::ModelRecipe *model_recipe) const override;
+};
+
+} // namespace tflchef
+
+#endif // __TFLITE_OP_SVDF_H__
diff --git a/compiler/tflchef/tflite/src/RecipeChef.cpp b/compiler/tflchef/tflite/src/RecipeChef.cpp
index d9215a4c4..0701707c1 100644
--- a/compiler/tflchef/tflite/src/RecipeChef.cpp
+++ b/compiler/tflchef/tflite/src/RecipeChef.cpp
@@ -15,6 +15,7 @@
*/
#include <tflchef/RecipeChef.h>
+#include <mio_tflite280/Helper.h>
#include "Convert.h"
#include "TFliteImport.h"
@@ -42,7 +43,7 @@ void set_inputs(TFliteImport *import, tflchef::Operation *operation, const tflit
else
{
auto tensor = tensors->Get(input);
- std::string name = tensor_name(tensor);
+ std::string name = mio::tflite::tensor_name(tensor);
operation->add_input(name);
}
}
@@ -56,7 +57,7 @@ void set_outputs(TFliteImport *import, tflchef::Operation *operation, const tfli
for (auto output : outputs)
{
auto tensor = tensors->Get(output);
- std::string name = tensor_name(tensor);
+ std::string name = mio::tflite::tensor_name(tensor);
operation->add_output(name);
}
}
@@ -108,7 +109,7 @@ std::unique_ptr<ModelRecipe> generate_recipe(const tflite::Model *model)
::tflchef::Operand *operand = model_recipe->add_operand();
- operand->set_name(tensor_name(tensor));
+ operand->set_name(mio::tflite::tensor_name(tensor));
operand->set_type(as_tflchef_type(tensor->type()));
operand->set_is_variable(tensor->is_variable());
@@ -311,14 +312,14 @@ std::unique_ptr<ModelRecipe> generate_recipe(const tflite::Model *model)
for (const auto input : inputs)
{
auto tensor = tensors->Get(input);
- std::string name = tensor_name(tensor);
+ std::string name = mio::tflite::tensor_name(tensor);
model_recipe->add_input(name);
}
for (const auto output : outputs)
{
auto tensor = tensors->Get(output);
- std::string name = tensor_name(tensor);
+ std::string name = mio::tflite::tensor_name(tensor);
model_recipe->add_output(name);
}
diff --git a/compiler/tflchef/tflite/src/TFliteImport.cpp b/compiler/tflchef/tflite/src/TFliteImport.cpp
index 1462ee7f4..7114ab019 100644
--- a/compiler/tflchef/tflite/src/TFliteImport.cpp
+++ b/compiler/tflchef/tflite/src/TFliteImport.cpp
@@ -18,50 +18,13 @@
#include "Convert.h"
+#include <mio_tflite280/Helper.h>
+
#include <sstream>
namespace tflchef
{
-const char *kEmptyTensorName = "(noname)";
-
-const char *tensor_type(const tflite::Tensor *tensor)
-{
- return tflite::EnumNameTensorType(tensor->type());
-}
-
-const char *tensor_name(const tflite::Tensor *tensor)
-{
- auto name = tensor->name();
- if (name)
- return name->c_str();
- return kEmptyTensorName;
-}
-
-// This will provide v3/v3a format neutral BuiltinOperator
-tflite::BuiltinOperator builtin_code_neutral(const tflite::OperatorCode *opcode)
-{
- assert(opcode != nullptr);
- int8_t dp_code = opcode->deprecated_builtin_code();
- // 127 is max of int8_t which is upper bound of v3 builtin_code
- // NOTE TensorFlow uses 'BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES' for 127
- if (dp_code < 127 && dp_code >= 0)
- return tflite::BuiltinOperator(dp_code);
- return opcode->builtin_code();
-}
-
-bool is_valid(const tflite::OperatorCode *opcode)
-{
- tflite::BuiltinOperator code = builtin_code_neutral(opcode);
- return (tflite::BuiltinOperator_MIN <= code && code <= tflite::BuiltinOperator_MAX);
-}
-
-bool is_custom(const tflite::OperatorCode *opcode)
-{
- tflite::BuiltinOperator code = builtin_code_neutral(opcode);
- return (code == tflite::BuiltinOperator_CUSTOM);
-}
-
TFliteImport::TFliteImport(const tflite::Model *model)
{
_subgraphs = model->subgraphs();
@@ -104,7 +67,7 @@ tflite::BuiltinOperator TFliteImport::builtin_code(const tflite::Operator *op) c
assert(index < _op_codes.size());
const tflite::OperatorCode *opcode = _op_codes.at(index);
- return builtin_code_neutral(opcode);
+ return mio::tflite::builtin_code_neutral(opcode);
}
std::string TFliteImport::opcode_name(const tflite::Operator *op) const
@@ -113,14 +76,14 @@ std::string TFliteImport::opcode_name(const tflite::Operator *op) const
assert(index < _op_codes.size());
const tflite::OperatorCode *opcode = _op_codes.at(index);
- if (!is_valid(opcode))
+ if (!mio::tflite::is_valid(opcode))
{
std::ostringstream oss;
oss << "(invalid: " << index << ")";
return oss.str();
}
- if (is_custom(opcode))
+ if (mio::tflite::is_custom(opcode))
{
if (!opcode->custom_code())
return "(invalid custom)";
@@ -128,7 +91,7 @@ std::string TFliteImport::opcode_name(const tflite::Operator *op) const
return opcode->custom_code()->c_str();
}
- tflite::BuiltinOperator code = builtin_code_neutral(opcode);
+ tflite::BuiltinOperator code = mio::tflite::builtin_code_neutral(opcode);
return EnumNameBuiltinOperator(code);
}
diff --git a/compiler/tflchef/tflite/src/TFliteImport.h b/compiler/tflchef/tflite/src/TFliteImport.h
index 43b5bbaff..e6722e455 100644
--- a/compiler/tflchef/tflite/src/TFliteImport.h
+++ b/compiler/tflchef/tflite/src/TFliteImport.h
@@ -34,12 +34,6 @@ using TFliteTensors_t = flatbuffers::Vector<flatbuffers::Offset<tflite::Tensor>>
using TFliteBuffers_t = flatbuffers::Vector<flatbuffers::Offset<tflite::Buffer>>;
using TFliteOperators_t = flatbuffers::Vector<flatbuffers::Offset<tflite::Operator>>;
-const char *tensor_type(const tflite::Tensor *tensor);
-const char *tensor_name(const tflite::Tensor *tensor);
-tflite::BuiltinOperator builtin_code_neutral(const tflite::OperatorCode *opcode);
-bool is_valid(const tflite::OperatorCode *opcode);
-bool is_custom(const tflite::OperatorCode *opcode);
-
/**
* @brief Loads TF lite file and provides helpers to access attributes
*/
diff --git a/compiler/tflchef/tflite/src/TFliteOpChefs.h b/compiler/tflchef/tflite/src/TFliteOpChefs.h
index 26ada7d0a..b38b35a61 100644
--- a/compiler/tflchef/tflite/src/TFliteOpChefs.h
+++ b/compiler/tflchef/tflite/src/TFliteOpChefs.h
@@ -117,6 +117,7 @@
#include "Op/StridedSlice.h"
#include "Op/Sub.h"
#include "Op/Sum.h"
+#include "Op/SVDF.h"
#include "Op/Tanh.h"
#include "Op/Tile.h"
#include "Op/TopKV2.h"
diff --git a/compiler/tflchef/tflite/src/TFliteOpRegistry.h b/compiler/tflchef/tflite/src/TFliteOpRegistry.h
index 06394ddfa..4cbe7cfcb 100644
--- a/compiler/tflchef/tflite/src/TFliteOpRegistry.h
+++ b/compiler/tflchef/tflite/src/TFliteOpRegistry.h
@@ -154,6 +154,7 @@ private:
REG_TFL_OP(STRIDED_SLICE, TFliteOpStridedSlice);
REG_TFL_OP(SUB, TFliteOpSub);
REG_TFL_OP(SUM, TFliteOpSum);
+ REG_TFL_OP(SVDF, TFliteOpSVDF);
REG_TFL_OP(TANH, TFliteOpTanh);
REG_TFL_OP(TILE, TFliteOpTile);
REG_TFL_OP(TOPK_V2, TFliteOpTopKV2);
diff --git a/compiler/tfldump/CMakeLists.txt b/compiler/tfldump/CMakeLists.txt
index 83f7febad..fac0be6bf 100644
--- a/compiler/tfldump/CMakeLists.txt
+++ b/compiler/tfldump/CMakeLists.txt
@@ -1,7 +1,7 @@
-if(NOT TARGET mio_tflite260)
- message(STATUS "Build tfldump: FAILED (missing mio_tflite260)")
+if(NOT TARGET mio_tflite280)
+ message(STATUS "Build tfldump: FAILED (missing mio_tflite280)")
return()
-endif(NOT TARGET mio_tflite260)
+endif(NOT TARGET mio_tflite280)
set(DRIVER "driver/Driver.cpp")
@@ -10,6 +10,6 @@ file(GLOB_RECURSE SOURCES "src/*.cpp")
add_executable(tfldump ${DRIVER} ${SOURCES})
target_include_directories(tfldump PRIVATE include)
target_link_libraries(tfldump arser)
-target_link_libraries(tfldump mio_tflite260)
+target_link_libraries(tfldump mio_tflite280)
+target_link_libraries(tfldump mio_tflite280_helper)
target_link_libraries(tfldump safemain)
-target_link_libraries(tfldump flatbuffers-1.12)
diff --git a/compiler/tfldump/requires.cmake b/compiler/tfldump/requires.cmake
index d0f9cccba..b1abf9486 100644
--- a/compiler/tfldump/requires.cmake
+++ b/compiler/tfldump/requires.cmake
@@ -1,3 +1,3 @@
require("arser")
-require("mio-tflite260")
+require("mio-tflite280")
require("safemain")
diff --git a/compiler/tfldump/src/Dump.cpp b/compiler/tfldump/src/Dump.cpp
index 2351e4c3d..2a87e47d7 100644
--- a/compiler/tfldump/src/Dump.cpp
+++ b/compiler/tfldump/src/Dump.cpp
@@ -15,6 +15,7 @@
*/
#include <tfldump/Dump.h>
+#include <mio_tflite280/Helper.h>
#include "Read.h"
#include "OpPrinter.h"
@@ -127,7 +128,7 @@ void dump_sub_graph(std::ostream &os, tflread::Reader &reader)
// dump operands(tensors)
os << "Operands: T(subgraph index : tensor index) TYPE (shape) (shape_signature) "
- << "B(buffer index) OperandName" << std::endl;
+ << "B(buffer index) (variable) OperandName" << std::endl;
for (uint32_t i = 0; i < tensors->Length(); ++i)
{
// TODO refactor to some better structure
@@ -137,7 +138,7 @@ void dump_sub_graph(std::ostream &os, tflread::Reader &reader)
if (tensor->shape())
dims = tflread::as_index_vector(tensor->shape());
- os << "T(" << reader.subgraph_index() << ":" << i << ") " << tflread::tensor_type(tensor)
+ os << "T(" << reader.subgraph_index() << ":" << i << ") " << mio::tflite::tensor_type(tensor)
<< " ";
os << "(" << dims << ") ";
if (tensor->shape_signature())
@@ -146,7 +147,11 @@ void dump_sub_graph(std::ostream &os, tflread::Reader &reader)
os << "(" << dims_sig << ") ";
}
os << "B(" << tensor->buffer() << ") ";
- os << tflread::tensor_name(tensor) << std::endl;
+ if (tensor->is_variable())
+ {
+ os << "(variable) ";
+ }
+ os << mio::tflite::tensor_name(tensor) << std::endl;
if (auto q_params = tensor->quantization())
{
@@ -298,7 +303,7 @@ void dump_sub_graph(std::ostream &os, tflread::Reader &reader)
if (input >= 0)
{
auto tensor = tensors->Get(input);
- os << tflread::tensor_name(tensor);
+ os << mio::tflite::tensor_name(tensor);
}
os << std::endl;
}
@@ -308,7 +313,7 @@ void dump_sub_graph(std::ostream &os, tflread::Reader &reader)
if (output >= 0)
{
auto tensor = tensors->Get(output);
- os << tflread::tensor_name(tensor);
+ os << mio::tflite::tensor_name(tensor);
}
os << std::endl;
}
@@ -321,14 +326,14 @@ void dump_sub_graph(std::ostream &os, tflread::Reader &reader)
for (const auto input : reader.inputs())
{
auto tensor = tensors->Get(input);
- std::string name = tflread::tensor_name(tensor);
+ std::string name = mio::tflite::tensor_name(tensor);
os << "I T(" << reader.subgraph_index() << ":" << input << ") " << name << std::endl;
}
for (const auto output : reader.outputs())
{
auto tensor = tensors->Get(output);
- std::string name = tflread::tensor_name(tensor);
+ std::string name = mio::tflite::tensor_name(tensor);
os << "O T(" << reader.subgraph_index() << ":" << output << ") " << name << std::endl;
}
@@ -360,7 +365,7 @@ void dump_model(std::ostream &os, const tflite::Model *model)
tflite::BuiltinOperator op_code = opcode->builtin_code();
tflite::BuiltinOperator dp_code = tflite::BuiltinOperator(opcode->deprecated_builtin_code());
- auto op_name = tflread::opcode_name(opcode);
+ auto op_name = mio::tflite::opcode_name(opcode);
auto op_version = opcode->version();
os << "[" << opcode_index << "] " << op_name << " (code: " << op_code
@@ -405,9 +410,8 @@ void dump_model(std::ostream &os, const tflite::Model *model)
for (uint32_t i = 0; i < signaturedefs->Length(); ++i)
{
auto sign_i = signaturedefs->Get(i);
- os << "S(" << i << ") method_name(" << sign_i->method_name()->c_str() << "), key("
- << sign_i->key()->c_str() << "), sub_graph(" << sign_i->subgraph_index() << ")"
- << std::endl;
+ os << "S(" << i << ") signature_key(" << sign_i->signature_key()->c_str() << "), sub_graph("
+ << sign_i->subgraph_index() << ")" << std::endl;
auto inputs_i = sign_i->inputs();
for (uint32_t t = 0; t < inputs_i->Length(); ++t)
diff --git a/compiler/tfldump/src/Load.cpp b/compiler/tfldump/src/Load.cpp
index fe04a5dd6..d2f6e06f1 100644
--- a/compiler/tfldump/src/Load.cpp
+++ b/compiler/tfldump/src/Load.cpp
@@ -76,7 +76,7 @@ public:
{
if (_value != -1)
{
- // Close on descturction
+ // Close on destructor
close(_value);
}
}
diff --git a/compiler/tfldump/src/OpPrinter.cpp b/compiler/tfldump/src/OpPrinter.cpp
index 90cba7173..47edcb086 100644
--- a/compiler/tfldump/src/OpPrinter.cpp
+++ b/compiler/tfldump/src/OpPrinter.cpp
@@ -602,6 +602,23 @@ public:
}
};
+class SVDFPrinter : public OpPrinter
+{
+public:
+ void options(const tflite::Operator *op, std::ostream &os) const override
+ {
+ if (auto *params = op->builtin_options_as_SVDFOptions())
+ {
+ os << " ";
+ os << "rank(" << params->rank() << ") ";
+ os << "activation(" << EnumNameActivationFunctionType(params->fused_activation_function())
+ << ") ";
+ os << "asymmetric_quantize_inputs(" << params->asymmetric_quantize_inputs() << ") ";
+ os << std::endl;
+ }
+ }
+};
+
class TransposeConvPrinter : public OpPrinter
{
public:
@@ -776,6 +793,7 @@ OpPrinterRegistry::OpPrinterRegistry()
_op_map[tflite::BuiltinOperator_STRIDED_SLICE] = make_unique<StridedSlicePrinter>();
_op_map[tflite::BuiltinOperator_SUB] = make_unique<SubPrinter>();
_op_map[tflite::BuiltinOperator_SUM] = make_unique<ReducerPrinter>();
+ _op_map[tflite::BuiltinOperator_SVDF] = make_unique<SVDFPrinter>();
_op_map[tflite::BuiltinOperator_TRANSPOSE_CONV] = make_unique<TransposeConvPrinter>();
// There is no Option for TOPK_V2
_op_map[tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM] =
diff --git a/compiler/tfldump/src/Read.cpp b/compiler/tfldump/src/Read.cpp
index 8b3a96e83..454e3a8a1 100644
--- a/compiler/tfldump/src/Read.cpp
+++ b/compiler/tfldump/src/Read.cpp
@@ -16,76 +16,14 @@
#include "Read.h"
+#include <mio_tflite280/Helper.h>
+
#include <sstream>
#include <string>
namespace tflread
{
-// This will provide v3/v3a format neutral BuiltinOperator
-tflite::BuiltinOperator builtin_code_neutral(const tflite::OperatorCode *opcode)
-{
- assert(opcode != nullptr);
- int8_t dp_code = opcode->deprecated_builtin_code();
- if (dp_code < 127 && dp_code >= 0)
- return tflite::BuiltinOperator(dp_code);
- return opcode->builtin_code();
-}
-
-bool is_valid(const tflite::OperatorCode *opcode)
-{
- tflite::BuiltinOperator code = builtin_code_neutral(opcode);
- return (tflite::BuiltinOperator_MIN <= code && code <= tflite::BuiltinOperator_MAX);
-}
-
-bool is_custom(const tflite::OperatorCode *opcode)
-{
- tflite::BuiltinOperator code = builtin_code_neutral(opcode);
- return (code == tflite::BuiltinOperator_CUSTOM);
-}
-
-std::string opcode_name(const tflite::OperatorCode *opcode)
-{
- assert(opcode);
-
- if (!is_valid(opcode))
- {
- std::ostringstream oss;
- oss << "(invalid)";
- return oss.str();
- }
-
- if (is_custom(opcode))
- {
- if (!opcode->custom_code())
- return "(invalid custom)";
-
- std::string custom_op = "CUSTOM(";
- custom_op += opcode->custom_code()->c_str();
- custom_op += ")";
- return custom_op;
- }
-
- tflite::BuiltinOperator code = builtin_code_neutral(opcode);
- return tflite::EnumNameBuiltinOperator(code);
-}
-
-const char *tensor_type(const tflite::Tensor *tensor)
-{
- return tflite::EnumNameTensorType(tensor->type());
-}
-
-const char *tensor_name(const tflite::Tensor *tensor)
-{
- static const char *kEmptyTensorName = "(noname)";
-
- auto name = tensor->name();
- if (name)
- return name->c_str();
-
- return kEmptyTensorName;
-}
-
Reader::Reader(const tflite::Model *model)
{
_version = model->version();
@@ -129,7 +67,7 @@ tflite::BuiltinOperator Reader::builtin_code(const tflite::Operator *op) const
assert(index < _op_codes.size());
const tflite::OperatorCode *opcode = _op_codes.at(index);
- return tflread::builtin_code_neutral(opcode);
+ return mio::tflite::builtin_code_neutral(opcode);
}
std::string Reader::opcode_name(const tflite::Operator *op) const
@@ -138,14 +76,14 @@ std::string Reader::opcode_name(const tflite::Operator *op) const
assert(index < _op_codes.size());
const tflite::OperatorCode *opcode = _op_codes.at(index);
- if (!is_valid(opcode))
+ if (!mio::tflite::is_valid(opcode))
{
std::ostringstream oss;
oss << "(invalid: " << index << ")";
return oss.str();
}
- return tflread::opcode_name(opcode);
+ return mio::tflite::opcode_name(opcode);
}
bool Reader::select_subgraph(uint32_t sgindex)
diff --git a/compiler/tfldump/src/Read.h b/compiler/tfldump/src/Read.h
index 80f317d0b..1ae63877f 100644
--- a/compiler/tfldump/src/Read.h
+++ b/compiler/tfldump/src/Read.h
@@ -36,13 +36,6 @@ template <typename T> std::vector<T> as_index_vector(const flatbuffers::Vector<T
return ret;
}
-tflite::BuiltinOperator builtin_code_neutral(const tflite::OperatorCode *opcode);
-bool is_valid(const tflite::OperatorCode *opcode);
-bool is_custom(const tflite::OperatorCode *opcode);
-std::string opcode_name(const tflite::OperatorCode *opcode);
-const char *tensor_type(const tflite::Tensor *tensor);
-const char *tensor_name(const tflite::Tensor *tensor);
-
/**
* @brief Loads TF lite file and provides helpers to access attributes
*/
diff --git a/compiler/tflite2circle/CMakeLists.txt b/compiler/tflite2circle/CMakeLists.txt
index 4ea01ad31..a317a6305 100644
--- a/compiler/tflite2circle/CMakeLists.txt
+++ b/compiler/tflite2circle/CMakeLists.txt
@@ -1,8 +1,8 @@
nnas_include(TargetRequire)
unset(REQUIRED_TARGETS)
-list(APPEND REQUIRED_TARGETS mio_tflite260)
-list(APPEND REQUIRED_TARGETS mio_circle)
+list(APPEND REQUIRED_TARGETS mio_tflite280)
+list(APPEND REQUIRED_TARGETS mio_circle04)
TargetRequire_Return(${REQUIRED_TARGETS})
set(DRIVER "driver/Driver.cpp")
@@ -13,8 +13,9 @@ target_include_directories(tflite2circle PRIVATE src)
target_link_libraries(tflite2circle arser)
target_link_libraries(tflite2circle foder)
target_link_libraries(tflite2circle safemain)
-target_link_libraries(tflite2circle mio_tflite260)
-target_link_libraries(tflite2circle mio_circle)
+target_link_libraries(tflite2circle mio_tflite280)
+target_link_libraries(tflite2circle mio_tflite280_helper)
+target_link_libraries(tflite2circle mio_circle04)
target_link_libraries(tflite2circle vconone)
target_link_libraries(tflite2circle nncc_coverage)
diff --git a/compiler/tflite2circle/requires.cmake b/compiler/tflite2circle/requires.cmake
index e39f9eeaf..3db9a2f2a 100644
--- a/compiler/tflite2circle/requires.cmake
+++ b/compiler/tflite2circle/requires.cmake
@@ -1,6 +1,6 @@
require("arser")
require("foder")
-require("mio-tflite260")
-require("mio-circle")
+require("mio-tflite280")
+require("mio-circle04")
require("safemain")
require("vconone")
diff --git a/compiler/tflite2circle/src/BuildBuiltinOptions.h b/compiler/tflite2circle/src/BuildBuiltinOptions.h
index dc6ff086c..88a4f71df 100644
--- a/compiler/tflite2circle/src/BuildBuiltinOptions.h
+++ b/compiler/tflite2circle/src/BuildBuiltinOptions.h
@@ -102,6 +102,7 @@
#include "BuildBuiltinOptions/SqueezeOptions.h"
#include "BuildBuiltinOptions/StridedSliceOptions.h"
#include "BuildBuiltinOptions/SubOptions.h"
+#include "BuildBuiltinOptions/SVDFOptions.h"
#include "BuildBuiltinOptions/TileOptions.h"
#include "BuildBuiltinOptions/TopKV2Options.h"
#include "BuildBuiltinOptions/TransposeOptions.h"
diff --git a/compiler/tflite2circle/src/BuildBuiltinOptions/FullyConnectedOptions.cpp b/compiler/tflite2circle/src/BuildBuiltinOptions/FullyConnectedOptions.cpp
index 2619b73eb..27410012d 100644
--- a/compiler/tflite2circle/src/BuildBuiltinOptions/FullyConnectedOptions.cpp
+++ b/compiler/tflite2circle/src/BuildBuiltinOptions/FullyConnectedOptions.cpp
@@ -37,6 +37,7 @@ build_circle_FullyConnectedOptions(flatbuffers::FlatBufferBuilder &fb, const tfl
else if (tflite_weight_format == tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8)
builtin_options_builder.add_weights_format(
circle::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8);
+ builtin_options_builder.add_keep_num_dims(tflite_builtin_options->keep_num_dims());
return builtin_options_builder.Finish();
}
diff --git a/compiler/tflite2circle/src/BuildBuiltinOptions/SVDFOptions.cpp b/compiler/tflite2circle/src/BuildBuiltinOptions/SVDFOptions.cpp
new file mode 100644
index 000000000..e23738a69
--- /dev/null
+++ b/compiler/tflite2circle/src/BuildBuiltinOptions/SVDFOptions.cpp
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "SVDFOptions.h"
+#include "DataLookup.h"
+
+#include <cassert>
+
+namespace tflite2circle
+{
+
+flatbuffers::Offset<circle::SVDFOptions>
+build_circle_SVDFOptions(flatbuffers::FlatBufferBuilder &fb, const tflite::Operator *op)
+{
+ auto *tflite_builtin_options = op->builtin_options_as_SVDFOptions();
+ assert(tflite_builtin_options);
+
+ circle::SVDFOptionsBuilder builtin_options_builder{fb};
+ builtin_options_builder.add_rank(tflite_builtin_options->rank());
+ builtin_options_builder.add_asymmetric_quantize_inputs(
+ tflite_builtin_options->asymmetric_quantize_inputs());
+ builtin_options_builder.add_fused_activation_function(
+ get_circle_activation_function_type(tflite_builtin_options->fused_activation_function()));
+
+ return builtin_options_builder.Finish();
+}
+
+} // namespace tflite2circle
diff --git a/compiler/tflite2circle/src/BuildBuiltinOptions/SVDFOptions.h b/compiler/tflite2circle/src/BuildBuiltinOptions/SVDFOptions.h
new file mode 100644
index 000000000..2ddbd3911
--- /dev/null
+++ b/compiler/tflite2circle/src/BuildBuiltinOptions/SVDFOptions.h
@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __BBO_SVDF_OPTIONS_H__
+#define __BBO_SVDF_OPTIONS_H__
+
+#include <mio/tflite/schema_generated.h>
+#include <mio/circle/schema_generated.h>
+
+namespace tflite2circle
+{
+
+flatbuffers::Offset<circle::SVDFOptions>
+build_circle_SVDFOptions(flatbuffers::FlatBufferBuilder &fb, const tflite::Operator *op);
+
+} // namespace tflite2circle
+
+#endif // __BBO_SVDF_OPTIONS_H__
diff --git a/compiler/tflite2circle/src/CircleModel.cpp b/compiler/tflite2circle/src/CircleModel.cpp
index 90cc415ff..d483b288f 100644
--- a/compiler/tflite2circle/src/CircleModel.cpp
+++ b/compiler/tflite2circle/src/CircleModel.cpp
@@ -16,11 +16,14 @@
#include <cassert>
#include <iostream>
+#include <map>
#include <memory>
#include "CircleModel.h"
#include "DataLookup.h"
+#include <mio_tflite280/Helper.h>
+
namespace tflite2circle
{
@@ -206,7 +209,8 @@ template <> void Offset<SubGraphLink>::build(const TFLFlatBufVec *tflite_flatbuf
auto tflite_inputs = it_sg->inputs();
std::vector<int32_t> input_vec{tflite_inputs->begin(), tflite_inputs->end()};
- // apply signature_def to input tensor index so that input orders are correct
+ // apply signature_def to input tensor index so that input orders follow like tensorflow lite
+ // interpreter._get_full_signature_list() method, which is ordered(sorted) in name
// NOTE we do not need this when circle format supports signature_def
if (_tfl_signature_def_offsets != nullptr)
{
@@ -216,10 +220,16 @@ template <> void Offset<SubGraphLink>::build(const TFLFlatBufVec *tflite_flatbuf
{
auto inputs = it_signdef->inputs();
assert(inputs->size() == input_vec.size());
- uint32_t input_vec_idx = 0;
+
+ std::map<std::string, uint32_t> map_name_index;
for (auto it_tm : *inputs)
{
- input_vec[input_vec_idx++] = static_cast<int32_t>(it_tm->tensor_index());
+ map_name_index[it_tm->name()->str()] = it_tm->tensor_index();
+ }
+ uint32_t input_vec_idx = 0;
+ for (auto &item : map_name_index)
+ {
+ input_vec[input_vec_idx++] = item.second;
}
}
}
@@ -240,10 +250,16 @@ template <> void Offset<SubGraphLink>::build(const TFLFlatBufVec *tflite_flatbuf
{
auto outputs = it_signdef->outputs();
assert(outputs->size() == output_vec.size());
- uint32_t output_vec_idx = 0;
+
+ std::map<std::string, uint32_t> map_name_index;
for (auto it_tm : *outputs)
{
- output_vec[output_vec_idx++] = static_cast<int32_t>(it_tm->tensor_index());
+ map_name_index[it_tm->name()->str()] = it_tm->tensor_index();
+ }
+ uint32_t output_vec_idx = 0;
+ for (auto &item : map_name_index)
+ {
+ output_vec[output_vec_idx++] = item.second;
}
}
}
@@ -318,17 +334,6 @@ template <> void Offset<SubGraphLink>::build(const TFLFlatBufVec *tflite_flatbuf
_circle_flatbuffer_vec_offset = _fb->CreateVector(subgprahs_vec);
}
-tflite::BuiltinOperator builtin_code_neutral(const tflite::OperatorCode *opcode)
-{
- assert(opcode != nullptr);
- int8_t dp_code = opcode->deprecated_builtin_code();
- // 127 is max of int8_t which is upper bound of v3 builtin_code
- // NOTE TensorFlow uses 'BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES' for 127
- if (dp_code < 127 && dp_code >= 0)
- return tflite::BuiltinOperator(dp_code);
- return opcode->builtin_code();
-}
-
template <> void Offset<OperatorCodeLink>::build(const TFLFlatBufVec *tflite_flatbuffer_vec)
{
std::vector<flatbuffers::Offset<circle::OperatorCode>> operator_code_vec;
@@ -337,8 +342,9 @@ template <> void Offset<OperatorCodeLink>::build(const TFLFlatBufVec *tflite_fla
{
auto custom_code = _fb->CreateString(it->custom_code());
circle::OperatorCodeBuilder operator_code_builder{*_fb};
- // TODO support circle deprecated_builtin_code
- auto bt_code = builtin_code_neutral(it);
+ auto de_code = it->deprecated_builtin_code();
+ auto bt_code = it->builtin_code();
+ operator_code_builder.add_deprecated_builtin_code(get_circle_builtin_code(de_code));
operator_code_builder.add_builtin_code(get_circle_builtin_code(bt_code));
operator_code_builder.add_custom_code(custom_code);
operator_code_builder.add_version(it->version());
diff --git a/compiler/tflite2circle/src/DataLookup.cpp b/compiler/tflite2circle/src/DataLookup.cpp
index c5ed62e31..7c3aab089 100644
--- a/compiler/tflite2circle/src/DataLookup.cpp
+++ b/compiler/tflite2circle/src/DataLookup.cpp
@@ -34,6 +34,22 @@ circle::BuiltinOperator get_circle_builtin_code(tflite::BuiltinOperator tfl_bop)
}
}
+int8_t get_circle_builtin_code(int8_t tfl_bop_i8)
+{
+ tflite::BuiltinOperator tfl_bop = static_cast<tflite::BuiltinOperator>(tfl_bop_i8);
+
+ switch (tfl_bop)
+ {
+#define TFL_OPERATOR(OP) \
+ case tflite::BuiltinOperator_##OP: \
+ return static_cast<int8_t>(circle::BuiltinOperator_##OP);
+#include "TFLOperator.lst"
+#undef TFL_OPERATOR
+ default:
+ throw std::runtime_error("tflite2circle: wrong op");
+ }
+}
+
circle::TensorType get_circle_tensortype(tflite::TensorType tfl_tt)
{
switch (tfl_tt)
diff --git a/compiler/tflite2circle/src/DataLookup.h b/compiler/tflite2circle/src/DataLookup.h
index 601d014dd..5aeeb6eca 100644
--- a/compiler/tflite2circle/src/DataLookup.h
+++ b/compiler/tflite2circle/src/DataLookup.h
@@ -30,6 +30,8 @@ namespace tflite2circle
*/
circle::BuiltinOperator get_circle_builtin_code(tflite::BuiltinOperator tfl_bop);
+int8_t get_circle_builtin_code(int8_t tfl_bop_i8);
+
/**
* @brief Returns circle TensorType according to tflite.
*
diff --git a/compiler/tflite2circle/src/TFLBuiltinOptions.lst b/compiler/tflite2circle/src/TFLBuiltinOptions.lst
index f2de7e046..d55ba464a 100644
--- a/compiler/tflite2circle/src/TFLBuiltinOptions.lst
+++ b/compiler/tflite2circle/src/TFLBuiltinOptions.lst
@@ -9,7 +9,7 @@ TFL_BUILTIN_OPTIONS(DepthwiseConv2DOptions)
//TFL_BUILTIN_OPTIONS(ConcatEmbeddingsOptions)
//TFL_BUILTIN_OPTIONS(LSHProjectionOptions)
TFL_BUILTIN_OPTIONS(Pool2DOptions)
-//TFL_BUILTIN_OPTIONS(SVDFOptions)
+TFL_BUILTIN_OPTIONS(SVDFOptions)
//TFL_BUILTIN_OPTIONS(RNNOptions)
TFL_BUILTIN_OPTIONS(FullyConnectedOptions)
TFL_BUILTIN_OPTIONS(SoftmaxOptions)
diff --git a/compiler/vconone/CMakeLists.txt b/compiler/vconone/CMakeLists.txt
index 2241c9ec9..3841a1b78 100644
--- a/compiler/vconone/CMakeLists.txt
+++ b/compiler/vconone/CMakeLists.txt
@@ -1,5 +1,5 @@
if (NOT VCONONE_VERSION)
- set(VCONONE_VERSION 0x0000000000130001)
+ set(VCONONE_VERSION 0x0000000000140001)
# NOTE order is [build patch minor major]
# if VCONONE_VERSION is set with -D option, it will be cached
# you may have to remove cache file if you remove -D option